mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas] Fix shard_axis in dma_start interpret mode rule.
PiperOrigin-RevId: 703192497
This commit is contained in:
parent
7e6620a577
commit
259194a69f
@ -615,7 +615,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
|
||||
if device_id_len > 1 or len(nonempty_axes) > 1:
|
||||
raise NotImplementedError("Meshes with more than 1 named dimension not "
|
||||
"implemented in dma_start_p")
|
||||
shard_axis = nonempty_axes[0].name
|
||||
shard_axis = nonempty_axes[0]
|
||||
my_axis = jax.lax.axis_index(shard_axis)
|
||||
else:
|
||||
raise ValueError(f"Unknown device_id_type: {device_id_type}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user