diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index f818d6bf1..afae3d9bc 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1440,7 +1440,8 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, mesh = None if shardy_enabled: sdy_mesh_axes = xla_extension.sdy.get_mesh(mlir.module_to_bytecode(submodule)) - mesh = mesh_lib.AbstractMesh(tuple(sdy_mesh_axes)) if sdy_mesh_axes else None + mesh = mesh_lib.AbstractMesh( + *list(zip(*sdy_mesh_axes))[::-1]) if sdy_mesh_axes else None axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.ShardingContext): diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index e614fba37..4c456e969 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -416,7 +416,8 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator): @functools.cached_property def abstract_mesh(self): - return AbstractMesh(self.shape_tuple, axis_types=self._axis_types) + return AbstractMesh(self.axis_sizes, self.axis_names, + axis_types=self._axis_types) EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) @@ -441,15 +442,12 @@ class AbstractMesh(_BaseMesh): details. """ - def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *, - axis_types: AxisTypes | tuple[AxisTypes, ...] | None = None): - self.shape_tuple = shape_tuple - if self.shape_tuple: - self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple)) - else: - self._axis_names, self._axis_sizes = (), () - self._size = math.prod(self._axis_sizes) if self._axis_sizes else 0 - self._axis_types = _normalize_axis_types(self._axis_names, axis_types) + def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...], + *, axis_types: AxisTypes | tuple[AxisTypes, ...] | None = None): + self.axis_sizes = axis_sizes + self.axis_names = axis_names + self._size = math.prod(self.axis_sizes) if self.axis_sizes else 0 + self._axis_types = _normalize_axis_types(self.axis_names, axis_types) def __hash__(self): return hash((self.shape_tuple, self._axis_types)) @@ -468,14 +466,6 @@ class AbstractMesh(_BaseMesh): atr = f", axis_types={self._axis_types}" return f"AbstractMesh({mesh_repr}{atr})" - @property - def axis_names(self): - return self._axis_names - - @property - def axis_sizes(self) -> tuple[int, ...]: - return self._axis_sizes - @property def size(self): return self._size @@ -484,6 +474,12 @@ class AbstractMesh(_BaseMesh): def shape(self): return collections.OrderedDict(self.shape_tuple) + @functools.cached_property + def shape_tuple(self): + return tuple( + (name, size) + for name, size in safe_zip(self.axis_names, self.axis_sizes)) + @property def _internal_device_list(self): return None @@ -499,7 +495,8 @@ class AbstractMesh(_BaseMesh): def update_axis_types(self, name_to_type: dict[MeshAxisName, AxisTypes]): new_axis_types = tuple(name_to_type[n] if n in name_to_type else a for n, a in zip(self.axis_names, self._axis_types)) - return AbstractMesh(self.shape_tuple, axis_types=new_axis_types) + return AbstractMesh(self.axis_sizes, self.axis_names, + axis_types=new_axis_types) @property def devices(self): @@ -553,7 +550,7 @@ class SetAbstractMeshContextManager: set_abstract_mesh = SetAbstractMeshContextManager -empty_abstract_mesh = AbstractMesh(()) +empty_abstract_mesh = AbstractMesh((), ()) def get_abstract_mesh(): val = jax_config.abstract_mesh_context_manager.value diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3ccd8e3be..f43062326 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -502,7 +502,8 @@ def _as_manual_mesh(mesh, auto: frozenset): else: assert n in explicit_axes new_axis_types.append(AxisTypes.Explicit) - return AbstractMesh(mesh.shape_tuple, axis_types=tuple(new_axis_types)) + return AbstractMesh(mesh.axis_sizes, mesh.axis_names, + axis_types=tuple(new_axis_types)) def _extend_axis_env(mesh, auto): diff --git a/tests/array_test.py b/tests/array_test.py index dda2a0cf8..dbebe5169 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1338,14 +1338,14 @@ class ShardingTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, 'Number of axis names should match the number of axis_types'): - jax.sharding.AbstractMesh((('x', 2), ('y', 1)), + jax.sharding.AbstractMesh((2, 1), ('x', 'y'), axis_types=jax.sharding.AxisTypes.Auto) def test_make_mesh_axis_types(self): Auto, Explicit, Manual = AxisTypes.Auto, AxisTypes.Explicit, AxisTypes.Manual - mesh1 = jax.sharding.AbstractMesh((('x', 2),), axis_types=Auto) - mesh2 = jax.sharding.AbstractMesh((('x', 2),), axis_types=Auto) + mesh1 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto) + mesh2 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto) self.assertEqual(mesh1, mesh2) mesh = jax.make_mesh((1, 1), ('x', 'y')) @@ -1501,7 +1501,7 @@ class RngShardingTest(jtu.JaxTestCase): self.assertTrue(abstract_mesh.empty) self.assertEqual(abstract_mesh.size, 0) - abstract_mesh2 = jax.sharding.AbstractMesh(()) + abstract_mesh2 = jax.sharding.AbstractMesh((), ()) self.assertTrue(abstract_mesh2.empty) self.assertEqual(abstract_mesh2.size, 0) diff --git a/tests/export_test.py b/tests/export_test.py index 6baecebe1..2b083f312 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1173,7 +1173,7 @@ class JaxExportTest(jtu.JaxTestCase): if jax.local_device_count() < 2: self.skipTest("Need at least 2 devices") - abs_mesh = jax.sharding.AbstractMesh((("x", 2),)) + abs_mesh = jax.sharding.AbstractMesh((2,), 'x') input_sharding = jax.sharding.NamedSharding(abs_mesh, P("x", None)) output_sharding = jax.sharding.NamedSharding(abs_mesh, P(None, "x")) @jax.jit diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9f8ce7209..8d9b84ce7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4677,7 +4677,7 @@ class ArrayPjitTest(jtu.JaxTestCase): @jax.jit def f(x): x = with_sharding_constraint( - x, NamedSharding(mesh_lib.AbstractMesh(mesh1.shape_tuple), P('x'))) + x, NamedSharding(mesh1.abstract_mesh, P('x'))) return jax.lax.sin(x) with ( @@ -4701,7 +4701,7 @@ class ArrayPjitTest(jtu.JaxTestCase): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) - abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) + abstract_mesh = mesh.abstract_mesh def f(x): x = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x'))) @@ -4718,7 +4718,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def test_wsc_sds_abstract_mesh(self): mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P()) - abstract_mesh = mesh_lib.AbstractMesh(mesh.shape_tuple) + abstract_mesh = mesh.abstract_mesh @jax.jit def f(x): @@ -4747,7 +4747,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def test_wsc_abstract_mesh_errors(self): mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) - abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) + abstract_mesh = mesh.abstract_mesh s_abs = NamedSharding(abstract_mesh, P('x')) with self.assertRaisesRegex( @@ -4759,8 +4759,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with_sharding_constraint(jnp.arange(8), s_abs) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) - abs_mesh2 = mesh_lib.AbstractMesh( - jtu.create_mesh((2,), 'y').shape_tuple) + abs_mesh2 = jtu.create_mesh((2,), 'y').abstract_mesh with self.assertRaisesRegex( ValueError, 'Mesh shape of the input.*does not' @@ -5647,7 +5646,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) full_user_mesh = mesh_lib.AbstractMesh( - (('x', 2), ('y', 2)), axis_types=(AxisTypes.Explicit,) * 2) + (2, 2), ('x', 'y'), axis_types=(AxisTypes.Explicit,) * 2) @jax.jit def f(x): @@ -6266,7 +6265,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_aval_spec_explicit_auto_complete(self): abstract_mesh = mesh_lib.AbstractMesh( - (('x', 2),), axis_types=(AxisTypes.Explicit,)) + (2,), 'x', axis_types=AxisTypes.Explicit) s = NamedSharding(abstract_mesh, P('x')) out = core.ShapedArray((8, 2), jnp.int32, sharding=s) self.assertEqual(out.sharding.spec, P('x', None)) @@ -7050,7 +7049,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_divisbility_aval_error(self): abstract_mesh = mesh_lib.AbstractMesh( - (('x', 2),), axis_types=(AxisTypes.Explicit,)) + (2,), ('x',), axis_types=AxisTypes.Explicit) s = NamedSharding(abstract_mesh, P('x')) with self.assertRaisesRegex( ValueError, 'does not evenly divide the dimension size'): @@ -7874,7 +7873,7 @@ class ShardyTest(jtu.JaxTestCase): self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])") def test_array_sharding_repr_with_logical_ids(self): - abstract_mesh = jax.sharding.AbstractMesh((('x', 4), ('y', 8), ('z', 2))) + abstract_mesh = jax.sharding.AbstractMesh((4, 8, 2), ('x', 'y', 'z')) ns = NamedSharding(abstract_mesh, P(('x', 'y'), 'z', P.UNCONSTRAINED, None), _logical_device_ids=[4, 5, 6, 7, 0, 1, 2, 3]) self.assertEqual(repr(ns._to_sdy_sharding(4)), diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 2fd3a24d3..a8c003321 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -496,7 +496,7 @@ class RooflineTest(jtu.JaxTestCase): ]: out, result = roofline.roofline( f, - mesh=mesh.AbstractMesh(()), + mesh=mesh.AbstractMesh((), ()), in_specs=(P(), P()), out_specs=P(), )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) @@ -516,7 +516,7 @@ class RooflineTest(jtu.JaxTestCase): ]: _, result = roofline.roofline( lambda a, b: a + b, - mesh=mesh.AbstractMesh(()), + mesh=mesh.AbstractMesh((), ()), in_specs=(P(), P()), out_specs=P(), )(left, right) @@ -532,7 +532,7 @@ class RooflineTest(jtu.JaxTestCase): _, result = roofline.roofline( f, - mesh=mesh.AbstractMesh(()), + mesh=mesh.AbstractMesh((), ()), in_specs=(P(), P()), out_specs=P(), )(jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int)) @@ -549,7 +549,7 @@ class RooflineTest(jtu.JaxTestCase): def test_no_specs(self): _, result = roofline.roofline( lambda a, b: a + b, - mesh=mesh.AbstractMesh(()), + mesh=mesh.AbstractMesh((), ()), )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) self.assertEqual(result.unfused_flops, 3 * 8) @@ -562,7 +562,7 @@ class RooflineTest(jtu.JaxTestCase): def test_dot_general(self): _, result = roofline.roofline( lambda a, b: a @ b, - mesh=mesh.AbstractMesh(()), + mesh=mesh.AbstractMesh((), ()), in_specs=(P(), P()), out_specs=P(), )(jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int)) @@ -574,7 +574,7 @@ class RooflineTest(jtu.JaxTestCase): def test_reduce_sum_no_axis(self): _, result = roofline.roofline( lambda x: jnp.sum(x), - mesh=mesh.AbstractMesh(()), + mesh=mesh.AbstractMesh((), ()), in_specs=(P()), out_specs=P(), )(jnp.zeros((11, 4))) @@ -592,7 +592,7 @@ class RooflineTest(jtu.JaxTestCase): ]: _, result = roofline.roofline( lambda x: jnp.sum(x, axis=axis), - mesh=mesh.AbstractMesh(()), + mesh=mesh.AbstractMesh((), ()), in_specs=(P()), out_specs=P(), )(jnp.zeros((11, 4))) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 55012fe18..e9157759b 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -40,7 +40,7 @@ from jax._src import test_util as jtu from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals -from jax._src.mesh import AbstractMesh, AxisTypes +from jax._src.mesh import AxisTypes from jax._src.interpreters import partial_eval as pe from jax._src import linear_util as lu from jax._src import tree_util @@ -796,7 +796,7 @@ class ShardMapTest(jtu.JaxTestCase): mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'i') mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'i') - abstract_mesh = AbstractMesh(mesh1.shape_tuple) + abstract_mesh = mesh1.abstract_mesh @jax.jit def f(x): @@ -823,7 +823,7 @@ class ShardMapTest(jtu.JaxTestCase): def test_shmap_abstract_mesh_errors(self): mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) - abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) + abstract_mesh = mesh.abstract_mesh with self.assertRaisesRegex( ValueError, @@ -834,7 +834,7 @@ class ShardMapTest(jtu.JaxTestCase): arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) mesh2 = jtu.create_mesh((2,), 'y') - abs_mesh2 = AbstractMesh(mesh2.shape_tuple) + abs_mesh2 = mesh2.abstract_mesh with self.assertRaisesRegex( ValueError, 'Mesh shape of the input.*does not match the mesh shape passed to'