diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 133951dc8..3d17c410d 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -484,17 +484,15 @@ def _as_manual_mesh(mesh, auto: frozenset): manual_axes = tuple(set(mesh.axis_names) - auto) cur_mesh = get_abstract_mesh() if cur_mesh.empty: - auto_axes = tuple(auto) - explicit_axes = () - else: - explicit_axes, auto_axes = [], [] # type: ignore - for a in auto: - if cur_mesh._name_to_type[a] == AxisTypes.Auto: - auto_axes.append(a) - else: - assert cur_mesh._name_to_type[a] == AxisTypes.Explicit - explicit_axes.append(a) - explicit_axes, auto_axes = tuple(explicit_axes), tuple(auto_axes) # type: ignore + cur_mesh = mesh + explicit_axes, auto_axes = [], [] # type: ignore + for a in auto: + if cur_mesh._name_to_type[a] == AxisTypes.Auto: + auto_axes.append(a) + else: + assert cur_mesh._name_to_type[a] == AxisTypes.Explicit + explicit_axes.append(a) + explicit_axes, auto_axes = tuple(explicit_axes), tuple(auto_axes) # type: ignore return AbstractMesh( mesh.shape_tuple, axis_types={ diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 4499fa4b1..863ee9008 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1921,6 +1921,34 @@ class ShardMapTest(jtu.JaxTestCase): ) self.assertAllClose(v * v, f(v), check_dtypes=False) + def test_partial_auto_explicit_no_use_mesh(self): + mesh = jtu.create_mesh((2, 2), ('i', 'j'), + axis_types={AxisTypes.Explicit: ('i', 'j')}) + + def g(x): + self.assertDictEqual(x.aval.sharding.mesh.axis_types, + {AxisTypes.Manual: ('i',), AxisTypes.Explicit: ('j',)}) + self.assertEqual(x.aval.sharding.spec, P(None, 'j')) + out = x * x + self.assertEqual(out.aval.sharding.spec, P(None, 'j')) + return out + + @jax.jit + def f(x): + x = shard_map(g, mesh, + in_specs=P('i', None), + out_specs=P('i', None), + auto=frozenset({'j'}))(x) + self.assertEqual(x.aval.sharding.spec, P('i', 'j')) + return x + + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + + out = f(v) + self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j'))) + self.assertAllClose(v * v, out, check_dtypes=False) + @jtu.with_user_mesh((2, 2), ('i', 'j')) def test_partial_auto_explicit(self, mesh): def g(x):