mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix repr of sharding in aval when a dimension is sharded on multiple mesh axes
PiperOrigin-RevId: 685215764
This commit is contained in:
parent
5b8775dc2f
commit
8139c531a3
@ -1829,8 +1829,8 @@ def _get_shape_sharding_str(shape, spec):
|
||||
if s2 is None:
|
||||
yield f"{s1}"
|
||||
elif isinstance(s2, tuple):
|
||||
ss = ''.join(s for s in s2)
|
||||
yield f"{s1}@{ss}"
|
||||
ss = ','.join(s for s in s2)
|
||||
yield f"{s1}@({ss})"
|
||||
else:
|
||||
yield f"{s1}@{s2}"
|
||||
|
||||
|
@ -4740,23 +4740,23 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
jnp.einsum('abc,acz->abz', arr1, arr2)
|
||||
|
||||
def test_aval_repr(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('model', 'data'))
|
||||
|
||||
aval = core.ShapedArray((8, 2), np.float32,
|
||||
sharding=NamedSharding(mesh, P('x', 'y')))
|
||||
self.assertEqual(aval.str_short(), 'float32[8@x,2@y]')
|
||||
aval = core.ShapedArray((128, 64), np.float32,
|
||||
sharding=NamedSharding(mesh, P('model', 'data')))
|
||||
self.assertEqual(aval.str_short(), 'float32[128@model,64@data]')
|
||||
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P('x', None)))
|
||||
self.assertEqual(aval.str_short(), 'float32[8@x,2]')
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P('model', None)))
|
||||
self.assertEqual(aval.str_short(), 'float32[128@model,64]')
|
||||
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P(None, 'y')))
|
||||
self.assertEqual(aval.str_short(), 'float32[8,2@y]')
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P(None, 'data')))
|
||||
self.assertEqual(aval.str_short(), 'float32[128,64@data]')
|
||||
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P(None, None)))
|
||||
self.assertEqual(aval.str_short(), 'float32[8,2]')
|
||||
self.assertEqual(aval.str_short(), 'float32[128,64]')
|
||||
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P(('x', 'y'), None)))
|
||||
self.assertEqual(aval.str_short(), 'float32[8@xy,2]')
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P(('model', 'data'), None)))
|
||||
self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('all', None,P('x', 'y'), P()),
|
||||
|
Loading…
x
Reference in New Issue
Block a user