mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add pmap axes hints
This commit is contained in:
parent
8df00e2666
commit
3e4dc0d490
@ -1135,8 +1135,8 @@ def pmap(
|
||||
fun: Callable,
|
||||
axis_name: AxisName | None = None,
|
||||
*,
|
||||
in_axes=0,
|
||||
out_axes=0,
|
||||
in_axes: int | None | Sequence[Any] = 0,
|
||||
out_axes: Any = 0,
|
||||
static_broadcasted_argnums: int | Iterable[int] = (),
|
||||
devices: Sequence[xc.Device] | None = None, # noqa: F811
|
||||
backend: str | None = None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user