mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas] Fix shard_axis in dma_start interpret mode rule.
PiperOrigin-RevId: 703192497
This commit is contained in:
parent
7e6620a577
commit
259194a69f
@ -615,7 +615,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
|
|||||||
if device_id_len > 1 or len(nonempty_axes) > 1:
|
if device_id_len > 1 or len(nonempty_axes) > 1:
|
||||||
raise NotImplementedError("Meshes with more than 1 named dimension not "
|
raise NotImplementedError("Meshes with more than 1 named dimension not "
|
||||||
"implemented in dma_start_p")
|
"implemented in dma_start_p")
|
||||||
shard_axis = nonempty_axes[0].name
|
shard_axis = nonempty_axes[0]
|
||||||
my_axis = jax.lax.axis_index(shard_axis)
|
my_axis = jax.lax.axis_index(shard_axis)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown device_id_type: {device_id_type}")
|
raise ValueError(f"Unknown device_id_type: {device_id_type}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user