[Pallas] Fix shard_axis in dma_start interpret mode rule.

PiperOrigin-RevId: 703192497
This commit is contained in:
Justin Fu 2024-12-05 11:44:15 -08:00 committed by Peter Hawkins
parent 7e6620a577
commit 259194a69f

View File

@ -615,7 +615,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
if device_id_len > 1 or len(nonempty_axes) > 1:
raise NotImplementedError("Meshes with more than 1 named dimension not "
"implemented in dma_start_p")
shard_axis = nonempty_axes[0].name
shard_axis = nonempty_axes[0]
my_axis = jax.lax.axis_index(shard_axis)
else:
raise ValueError(f"Unknown device_id_type: {device_id_type}")