mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types)
instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types)
so that we are consistent across all Mesh APIs: Mesh
, AbstractMesh
and make_mesh
PiperOrigin-RevId: 736371111
This commit is contained in:
parent
c6dcbb6759
commit
a4ca0dbc6c
@ -1440,7 +1440,8 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
mesh = None
|
||||
if shardy_enabled:
|
||||
sdy_mesh_axes = xla_extension.sdy.get_mesh(mlir.module_to_bytecode(submodule))
|
||||
mesh = mesh_lib.AbstractMesh(tuple(sdy_mesh_axes)) if sdy_mesh_axes else None
|
||||
mesh = mesh_lib.AbstractMesh(
|
||||
*list(zip(*sdy_mesh_axes))[::-1]) if sdy_mesh_axes else None
|
||||
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
|
@ -416,7 +416,8 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator):
|
||||
|
||||
@functools.cached_property
|
||||
def abstract_mesh(self):
|
||||
return AbstractMesh(self.shape_tuple, axis_types=self._axis_types)
|
||||
return AbstractMesh(self.axis_sizes, self.axis_names,
|
||||
axis_types=self._axis_types)
|
||||
|
||||
|
||||
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
|
||||
@ -441,15 +442,12 @@ class AbstractMesh(_BaseMesh):
|
||||
details.
|
||||
"""
|
||||
|
||||
def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *,
|
||||
axis_types: AxisTypes | tuple[AxisTypes, ...] | None = None):
|
||||
self.shape_tuple = shape_tuple
|
||||
if self.shape_tuple:
|
||||
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
|
||||
else:
|
||||
self._axis_names, self._axis_sizes = (), ()
|
||||
self._size = math.prod(self._axis_sizes) if self._axis_sizes else 0
|
||||
self._axis_types = _normalize_axis_types(self._axis_names, axis_types)
|
||||
def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...],
|
||||
*, axis_types: AxisTypes | tuple[AxisTypes, ...] | None = None):
|
||||
self.axis_sizes = axis_sizes
|
||||
self.axis_names = axis_names
|
||||
self._size = math.prod(self.axis_sizes) if self.axis_sizes else 0
|
||||
self._axis_types = _normalize_axis_types(self.axis_names, axis_types)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.shape_tuple, self._axis_types))
|
||||
@ -468,14 +466,6 @@ class AbstractMesh(_BaseMesh):
|
||||
atr = f", axis_types={self._axis_types}"
|
||||
return f"AbstractMesh({mesh_repr}{atr})"
|
||||
|
||||
@property
|
||||
def axis_names(self):
|
||||
return self._axis_names
|
||||
|
||||
@property
|
||||
def axis_sizes(self) -> tuple[int, ...]:
|
||||
return self._axis_sizes
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return self._size
|
||||
@ -484,6 +474,12 @@ class AbstractMesh(_BaseMesh):
|
||||
def shape(self):
|
||||
return collections.OrderedDict(self.shape_tuple)
|
||||
|
||||
@functools.cached_property
|
||||
def shape_tuple(self):
|
||||
return tuple(
|
||||
(name, size)
|
||||
for name, size in safe_zip(self.axis_names, self.axis_sizes))
|
||||
|
||||
@property
|
||||
def _internal_device_list(self):
|
||||
return None
|
||||
@ -499,7 +495,8 @@ class AbstractMesh(_BaseMesh):
|
||||
def update_axis_types(self, name_to_type: dict[MeshAxisName, AxisTypes]):
|
||||
new_axis_types = tuple(name_to_type[n] if n in name_to_type else a
|
||||
for n, a in zip(self.axis_names, self._axis_types))
|
||||
return AbstractMesh(self.shape_tuple, axis_types=new_axis_types)
|
||||
return AbstractMesh(self.axis_sizes, self.axis_names,
|
||||
axis_types=new_axis_types)
|
||||
|
||||
@property
|
||||
def devices(self):
|
||||
@ -553,7 +550,7 @@ class SetAbstractMeshContextManager:
|
||||
set_abstract_mesh = SetAbstractMeshContextManager
|
||||
|
||||
|
||||
empty_abstract_mesh = AbstractMesh(())
|
||||
empty_abstract_mesh = AbstractMesh((), ())
|
||||
|
||||
def get_abstract_mesh():
|
||||
val = jax_config.abstract_mesh_context_manager.value
|
||||
|
@ -502,7 +502,8 @@ def _as_manual_mesh(mesh, auto: frozenset):
|
||||
else:
|
||||
assert n in explicit_axes
|
||||
new_axis_types.append(AxisTypes.Explicit)
|
||||
return AbstractMesh(mesh.shape_tuple, axis_types=tuple(new_axis_types))
|
||||
return AbstractMesh(mesh.axis_sizes, mesh.axis_names,
|
||||
axis_types=tuple(new_axis_types))
|
||||
|
||||
|
||||
def _extend_axis_env(mesh, auto):
|
||||
|
@ -1338,14 +1338,14 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Number of axis names should match the number of axis_types'):
|
||||
jax.sharding.AbstractMesh((('x', 2), ('y', 1)),
|
||||
jax.sharding.AbstractMesh((2, 1), ('x', 'y'),
|
||||
axis_types=jax.sharding.AxisTypes.Auto)
|
||||
|
||||
def test_make_mesh_axis_types(self):
|
||||
Auto, Explicit, Manual = AxisTypes.Auto, AxisTypes.Explicit, AxisTypes.Manual
|
||||
|
||||
mesh1 = jax.sharding.AbstractMesh((('x', 2),), axis_types=Auto)
|
||||
mesh2 = jax.sharding.AbstractMesh((('x', 2),), axis_types=Auto)
|
||||
mesh1 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
|
||||
mesh2 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
|
||||
self.assertEqual(mesh1, mesh2)
|
||||
|
||||
mesh = jax.make_mesh((1, 1), ('x', 'y'))
|
||||
@ -1501,7 +1501,7 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
self.assertTrue(abstract_mesh.empty)
|
||||
self.assertEqual(abstract_mesh.size, 0)
|
||||
|
||||
abstract_mesh2 = jax.sharding.AbstractMesh(())
|
||||
abstract_mesh2 = jax.sharding.AbstractMesh((), ())
|
||||
self.assertTrue(abstract_mesh2.empty)
|
||||
self.assertEqual(abstract_mesh2.size, 0)
|
||||
|
||||
|
@ -1173,7 +1173,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
if jax.local_device_count() < 2:
|
||||
self.skipTest("Need at least 2 devices")
|
||||
|
||||
abs_mesh = jax.sharding.AbstractMesh((("x", 2),))
|
||||
abs_mesh = jax.sharding.AbstractMesh((2,), 'x')
|
||||
input_sharding = jax.sharding.NamedSharding(abs_mesh, P("x", None))
|
||||
output_sharding = jax.sharding.NamedSharding(abs_mesh, P(None, "x"))
|
||||
@jax.jit
|
||||
|
@ -4677,7 +4677,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
x = with_sharding_constraint(
|
||||
x, NamedSharding(mesh_lib.AbstractMesh(mesh1.shape_tuple), P('x')))
|
||||
x, NamedSharding(mesh1.abstract_mesh, P('x')))
|
||||
return jax.lax.sin(x)
|
||||
|
||||
with (
|
||||
@ -4701,7 +4701,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple)
|
||||
abstract_mesh = mesh.abstract_mesh
|
||||
|
||||
def f(x):
|
||||
x = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
|
||||
@ -4718,7 +4718,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def test_wsc_sds_abstract_mesh(self):
|
||||
mesh = jtu.create_mesh((2,), 'x')
|
||||
s = NamedSharding(mesh, P())
|
||||
abstract_mesh = mesh_lib.AbstractMesh(mesh.shape_tuple)
|
||||
abstract_mesh = mesh.abstract_mesh
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
@ -4747,7 +4747,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def test_wsc_abstract_mesh_errors(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
np_inp = np.arange(8)
|
||||
abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple)
|
||||
abstract_mesh = mesh.abstract_mesh
|
||||
s_abs = NamedSharding(abstract_mesh, P('x'))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
@ -4759,8 +4759,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with_sharding_constraint(jnp.arange(8), s_abs)
|
||||
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
||||
abs_mesh2 = mesh_lib.AbstractMesh(
|
||||
jtu.create_mesh((2,), 'y').shape_tuple)
|
||||
abs_mesh2 = jtu.create_mesh((2,), 'y').abstract_mesh
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Mesh shape of the input.*does not'
|
||||
@ -5647,7 +5646,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
full_user_mesh = mesh_lib.AbstractMesh(
|
||||
(('x', 2), ('y', 2)), axis_types=(AxisTypes.Explicit,) * 2)
|
||||
(2, 2), ('x', 'y'), axis_types=(AxisTypes.Explicit,) * 2)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
@ -6266,7 +6265,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_aval_spec_explicit_auto_complete(self):
|
||||
abstract_mesh = mesh_lib.AbstractMesh(
|
||||
(('x', 2),), axis_types=(AxisTypes.Explicit,))
|
||||
(2,), 'x', axis_types=AxisTypes.Explicit)
|
||||
s = NamedSharding(abstract_mesh, P('x'))
|
||||
out = core.ShapedArray((8, 2), jnp.int32, sharding=s)
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
@ -7050,7 +7049,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_divisbility_aval_error(self):
|
||||
abstract_mesh = mesh_lib.AbstractMesh(
|
||||
(('x', 2),), axis_types=(AxisTypes.Explicit,))
|
||||
(2,), ('x',), axis_types=AxisTypes.Explicit)
|
||||
s = NamedSharding(abstract_mesh, P('x'))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'does not evenly divide the dimension size'):
|
||||
@ -7874,7 +7873,7 @@ class ShardyTest(jtu.JaxTestCase):
|
||||
self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])")
|
||||
|
||||
def test_array_sharding_repr_with_logical_ids(self):
|
||||
abstract_mesh = jax.sharding.AbstractMesh((('x', 4), ('y', 8), ('z', 2)))
|
||||
abstract_mesh = jax.sharding.AbstractMesh((4, 8, 2), ('x', 'y', 'z'))
|
||||
ns = NamedSharding(abstract_mesh, P(('x', 'y'), 'z', P.UNCONSTRAINED, None),
|
||||
_logical_device_ids=[4, 5, 6, 7, 0, 1, 2, 3])
|
||||
self.assertEqual(repr(ns._to_sdy_sharding(4)),
|
||||
|
@ -496,7 +496,7 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
]:
|
||||
out, result = roofline.roofline(
|
||||
f,
|
||||
mesh=mesh.AbstractMesh(()),
|
||||
mesh=mesh.AbstractMesh((), ()),
|
||||
in_specs=(P(), P()),
|
||||
out_specs=P(),
|
||||
)(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int))
|
||||
@ -516,7 +516,7 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
]:
|
||||
_, result = roofline.roofline(
|
||||
lambda a, b: a + b,
|
||||
mesh=mesh.AbstractMesh(()),
|
||||
mesh=mesh.AbstractMesh((), ()),
|
||||
in_specs=(P(), P()),
|
||||
out_specs=P(),
|
||||
)(left, right)
|
||||
@ -532,7 +532,7 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
|
||||
_, result = roofline.roofline(
|
||||
f,
|
||||
mesh=mesh.AbstractMesh(()),
|
||||
mesh=mesh.AbstractMesh((), ()),
|
||||
in_specs=(P(), P()),
|
||||
out_specs=P(),
|
||||
)(jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int))
|
||||
@ -549,7 +549,7 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
def test_no_specs(self):
|
||||
_, result = roofline.roofline(
|
||||
lambda a, b: a + b,
|
||||
mesh=mesh.AbstractMesh(()),
|
||||
mesh=mesh.AbstractMesh((), ()),
|
||||
)(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int))
|
||||
self.assertEqual(result.unfused_flops, 3 * 8)
|
||||
|
||||
@ -562,7 +562,7 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
def test_dot_general(self):
|
||||
_, result = roofline.roofline(
|
||||
lambda a, b: a @ b,
|
||||
mesh=mesh.AbstractMesh(()),
|
||||
mesh=mesh.AbstractMesh((), ()),
|
||||
in_specs=(P(), P()),
|
||||
out_specs=P(),
|
||||
)(jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int))
|
||||
@ -574,7 +574,7 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
def test_reduce_sum_no_axis(self):
|
||||
_, result = roofline.roofline(
|
||||
lambda x: jnp.sum(x),
|
||||
mesh=mesh.AbstractMesh(()),
|
||||
mesh=mesh.AbstractMesh((), ()),
|
||||
in_specs=(P()),
|
||||
out_specs=P(),
|
||||
)(jnp.zeros((11, 4)))
|
||||
@ -592,7 +592,7 @@ class RooflineTest(jtu.JaxTestCase):
|
||||
]:
|
||||
_, result = roofline.roofline(
|
||||
lambda x: jnp.sum(x, axis=axis),
|
||||
mesh=mesh.AbstractMesh(()),
|
||||
mesh=mesh.AbstractMesh((), ()),
|
||||
in_specs=(P()),
|
||||
out_specs=P(),
|
||||
)(jnp.zeros((11, 4)))
|
||||
|
@ -40,7 +40,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax._src.mesh import AbstractMesh, AxisTypes
|
||||
from jax._src.mesh import AxisTypes
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
@ -796,7 +796,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'i')
|
||||
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'i')
|
||||
abstract_mesh = AbstractMesh(mesh1.shape_tuple)
|
||||
abstract_mesh = mesh1.abstract_mesh
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
@ -823,7 +823,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def test_shmap_abstract_mesh_errors(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
np_inp = np.arange(8)
|
||||
abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple)
|
||||
abstract_mesh = mesh.abstract_mesh
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -834,7 +834,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
||||
mesh2 = jtu.create_mesh((2,), 'y')
|
||||
abs_mesh2 = AbstractMesh(mesh2.shape_tuple)
|
||||
abs_mesh2 = mesh2.abstract_mesh
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Mesh shape of the input.*does not match the mesh shape passed to'
|
||||
|
Loading…
x
Reference in New Issue
Block a user