mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
#sdy add repr for Sdy ArraySharding and DimSharding
PiperOrigin-RevId: 713422071
This commit is contained in:
parent
196eec8296
commit
cbcc883ea3
@ -125,6 +125,17 @@ class SdyDimSharding:
|
||||
is_closed=self.is_closed,
|
||||
priority=self.priority)
|
||||
|
||||
def __repr__(self):
|
||||
return f'SdyDimSharding({self._custom_repr()})'
|
||||
|
||||
def _custom_repr(self):
|
||||
axes_repr = ', '.join(f"'{a}'" for a in self.axes)
|
||||
open_repr = ''
|
||||
if not self.is_closed:
|
||||
open_repr = ', ?' if self.axes else '?'
|
||||
priority_repr = '' if self.priority is None else f'p{self.priority}'
|
||||
return f'{{{axes_repr}{open_repr}}}{priority_repr}'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SdyArraySharding:
|
||||
@ -146,6 +157,13 @@ class SdyArraySharding:
|
||||
mesh_attr,
|
||||
[dim_sharding.build() for dim_sharding in self.dimension_shardings])
|
||||
|
||||
def __repr__(self):
|
||||
dim_sharding_repr = ', '.join(
|
||||
d._custom_repr() for d in self.dimension_shardings)
|
||||
device_id_repr = (f', device_ids={self.logical_device_ids}'
|
||||
if self.logical_device_ids is not None else '')
|
||||
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr})"
|
||||
|
||||
|
||||
@util.cache(max_size=4096, trace_context_in_key=False)
|
||||
def named_sharding_to_xla_hlo_sharding(
|
||||
|
@ -6632,6 +6632,28 @@ class ShardyTest(jtu.JaxTestCase):
|
||||
lowered_str = jax.jit(f, in_shardings=[AUTO(mesh), AUTO(mesh)]).lower(x, x).as_text()
|
||||
self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str)
|
||||
|
||||
def test_array_sharding_repr_with_priority(self):
|
||||
sharding = sharding_impls.SdyArraySharding(
|
||||
mesh_shape=(('data', 4), ('model', 8), ('expert', 2)),
|
||||
dimension_shardings=[
|
||||
sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_closed=True),
|
||||
sharding_impls.SdyDimSharding(axes=['model'], is_closed=False, priority=2)])
|
||||
self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])")
|
||||
|
||||
def test_array_sharding_repr_with_logical_ids(self):
|
||||
abstract_mesh = jax.sharding.AbstractMesh((('x', 4), ('y', 8), ('z', 2)))
|
||||
ns = NamedSharding(abstract_mesh, P(('x', 'y'), 'z', P.UNCONSTRAINED, None),
|
||||
_logical_device_ids=[4, 5, 6, 7, 0, 1, 2, 3])
|
||||
self.assertEqual(repr(ns._to_sdy_sharding(4)),
|
||||
"SdyArraySharding([{'x', 'y'}, {'z'}, {?}, {}], "
|
||||
"device_ids=[4, 5, 6, 7, 0, 1, 2, 3])")
|
||||
|
||||
def test_dimension_sharding_repr(self):
|
||||
dim_sharding = sharding_impls.SdyDimSharding(
|
||||
axes=['data', 'model'], is_closed=False, priority=2)
|
||||
self.assertEqual(repr(dim_sharding),
|
||||
"SdyDimSharding({'data', 'model', ?}p2)")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user