Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types) instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types) so that we are consistent across all Mesh APIs: Mesh, AbstractMesh and make_mesh

PiperOrigin-RevId: 736371111
This commit is contained in:
Yash Katariya 2025-03-12 21:31:51 -07:00 committed by jax authors
parent c6dcbb6759
commit a4ca0dbc6c
8 changed files with 46 additions and 48 deletions

View File

@ -1440,7 +1440,8 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
mesh = None
if shardy_enabled:
sdy_mesh_axes = xla_extension.sdy.get_mesh(mlir.module_to_bytecode(submodule))
mesh = mesh_lib.AbstractMesh(tuple(sdy_mesh_axes)) if sdy_mesh_axes else None
mesh = mesh_lib.AbstractMesh(
*list(zip(*sdy_mesh_axes))[::-1]) if sdy_mesh_axes else None
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):

View File

@ -416,7 +416,8 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator):
@functools.cached_property
def abstract_mesh(self):
return AbstractMesh(self.shape_tuple, axis_types=self._axis_types)
return AbstractMesh(self.axis_sizes, self.axis_names,
axis_types=self._axis_types)
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
@ -441,15 +442,12 @@ class AbstractMesh(_BaseMesh):
details.
"""
def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *,
axis_types: AxisTypes | tuple[AxisTypes, ...] | None = None):
self.shape_tuple = shape_tuple
if self.shape_tuple:
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
else:
self._axis_names, self._axis_sizes = (), ()
self._size = math.prod(self._axis_sizes) if self._axis_sizes else 0
self._axis_types = _normalize_axis_types(self._axis_names, axis_types)
def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...],
*, axis_types: AxisTypes | tuple[AxisTypes, ...] | None = None):
self.axis_sizes = axis_sizes
self.axis_names = axis_names
self._size = math.prod(self.axis_sizes) if self.axis_sizes else 0
self._axis_types = _normalize_axis_types(self.axis_names, axis_types)
def __hash__(self):
return hash((self.shape_tuple, self._axis_types))
@ -468,14 +466,6 @@ class AbstractMesh(_BaseMesh):
atr = f", axis_types={self._axis_types}"
return f"AbstractMesh({mesh_repr}{atr})"
@property
def axis_names(self):
return self._axis_names
@property
def axis_sizes(self) -> tuple[int, ...]:
return self._axis_sizes
@property
def size(self):
return self._size
@ -484,6 +474,12 @@ class AbstractMesh(_BaseMesh):
def shape(self):
return collections.OrderedDict(self.shape_tuple)
@functools.cached_property
def shape_tuple(self):
return tuple(
(name, size)
for name, size in safe_zip(self.axis_names, self.axis_sizes))
@property
def _internal_device_list(self):
return None
@ -499,7 +495,8 @@ class AbstractMesh(_BaseMesh):
def update_axis_types(self, name_to_type: dict[MeshAxisName, AxisTypes]):
new_axis_types = tuple(name_to_type[n] if n in name_to_type else a
for n, a in zip(self.axis_names, self._axis_types))
return AbstractMesh(self.shape_tuple, axis_types=new_axis_types)
return AbstractMesh(self.axis_sizes, self.axis_names,
axis_types=new_axis_types)
@property
def devices(self):
@ -553,7 +550,7 @@ class SetAbstractMeshContextManager:
set_abstract_mesh = SetAbstractMeshContextManager
empty_abstract_mesh = AbstractMesh(())
empty_abstract_mesh = AbstractMesh((), ())
def get_abstract_mesh():
val = jax_config.abstract_mesh_context_manager.value

View File

@ -502,7 +502,8 @@ def _as_manual_mesh(mesh, auto: frozenset):
else:
assert n in explicit_axes
new_axis_types.append(AxisTypes.Explicit)
return AbstractMesh(mesh.shape_tuple, axis_types=tuple(new_axis_types))
return AbstractMesh(mesh.axis_sizes, mesh.axis_names,
axis_types=tuple(new_axis_types))
def _extend_axis_env(mesh, auto):

View File

@ -1338,14 +1338,14 @@ class ShardingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
'Number of axis names should match the number of axis_types'):
jax.sharding.AbstractMesh((('x', 2), ('y', 1)),
jax.sharding.AbstractMesh((2, 1), ('x', 'y'),
axis_types=jax.sharding.AxisTypes.Auto)
def test_make_mesh_axis_types(self):
Auto, Explicit, Manual = AxisTypes.Auto, AxisTypes.Explicit, AxisTypes.Manual
mesh1 = jax.sharding.AbstractMesh((('x', 2),), axis_types=Auto)
mesh2 = jax.sharding.AbstractMesh((('x', 2),), axis_types=Auto)
mesh1 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
mesh2 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
self.assertEqual(mesh1, mesh2)
mesh = jax.make_mesh((1, 1), ('x', 'y'))
@ -1501,7 +1501,7 @@ class RngShardingTest(jtu.JaxTestCase):
self.assertTrue(abstract_mesh.empty)
self.assertEqual(abstract_mesh.size, 0)
abstract_mesh2 = jax.sharding.AbstractMesh(())
abstract_mesh2 = jax.sharding.AbstractMesh((), ())
self.assertTrue(abstract_mesh2.empty)
self.assertEqual(abstract_mesh2.size, 0)

View File

@ -1173,7 +1173,7 @@ class JaxExportTest(jtu.JaxTestCase):
if jax.local_device_count() < 2:
self.skipTest("Need at least 2 devices")
abs_mesh = jax.sharding.AbstractMesh((("x", 2),))
abs_mesh = jax.sharding.AbstractMesh((2,), 'x')
input_sharding = jax.sharding.NamedSharding(abs_mesh, P("x", None))
output_sharding = jax.sharding.NamedSharding(abs_mesh, P(None, "x"))
@jax.jit

View File

@ -4677,7 +4677,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@jax.jit
def f(x):
x = with_sharding_constraint(
x, NamedSharding(mesh_lib.AbstractMesh(mesh1.shape_tuple), P('x')))
x, NamedSharding(mesh1.abstract_mesh, P('x')))
return jax.lax.sin(x)
with (
@ -4701,7 +4701,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple)
abstract_mesh = mesh.abstract_mesh
def f(x):
x = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
@ -4718,7 +4718,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_wsc_sds_abstract_mesh(self):
mesh = jtu.create_mesh((2,), 'x')
s = NamedSharding(mesh, P())
abstract_mesh = mesh_lib.AbstractMesh(mesh.shape_tuple)
abstract_mesh = mesh.abstract_mesh
@jax.jit
def f(x):
@ -4747,7 +4747,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_wsc_abstract_mesh_errors(self):
mesh = jtu.create_mesh((2,), ('x',))
np_inp = np.arange(8)
abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple)
abstract_mesh = mesh.abstract_mesh
s_abs = NamedSharding(abstract_mesh, P('x'))
with self.assertRaisesRegex(
@ -4759,8 +4759,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with_sharding_constraint(jnp.arange(8), s_abs)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
abs_mesh2 = mesh_lib.AbstractMesh(
jtu.create_mesh((2,), 'y').shape_tuple)
abs_mesh2 = jtu.create_mesh((2,), 'y').abstract_mesh
with self.assertRaisesRegex(
ValueError,
'Mesh shape of the input.*does not'
@ -5647,7 +5646,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
full_user_mesh = mesh_lib.AbstractMesh(
(('x', 2), ('y', 2)), axis_types=(AxisTypes.Explicit,) * 2)
(2, 2), ('x', 'y'), axis_types=(AxisTypes.Explicit,) * 2)
@jax.jit
def f(x):
@ -6266,7 +6265,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_aval_spec_explicit_auto_complete(self):
abstract_mesh = mesh_lib.AbstractMesh(
(('x', 2),), axis_types=(AxisTypes.Explicit,))
(2,), 'x', axis_types=AxisTypes.Explicit)
s = NamedSharding(abstract_mesh, P('x'))
out = core.ShapedArray((8, 2), jnp.int32, sharding=s)
self.assertEqual(out.sharding.spec, P('x', None))
@ -7050,7 +7049,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_divisbility_aval_error(self):
abstract_mesh = mesh_lib.AbstractMesh(
(('x', 2),), axis_types=(AxisTypes.Explicit,))
(2,), ('x',), axis_types=AxisTypes.Explicit)
s = NamedSharding(abstract_mesh, P('x'))
with self.assertRaisesRegex(
ValueError, 'does not evenly divide the dimension size'):
@ -7874,7 +7873,7 @@ class ShardyTest(jtu.JaxTestCase):
self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])")
def test_array_sharding_repr_with_logical_ids(self):
abstract_mesh = jax.sharding.AbstractMesh((('x', 4), ('y', 8), ('z', 2)))
abstract_mesh = jax.sharding.AbstractMesh((4, 8, 2), ('x', 'y', 'z'))
ns = NamedSharding(abstract_mesh, P(('x', 'y'), 'z', P.UNCONSTRAINED, None),
_logical_device_ids=[4, 5, 6, 7, 0, 1, 2, 3])
self.assertEqual(repr(ns._to_sdy_sharding(4)),

View File

@ -496,7 +496,7 @@ class RooflineTest(jtu.JaxTestCase):
]:
out, result = roofline.roofline(
f,
mesh=mesh.AbstractMesh(()),
mesh=mesh.AbstractMesh((), ()),
in_specs=(P(), P()),
out_specs=P(),
)(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int))
@ -516,7 +516,7 @@ class RooflineTest(jtu.JaxTestCase):
]:
_, result = roofline.roofline(
lambda a, b: a + b,
mesh=mesh.AbstractMesh(()),
mesh=mesh.AbstractMesh((), ()),
in_specs=(P(), P()),
out_specs=P(),
)(left, right)
@ -532,7 +532,7 @@ class RooflineTest(jtu.JaxTestCase):
_, result = roofline.roofline(
f,
mesh=mesh.AbstractMesh(()),
mesh=mesh.AbstractMesh((), ()),
in_specs=(P(), P()),
out_specs=P(),
)(jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int))
@ -549,7 +549,7 @@ class RooflineTest(jtu.JaxTestCase):
def test_no_specs(self):
_, result = roofline.roofline(
lambda a, b: a + b,
mesh=mesh.AbstractMesh(()),
mesh=mesh.AbstractMesh((), ()),
)(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int))
self.assertEqual(result.unfused_flops, 3 * 8)
@ -562,7 +562,7 @@ class RooflineTest(jtu.JaxTestCase):
def test_dot_general(self):
_, result = roofline.roofline(
lambda a, b: a @ b,
mesh=mesh.AbstractMesh(()),
mesh=mesh.AbstractMesh((), ()),
in_specs=(P(), P()),
out_specs=P(),
)(jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int))
@ -574,7 +574,7 @@ class RooflineTest(jtu.JaxTestCase):
def test_reduce_sum_no_axis(self):
_, result = roofline.roofline(
lambda x: jnp.sum(x),
mesh=mesh.AbstractMesh(()),
mesh=mesh.AbstractMesh((), ()),
in_specs=(P()),
out_specs=P(),
)(jnp.zeros((11, 4)))
@ -592,7 +592,7 @@ class RooflineTest(jtu.JaxTestCase):
]:
_, result = roofline.roofline(
lambda x: jnp.sum(x, axis=axis),
mesh=mesh.AbstractMesh(()),
mesh=mesh.AbstractMesh((), ()),
in_specs=(P()),
out_specs=P(),
)(jnp.zeros((11, 4)))

View File

@ -40,7 +40,7 @@ from jax._src import test_util as jtu
from jax._src.lib.mlir.dialects import sdy
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
from jax._src.ad_checkpoint import saved_residuals
from jax._src.mesh import AbstractMesh, AxisTypes
from jax._src.mesh import AxisTypes
from jax._src.interpreters import partial_eval as pe
from jax._src import linear_util as lu
from jax._src import tree_util
@ -796,7 +796,7 @@ class ShardMapTest(jtu.JaxTestCase):
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'i')
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'i')
abstract_mesh = AbstractMesh(mesh1.shape_tuple)
abstract_mesh = mesh1.abstract_mesh
@jax.jit
def f(x):
@ -823,7 +823,7 @@ class ShardMapTest(jtu.JaxTestCase):
def test_shmap_abstract_mesh_errors(self):
mesh = jtu.create_mesh((2,), ('x',))
np_inp = np.arange(8)
abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple)
abstract_mesh = mesh.abstract_mesh
with self.assertRaisesRegex(
ValueError,
@ -834,7 +834,7 @@ class ShardMapTest(jtu.JaxTestCase):
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
mesh2 = jtu.create_mesh((2,), 'y')
abs_mesh2 = AbstractMesh(mesh2.shape_tuple)
abs_mesh2 = mesh2.abstract_mesh
with self.assertRaisesRegex(
ValueError,
'Mesh shape of the input.*does not match the mesh shape passed to'