add pmap axes hints

This commit is contained in:
Owen Lockwood 2025-03-05 12:14:24 -08:00
parent 8df00e2666
commit 3e4dc0d490

View File

@ -1135,8 +1135,8 @@ def pmap(
fun: Callable,
axis_name: AxisName | None = None,
*,
in_axes=0,
out_axes=0,
in_axes: int | None | Sequence[Any] = 0,
out_axes: Any = 0,
static_broadcasted_argnums: int | Iterable[int] = (),
devices: Sequence[xc.Device] | None = None, # noqa: F811
backend: str | None = None,