mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Better repr of aval when shardings are present
Example: (for array for shape (8, 2) with dtype float32 ``` P('x', 'y') -- float32[8@x,2@y] P('x', None) -- float32[8@x,2] P(('x', 'y'), None) -- float32[8@xy,2] P(None, None) -- float32[8, 2] ``` PiperOrigin-RevId: 684996577
This commit is contained in:
parent
18bc354305
commit
89fcd9f1f1
@ -1030,7 +1030,7 @@ def _get_aval_array(self):
|
||||
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
|
||||
return self.aval.update(sharding=NamedSharding(
|
||||
self.sharding.mesh.abstract_mesh,
|
||||
self.sharding.normalized_spec(self.ndim)))
|
||||
self.sharding._normalized_spec(self.ndim)))
|
||||
else:
|
||||
return self.aval
|
||||
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
|
||||
|
@ -1755,6 +1755,8 @@ class ShapedArray(UnshapedArray):
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.weak_type = weak_type
|
||||
if config.sharding_in_types.value:
|
||||
if sharding is not None:
|
||||
assert len(sharding.spec) == len(self.shape)
|
||||
self.sharding = sharding
|
||||
|
||||
def update(self, shape=None, dtype=None, weak_type=None, sharding=None):
|
||||
@ -1805,12 +1807,14 @@ class ShapedArray(UnshapedArray):
|
||||
raise TypeError(self, other)
|
||||
|
||||
def str_short(self, short_dtypes=False):
|
||||
dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
|
||||
self.dtype.name)
|
||||
dt_str = dt_str.replace('void', 'float0')
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
if hasattr(self, 'sharding'):
|
||||
return f'{dt_str}[{shapestr}]({self.sharding})'
|
||||
shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec))
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
else:
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
|
||||
def _len(self, ignored_tracer):
|
||||
@ -1820,6 +1824,17 @@ class ShapedArray(UnshapedArray):
|
||||
raise TypeError("len() of unsized object") from err # same as numpy error
|
||||
|
||||
|
||||
def _get_shape_sharding_str(shape, spec):
|
||||
for s1, s2 in zip(shape, spec):
|
||||
if s2 is None:
|
||||
yield f"{s1}"
|
||||
elif isinstance(s2, tuple):
|
||||
ss = ''.join(s for s in s2)
|
||||
yield f"{s1}@{ss}"
|
||||
else:
|
||||
yield f"{s1}@{s2}"
|
||||
|
||||
|
||||
def _forward_to_value(self, fun, ignored_tracer, *args):
|
||||
return fun(self.val, *args)
|
||||
|
||||
|
@ -307,7 +307,7 @@ class NamedSharding(sharding.Sharding):
|
||||
def with_memory_kind(self, kind: str) -> NamedSharding:
|
||||
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
|
||||
|
||||
def normalized_spec(self, ndim: int) -> PartitionSpec:
|
||||
def _normalized_spec(self, ndim: int) -> PartitionSpec:
|
||||
out = [] # type: ignore
|
||||
for p in self._parsed_pspec:
|
||||
if p is None:
|
||||
|
@ -4739,6 +4739,25 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
' have the consistent sharding'):
|
||||
jnp.einsum('abc,acz->abz', arr1, arr2)
|
||||
|
||||
def test_aval_repr(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
aval = core.ShapedArray((8, 2), np.float32,
|
||||
sharding=NamedSharding(mesh, P('x', 'y')))
|
||||
self.assertEqual(aval.str_short(), 'float32[8@x,2@y]')
|
||||
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P('x', None)))
|
||||
self.assertEqual(aval.str_short(), 'float32[8@x,2]')
|
||||
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P(None, 'y')))
|
||||
self.assertEqual(aval.str_short(), 'float32[8,2@y]')
|
||||
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P(None, None)))
|
||||
self.assertEqual(aval.str_short(), 'float32[8,2]')
|
||||
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P(('x', 'y'), None)))
|
||||
self.assertEqual(aval.str_short(), 'float32[8@xy,2]')
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user