Fix repr of sharding in aval when a dimension is sharded on multiple mesh axes

PiperOrigin-RevId: 685215764
This commit is contained in:
Yash Katariya 2024-10-12 09:55:23 -07:00 committed by jax authors
parent 5b8775dc2f
commit 8139c531a3
2 changed files with 13 additions and 13 deletions

View File

@ -1829,8 +1829,8 @@ def _get_shape_sharding_str(shape, spec):
if s2 is None:
yield f"{s1}"
elif isinstance(s2, tuple):
ss = ''.join(s for s in s2)
yield f"{s1}@{ss}"
ss = ','.join(s for s in s2)
yield f"{s1}@({ss})"
else:
yield f"{s1}@{s2}"

View File

@ -4740,23 +4740,23 @@ class ShardingInTypesTest(jtu.JaxTestCase):
jnp.einsum('abc,acz->abz', arr1, arr2)
def test_aval_repr(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('model', 'data'))
aval = core.ShapedArray((8, 2), np.float32,
sharding=NamedSharding(mesh, P('x', 'y')))
self.assertEqual(aval.str_short(), 'float32[8@x,2@y]')
aval = core.ShapedArray((128, 64), np.float32,
sharding=NamedSharding(mesh, P('model', 'data')))
self.assertEqual(aval.str_short(), 'float32[128@model,64@data]')
aval = aval.update(sharding=NamedSharding(mesh, P('x', None)))
self.assertEqual(aval.str_short(), 'float32[8@x,2]')
aval = aval.update(sharding=NamedSharding(mesh, P('model', None)))
self.assertEqual(aval.str_short(), 'float32[128@model,64]')
aval = aval.update(sharding=NamedSharding(mesh, P(None, 'y')))
self.assertEqual(aval.str_short(), 'float32[8,2@y]')
aval = aval.update(sharding=NamedSharding(mesh, P(None, 'data')))
self.assertEqual(aval.str_short(), 'float32[128,64@data]')
aval = aval.update(sharding=NamedSharding(mesh, P(None, None)))
self.assertEqual(aval.str_short(), 'float32[8,2]')
self.assertEqual(aval.str_short(), 'float32[128,64]')
aval = aval.update(sharding=NamedSharding(mesh, P(('x', 'y'), None)))
self.assertEqual(aval.str_short(), 'float32[8@xy,2]')
aval = aval.update(sharding=NamedSharding(mesh, P(('model', 'data'), None)))
self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]')
@parameterized.named_parameters(
('all', None,P('x', 'y'), P()),