If cur_mesh is empty and AxisTypes of Mesh passed to shmap are Explicit, then treat the axes mentioned in auto as explicit too. In other words, "auto" really means "don't convert to manual", ie leave the listed mesh axes as they are, whether explicit or auto type

PiperOrigin-RevId: 728942780
This commit is contained in:
Yash Katariya 2025-02-19 21:25:22 -08:00 committed by jax authors
parent 8305803b76
commit b6b319cd06
2 changed files with 37 additions and 11 deletions

View File

@ -484,17 +484,15 @@ def _as_manual_mesh(mesh, auto: frozenset):
manual_axes = tuple(set(mesh.axis_names) - auto)
cur_mesh = get_abstract_mesh()
if cur_mesh.empty:
auto_axes = tuple(auto)
explicit_axes = ()
else:
explicit_axes, auto_axes = [], [] # type: ignore
for a in auto:
if cur_mesh._name_to_type[a] == AxisTypes.Auto:
auto_axes.append(a)
else:
assert cur_mesh._name_to_type[a] == AxisTypes.Explicit
explicit_axes.append(a)
explicit_axes, auto_axes = tuple(explicit_axes), tuple(auto_axes) # type: ignore
cur_mesh = mesh
explicit_axes, auto_axes = [], [] # type: ignore
for a in auto:
if cur_mesh._name_to_type[a] == AxisTypes.Auto:
auto_axes.append(a)
else:
assert cur_mesh._name_to_type[a] == AxisTypes.Explicit
explicit_axes.append(a)
explicit_axes, auto_axes = tuple(explicit_axes), tuple(auto_axes) # type: ignore
return AbstractMesh(
mesh.shape_tuple,
axis_types={

View File

@ -1921,6 +1921,34 @@ class ShardMapTest(jtu.JaxTestCase):
)
self.assertAllClose(v * v, f(v), check_dtypes=False)
def test_partial_auto_explicit_no_use_mesh(self):
mesh = jtu.create_mesh((2, 2), ('i', 'j'),
axis_types={AxisTypes.Explicit: ('i', 'j')})
def g(x):
self.assertDictEqual(x.aval.sharding.mesh.axis_types,
{AxisTypes.Manual: ('i',), AxisTypes.Explicit: ('j',)})
self.assertEqual(x.aval.sharding.spec, P(None, 'j'))
out = x * x
self.assertEqual(out.aval.sharding.spec, P(None, 'j'))
return out
@jax.jit
def f(x):
x = shard_map(g, mesh,
in_specs=P('i', None),
out_specs=P('i', None),
auto=frozenset({'j'}))(x)
self.assertEqual(x.aval.sharding.spec, P('i', 'j'))
return x
v = jnp.arange(32.).reshape(4, 8)
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
out = f(v)
self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j')))
self.assertAllClose(v * v, out, check_dtypes=False)
@jtu.with_user_mesh((2, 2), ('i', 'j'))
def test_partial_auto_explicit(self, mesh):
def g(x):