#sdy add repr for Sdy ArraySharding and DimSharding

PiperOrigin-RevId: 713422071
This commit is contained in:
Bart Chrzaszcz 2025-01-08 14:41:02 -08:00 committed by jax authors
parent 196eec8296
commit cbcc883ea3
2 changed files with 40 additions and 0 deletions

View File

@ -125,6 +125,17 @@ class SdyDimSharding:
is_closed=self.is_closed,
priority=self.priority)
def __repr__(self):
return f'SdyDimSharding({self._custom_repr()})'
def _custom_repr(self):
axes_repr = ', '.join(f"'{a}'" for a in self.axes)
open_repr = ''
if not self.is_closed:
open_repr = ', ?' if self.axes else '?'
priority_repr = '' if self.priority is None else f'p{self.priority}'
return f'{{{axes_repr}{open_repr}}}{priority_repr}'
@dataclasses.dataclass
class SdyArraySharding:
@ -146,6 +157,13 @@ class SdyArraySharding:
mesh_attr,
[dim_sharding.build() for dim_sharding in self.dimension_shardings])
def __repr__(self):
dim_sharding_repr = ', '.join(
d._custom_repr() for d in self.dimension_shardings)
device_id_repr = (f', device_ids={self.logical_device_ids}'
if self.logical_device_ids is not None else '')
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr})"
@util.cache(max_size=4096, trace_context_in_key=False)
def named_sharding_to_xla_hlo_sharding(

View File

@ -6632,6 +6632,28 @@ class ShardyTest(jtu.JaxTestCase):
lowered_str = jax.jit(f, in_shardings=[AUTO(mesh), AUTO(mesh)]).lower(x, x).as_text()
self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str)
def test_array_sharding_repr_with_priority(self):
sharding = sharding_impls.SdyArraySharding(
mesh_shape=(('data', 4), ('model', 8), ('expert', 2)),
dimension_shardings=[
sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_closed=True),
sharding_impls.SdyDimSharding(axes=['model'], is_closed=False, priority=2)])
self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])")
def test_array_sharding_repr_with_logical_ids(self):
abstract_mesh = jax.sharding.AbstractMesh((('x', 4), ('y', 8), ('z', 2)))
ns = NamedSharding(abstract_mesh, P(('x', 'y'), 'z', P.UNCONSTRAINED, None),
_logical_device_ids=[4, 5, 6, 7, 0, 1, 2, 3])
self.assertEqual(repr(ns._to_sdy_sharding(4)),
"SdyArraySharding([{'x', 'y'}, {'z'}, {?}, {}], "
"device_ids=[4, 5, 6, 7, 0, 1, 2, 3])")
def test_dimension_sharding_repr(self):
dim_sharding = sharding_impls.SdyDimSharding(
axes=['data', 'model'], is_closed=False, priority=2)
self.assertEqual(repr(dim_sharding),
"SdyDimSharding({'data', 'model', ?}p2)")
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())