Fixes the random key sharding in shard_map.

This commit is contained in:
Yunlong Liu 2024-12-29 18:29:28 +00:00
parent 879fa12d90
commit 97b1faacdd
2 changed files with 20 additions and 1 deletions

View File

@ -724,7 +724,6 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
aval_in, aval_out, x):
if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1:
return x
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
axes = {name: i for i, ns in names.items() for name in ns}
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
@ -734,6 +733,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
unspecified = set(range(aval_in.ndim)) if auto else set()
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto,
unspecified_dims=unspecified)
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified)
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
@ -746,6 +746,8 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
ns = sharding_impls.physical_sharding(aval_out, ns)
aval_out = core.physical_aval(aval_out)
unspecified = set(range(aval_out.ndim)) if auto else set()
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
aval_in = core.physical_aval(aval_in)
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified)
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()

View File

@ -2207,6 +2207,23 @@ class ShardMapTest(jtu.JaxTestCase):
#
# f(x) # don't crash
def test_partial_auto_of_random_keys(self):
if config.use_shardy_partitioner.value:
self.skipTest('Shardy does not support full-to-shard.')
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
keys = jax.random.split(jax.random.key(0), 8)
@jax.jit
def f(x):
return shard_map(lambda k: k,
mesh, in_specs=P('i'), out_specs=P('i'),
check_rep=False, auto=frozenset({'j'}))(keys)
y = f(keys) # don't crash
self.assertAllClose(jax.random.key_data(y), jax.random.key_data(keys),
check_dtypes=False)
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
# https://github.com/jax-ml/jax/pull/21032
mesh = jtu.create_mesh((4, 2), ('i', 'j'))