Better repr of aval when shardings are present

Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
This commit is contained in:
Yash Katariya 2024-10-11 16:47:43 -07:00 committed by jax authors
parent 18bc354305
commit 89fcd9f1f1
4 changed files with 39 additions and 5 deletions

View File

@ -1030,7 +1030,7 @@ def _get_aval_array(self):
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
return self.aval.update(sharding=NamedSharding(
self.sharding.mesh.abstract_mesh,
self.sharding.normalized_spec(self.ndim)))
self.sharding._normalized_spec(self.ndim)))
else:
return self.aval
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array

View File

@ -1755,6 +1755,8 @@ class ShapedArray(UnshapedArray):
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
if config.sharding_in_types.value:
if sharding is not None:
assert len(sharding.spec) == len(self.shape)
self.sharding = sharding
def update(self, shape=None, dtype=None, weak_type=None, sharding=None):
@ -1805,12 +1807,14 @@ class ShapedArray(UnshapedArray):
raise TypeError(self, other)
def str_short(self, short_dtypes=False):
dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
self.dtype.name)
dt_str = dt_str.replace('void', 'float0')
shapestr = ','.join(map(str, self.shape))
if hasattr(self, 'sharding'):
return f'{dt_str}[{shapestr}]({self.sharding})'
shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec))
return f'{dt_str}[{shapestr}]'
else:
shapestr = ','.join(map(str, self.shape))
return f'{dt_str}[{shapestr}]'
def _len(self, ignored_tracer):
@ -1820,6 +1824,17 @@ class ShapedArray(UnshapedArray):
raise TypeError("len() of unsized object") from err # same as numpy error
def _get_shape_sharding_str(shape, spec):
for s1, s2 in zip(shape, spec):
if s2 is None:
yield f"{s1}"
elif isinstance(s2, tuple):
ss = ''.join(s for s in s2)
yield f"{s1}@{ss}"
else:
yield f"{s1}@{s2}"
def _forward_to_value(self, fun, ignored_tracer, *args):
return fun(self.val, *args)

View File

@ -307,7 +307,7 @@ class NamedSharding(sharding.Sharding):
def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
def normalized_spec(self, ndim: int) -> PartitionSpec:
def _normalized_spec(self, ndim: int) -> PartitionSpec:
out = [] # type: ignore
for p in self._parsed_pspec:
if p is None:

View File

@ -4739,6 +4739,25 @@ class ShardingInTypesTest(jtu.JaxTestCase):
' have the consistent sharding'):
jnp.einsum('abc,acz->abz', arr1, arr2)
def test_aval_repr(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
aval = core.ShapedArray((8, 2), np.float32,
sharding=NamedSharding(mesh, P('x', 'y')))
self.assertEqual(aval.str_short(), 'float32[8@x,2@y]')
aval = aval.update(sharding=NamedSharding(mesh, P('x', None)))
self.assertEqual(aval.str_short(), 'float32[8@x,2]')
aval = aval.update(sharding=NamedSharding(mesh, P(None, 'y')))
self.assertEqual(aval.str_short(), 'float32[8,2@y]')
aval = aval.update(sharding=NamedSharding(mesh, P(None, None)))
self.assertEqual(aval.str_short(), 'float32[8,2]')
aval = aval.update(sharding=NamedSharding(mesh, P(('x', 'y'), None)))
self.assertEqual(aval.str_short(), 'float32[8@xy,2]')
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):