mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
If cur_mesh is empty and AxisTypes of Mesh passed to shmap are Explicit, then treat the axes mentioned in auto
as explicit too. In other words, "auto" really means "don't convert to manual", ie leave the listed mesh axes as they are, whether explicit or auto type
PiperOrigin-RevId: 728942780
This commit is contained in:
parent
8305803b76
commit
b6b319cd06
@ -484,17 +484,15 @@ def _as_manual_mesh(mesh, auto: frozenset):
|
||||
manual_axes = tuple(set(mesh.axis_names) - auto)
|
||||
cur_mesh = get_abstract_mesh()
|
||||
if cur_mesh.empty:
|
||||
auto_axes = tuple(auto)
|
||||
explicit_axes = ()
|
||||
else:
|
||||
explicit_axes, auto_axes = [], [] # type: ignore
|
||||
for a in auto:
|
||||
if cur_mesh._name_to_type[a] == AxisTypes.Auto:
|
||||
auto_axes.append(a)
|
||||
else:
|
||||
assert cur_mesh._name_to_type[a] == AxisTypes.Explicit
|
||||
explicit_axes.append(a)
|
||||
explicit_axes, auto_axes = tuple(explicit_axes), tuple(auto_axes) # type: ignore
|
||||
cur_mesh = mesh
|
||||
explicit_axes, auto_axes = [], [] # type: ignore
|
||||
for a in auto:
|
||||
if cur_mesh._name_to_type[a] == AxisTypes.Auto:
|
||||
auto_axes.append(a)
|
||||
else:
|
||||
assert cur_mesh._name_to_type[a] == AxisTypes.Explicit
|
||||
explicit_axes.append(a)
|
||||
explicit_axes, auto_axes = tuple(explicit_axes), tuple(auto_axes) # type: ignore
|
||||
return AbstractMesh(
|
||||
mesh.shape_tuple,
|
||||
axis_types={
|
||||
|
@ -1921,6 +1921,34 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
)
|
||||
self.assertAllClose(v * v, f(v), check_dtypes=False)
|
||||
|
||||
def test_partial_auto_explicit_no_use_mesh(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'),
|
||||
axis_types={AxisTypes.Explicit: ('i', 'j')})
|
||||
|
||||
def g(x):
|
||||
self.assertDictEqual(x.aval.sharding.mesh.axis_types,
|
||||
{AxisTypes.Manual: ('i',), AxisTypes.Explicit: ('j',)})
|
||||
self.assertEqual(x.aval.sharding.spec, P(None, 'j'))
|
||||
out = x * x
|
||||
self.assertEqual(out.aval.sharding.spec, P(None, 'j'))
|
||||
return out
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
x = shard_map(g, mesh,
|
||||
in_specs=P('i', None),
|
||||
out_specs=P('i', None),
|
||||
auto=frozenset({'j'}))(x)
|
||||
self.assertEqual(x.aval.sharding.spec, P('i', 'j'))
|
||||
return x
|
||||
|
||||
v = jnp.arange(32.).reshape(4, 8)
|
||||
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
||||
|
||||
out = f(v)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j')))
|
||||
self.assertAllClose(v * v, out, check_dtypes=False)
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('i', 'j'))
|
||||
def test_partial_auto_explicit(self, mesh):
|
||||
def g(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user