From 49224d6cdb4dbae2a8dcc777c295ce6757d5a475 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 16 Jan 2025 17:55:15 -0800 Subject: [PATCH] Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective. Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`. PiperOrigin-RevId: 716446406 --- jax/_src/core.py | 15 +++--- jax/_src/interpreters/mlir.py | 4 +- jax/_src/interpreters/pxla.py | 2 +- jax/_src/lax/slicing.py | 2 +- jax/_src/lax/utils.py | 2 +- jax/_src/mesh.py | 39 ++++++++++------ jax/_src/numpy/array_methods.py | 6 +-- jax/_src/pjit.py | 20 ++++---- jax/_src/sharding_impls.py | 12 ++--- jax/_src/test_util.py | 4 +- jax/sharding.py | 6 ++- tests/array_test.py | 16 ++++++- tests/pjit_test.py | 82 ++++++++++++++++----------------- 13 files changed, 120 insertions(+), 90 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 6b6ec9ef5..e83ce6546 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -729,9 +729,10 @@ class Tracer(typing.Array, metaclass=StrictABCMeta): # This attribute is part of the jax.Array API, but only defined on concrete arrays. # Raising a ConcretizationTypeError would make sense, but for backward compatibility # we raise an AttributeError so that hasattr() and getattr() work as expected. - raise AttributeError(self, - f"The 'sharding' attribute is not available on {self._error_repr()}." - f"{self._origin_msg()}") + raise AttributeError( + self, + f"The 'sharding' attribute is not available on {self._error_repr()}." + f"{self._origin_msg()}") @property def committed(self): @@ -1674,7 +1675,7 @@ def _invalid_shape_error(shape: Shape, context: str=""): # TODO(yashkatariya): Only works with User/Auto. Generalize it to work with # Collective too. -def modify_spec_for_auto(spec, mesh) -> P: +def modify_spec_for_hidden(spec, mesh) -> P: new_spec = [] # type: ignore for s in spec: if s is None: @@ -1682,13 +1683,13 @@ def modify_spec_for_auto(spec, mesh) -> P: else: temp_s = s[0] if isinstance(s, tuple) else s new_spec.append( - None if mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s) + None if mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Hidden else s) return P(*new_spec) def _maybe_modify_sharding(sharding): - if mesh_lib.AxisTypes.Auto not in sharding.mesh.axis_types: + if mesh_lib.AxisTypes.Hidden not in sharding.mesh.axis_types: return sharding - new_spec = modify_spec_for_auto(sharding.spec, sharding.mesh) + new_spec = modify_spec_for_hidden(sharding.spec, sharding.mesh) return sharding.with_spec(new_spec) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 92f3b72e4..fba5cc8bb 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2596,7 +2596,7 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None): # Don't emit a wsc under full manual mode to avoid increasing HLO size. if aval.sharding.mesh._are_all_axes_collective: return op - if aval.sharding.mesh._are_all_axes_auto: + if aval.sharding.mesh._are_all_axes_hidden: return op # TODO(yashkatariya): If all the axes in pspec are AUTO or collective, # `return op` early and avoid bloating HLO size. @@ -2609,7 +2609,7 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None): proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto() if sharding_proto is None else sharding_proto) unspecified_dims = None - if aval.sharding.mesh._any_axis_auto: + if aval.sharding.mesh._any_axis_hidden: # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes # as unspecified? unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None} diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 3ae4131d7..0c7c01762 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2172,7 +2172,7 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment): if isinstance(s, UnspecifiedValue) and a.sharding is not None: spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp for sp in a.sharding.spec]) - if a.sharding.mesh._any_axis_auto else a.sharding.spec) + if a.sharding.mesh._any_axis_hidden else a.sharding.spec) out.append(NamedSharding( _abstract_to_concrete_mesh(a.sharding.mesh), spec)) else: diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index f058888ea..37cb099d7 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1883,7 +1883,7 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): # TODO(yashkatariya): Write a proper gather sharding rule. - if mesh_lib.get_abstract_mesh()._are_all_axes_auto: # type: ignore + if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore return None raise GatherShardingError( "Use `.at[...].get(out_specs=)` to provide output PartitionSpec for the" diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index cd0bb84b1..2b8757f78 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -49,7 +49,7 @@ def _get_array_abstraction_level(a): return a.array_abstraction_level def call_sharding_rule(rule, num_out, *avals, **kwargs): if config.sharding_in_types.value: - if rule is None and mesh_lib.get_abstract_mesh()._are_all_axes_auto: # type: ignore + if rule is None and mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore return None if num_out is None else [None] * num_out return rule(*avals, **kwargs) return None if num_out is None else [None] * num_out diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 02f64a689..e648b7bc6 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -103,8 +103,8 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: class AxisTypes(enum.Enum): - Auto = enum.auto() - User = enum.auto() + Hidden = enum.auto() + Visible = enum.auto() Collective = enum.auto() def __repr__(self): @@ -198,9 +198,15 @@ class Mesh(contextlib.ContextDecorator): f"devices.ndim == {devices.ndim} and " f"len(axis_names) == {len(axis_names)}.") - axis_types = ({AxisTypes.Auto: axis_names} if axis_types is None else + axis_types = ({AxisTypes.Hidden: axis_names} if axis_types is None else axis_types) axis_types_tuple = tuple(axis_types.items()) + if len(axis_names_to_types(axis_types).keys()) != len(axis_names): + raise ValueError( + "Number of axis names in axis_types should match the number of" + f" axis_names. Got axis_names={axis_names} and" + f" axis_types={axis_types}") + key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple) val = _mesh_object_dict.get(key, None) if val is not None: @@ -356,16 +362,16 @@ class Mesh(contextlib.ContextDecorator): return all(t == AxisTypes.Collective for t in self.axis_types.keys()) @functools.cached_property - def _are_all_axes_auto(self) -> bool: - return all(t == AxisTypes.Auto for t in self.axis_types.keys()) + def _are_all_axes_hidden(self) -> bool: + return all(t == AxisTypes.Hidden for t in self.axis_types.keys()) @functools.cached_property def _any_axis_collective(self) -> bool: return any(t == AxisTypes.Collective for t in self.axis_types.keys()) @functools.cached_property - def _any_axis_auto(self) -> bool: - return any(t == AxisTypes.Auto for t in self.axis_types.keys()) + def _any_axis_hidden(self) -> bool: + return any(t == AxisTypes.Hidden for t in self.axis_types.keys()) EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) @@ -397,9 +403,14 @@ class AbstractMesh: self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple)) else: self._axis_names, self._axis_sizes = (), () - self.axis_types = ({AxisTypes.Auto: self._axis_names} if axis_types is None - else axis_types) + self.axis_types = ({AxisTypes.Hidden: self._axis_names} + if axis_types is None else axis_types) self._axis_types_tuple = tuple(self.axis_types.items()) + if len(self._name_to_type.keys()) != len(self._axis_names): + raise ValueError( + "Number of axis names in axis_types should match the number of" + f" axis_names in shape_tuple. Got axis_names={self._axis_names} and" + f" axis_types={self.axis_types}") def __hash__(self): return hash((self.shape_tuple, self._axis_types_tuple)) @@ -461,16 +472,16 @@ class AbstractMesh: return all(t == AxisTypes.Collective for t in self.axis_types.keys()) @functools.cached_property - def _are_all_axes_auto(self) -> bool: - return all(t == AxisTypes.Auto for t in self.axis_types.keys()) + def _are_all_axes_hidden(self) -> bool: + return all(t == AxisTypes.Hidden for t in self.axis_types.keys()) @functools.cached_property def _any_axis_collective(self) -> bool: return any(t == AxisTypes.Collective for t in self.axis_types.keys()) @functools.cached_property - def _any_axis_auto(self) -> bool: - return any(t == AxisTypes.Auto for t in self.axis_types.keys()) + def _any_axis_hidden(self) -> bool: + return any(t == AxisTypes.Hidden for t in self.axis_types.keys()) @property def devices(self): @@ -535,7 +546,7 @@ def get_concrete_mesh(): @contextlib.contextmanager -def set_mesh(mesh: Mesh): +def use_mesh(mesh: Mesh): with (set_abstract_mesh(mesh.abstract_mesh), jax_config.sharding_in_types(True), set_concrete_mesh(mesh)): yield diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 2e918db98..eda6e36fb 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -42,7 +42,7 @@ from jax._src.lib import xla_client as xc from jax._src.numpy import array_api_metadata from jax._src.numpy import lax_numpy from jax._src import mesh as mesh_lib -from jax._src.pjit import auto_mode, PartitionSpec +from jax._src.pjit import hidden_mode, PartitionSpec from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax._src.ops import scatter @@ -781,8 +781,8 @@ class _IndexUpdateRef: fill_value=fill_value) if out_spec is not None: assert isinstance(out_spec, PartitionSpec) - take = auto_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore - out_specs=out_spec) + take = hidden_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore + out_specs=out_spec) return take(self.array, self.index) def set(self, values, *, indices_are_sorted=False, unique_indices=False, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ccd087107..96b776557 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2754,11 +2754,11 @@ def _get_new_mesh(axes: str | tuple[str, ...], axis_type: mesh_lib.AxisTypes): new_mesh = cur_mesh.update_axis_types({axis_type: axes}) # type: ignore return new_mesh -def auto_mode(fun, *, axes: str | tuple[str, ...], out_specs): - new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto) +def hidden_mode(fun, *, axes: str | tuple[str, ...], out_specs): + new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden) def decorator(*args, **kwargs): with mesh_lib.set_abstract_mesh(new_mesh): - in_specs = tree_map(lambda a: core.modify_spec_for_auto( + in_specs = tree_map(lambda a: core.modify_spec_for_hidden( a.sharding.spec, new_mesh), args) args = sharding_cast(args, in_specs) out = fun(*args, **kwargs) @@ -2767,26 +2767,26 @@ def auto_mode(fun, *, axes: str | tuple[str, ...], out_specs): @contextlib.contextmanager -def auto_mode_ctx(axes: str | tuple[str, ...]): - new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto) +def hidden_axes(axes: str | tuple[str, ...]): + new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden) with mesh_lib.set_abstract_mesh(new_mesh): yield -def user_mode(fun, *, axes: str | tuple[str, ...], in_specs): - new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.User) +def visible_mode(fun, *, axes: str | tuple[str, ...], in_specs): + new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Visible) def decorator(*args, **kwargs): with mesh_lib.set_abstract_mesh(new_mesh): args = sharding_cast(args, in_specs) out = fun(*args, **kwargs) - out_specs = tree_map(lambda o: core.modify_spec_for_auto( + out_specs = tree_map(lambda o: core.modify_spec_for_hidden( o.sharding.spec, mesh_lib.get_abstract_mesh()), out) return sharding_cast(out, out_specs) return decorator @contextlib.contextmanager -def user_mode_ctx(axes: str | tuple[str, ...]): - new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.User) +def visible_axes(axes: str | tuple[str, ...]): + new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Visible) with mesh_lib.set_abstract_mesh(new_mesh): yield diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 53a19fd7f..8150c9a90 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -78,10 +78,10 @@ def _check_axis_type_consistency(mesh, parsed_pspec): 'AxisTypes should be the same in a tuple subset of PartitionSpec:' f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis' f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})') - if mesh_lib.AxisTypes.Auto not in mesh.axis_types and None in parsed_pspec: + if mesh_lib.AxisTypes.Hidden not in mesh.axis_types and None in parsed_pspec: raise ValueError( f'PartitionSpec {parsed_pspec.get_partition_spec()} cannot contain' - ' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh' + ' `P.UNCONSTRAINED` when no mesh axis_types are `Hidden`. Got mesh' f' axis_types: {mesh.axis_types}') @@ -439,7 +439,7 @@ class NamedSharding(jsharding.Sharding): # TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra # parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)` def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): - if mesh._any_axis_auto: + if mesh._any_axis_hidden: dim_shardings, used_axes = [], [] # type: ignore for d in sdy_sharding.dimension_shardings: # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open? @@ -448,7 +448,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): used_axes.extend(d.axes) remaining_axes = set(mesh.axis_names) - set(used_axes) replicated_axes = tuple(r for r in remaining_axes - if mesh._name_to_type[r] == mesh_lib.AxisTypes.User) + if mesh._name_to_type[r] == mesh_lib.AxisTypes.Visible) return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings, sdy_sharding.logical_device_ids, replicated_axes) return sdy_sharding @@ -1774,9 +1774,9 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, for s in flatten_spec(sharding.spec): if sharding.mesh._name_to_type[s] in { - mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Collective}: + mesh_lib.AxisTypes.Hidden, mesh_lib.AxisTypes.Collective}: raise ValueError( - 'PartitionSpec cannot contain axis names that are of type Auto or' + 'PartitionSpec cannot contain axis names that are of type Hidden or' f' Collective. Got PartitionSpec: {sharding.spec} with axis name:' f' {s} or type: {sharding.mesh._name_to_type[s]}') return sharding diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 380f9b30a..69bb34669 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1549,12 +1549,12 @@ def with_and_without_mesh(f): ))(with_mesh_from_kwargs(f)) def with_user_mesh(sizes, names, axis_types=None): - axis_types = ({mesh_lib.AxisTypes.User: names} + axis_types = ({mesh_lib.AxisTypes.Visible: names} if axis_types is None else axis_types) def decorator(fn): def mesh_fn(*args, **kwargs): mesh = create_mesh(sizes, names, axis_types=axis_types) - with mesh_lib.set_mesh(mesh): + with mesh_lib.use_mesh(mesh): return fn(*args, **kwargs, mesh=mesh) return mesh_fn return decorator diff --git a/jax/sharding.py b/jax/sharding.py index 3c41439ef..c4e943f48 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -27,7 +27,11 @@ from jax._src.partition_spec import ( PartitionSpec as PartitionSpec, ) from jax._src.interpreters.pxla import Mesh as Mesh -from jax._src.mesh import AbstractMesh as AbstractMesh +from jax._src.mesh import ( + AbstractMesh as AbstractMesh, + AxisTypes as AxisTypes, + use_mesh as use_mesh +) _deprecations = { # Finalized 2024-10-01; remove after 2025-01-01. diff --git a/tests/array_test.py b/tests/array_test.py index bf7aa6488..8a4cb647b 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1260,7 +1260,6 @@ class ShardingTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, msg): jax.make_array_from_single_device_arrays(x.shape, s, [x, x]) - def test_gspmd_sharding_hash_eq(self): mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z')) ns = NamedSharding(mesh, P('x', 'y', 'z')) @@ -1299,6 +1298,21 @@ class ShardingTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, 'Mesh axis names cannot be None.'): jax.sharding.Mesh(jax.devices(), (None, 'x')) + def test_mesh_axis_types_mismatch(self): + with self.assertRaisesRegex( + ValueError, + 'Number of axis names in axis_types should match the number of' + ' axis_names'): + jtu.create_mesh((2, 1), ('x', 'y'), + axis_types={jax.sharding.AxisTypes.Hidden: 'x'}) + + with self.assertRaisesRegex( + ValueError, + 'Number of axis names in axis_types should match the number of' + ' axis_names in shape_tuple'): + jax.sharding.AbstractMesh((('x', 2), ('y', 1)), + axis_types={jax.sharding.AxisTypes.Hidden: 'x'}) + @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyShardingTest(jtu.JaxTestCase): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6abc83225..5ad79ed38 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -51,8 +51,8 @@ from jax._src import sharding_impls from jax._src.sharding_impls import ( AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding, SingleDeviceSharding, parse_flatten_op_sharding) -from jax._src.pjit import (pjit, sharding_cast, auto_mode, user_mode, - auto_mode_ctx, user_mode_ctx) +from jax._src.pjit import (pjit, sharding_cast, hidden_mode, visible_mode, + hidden_axes, visible_axes) from jax._src import mesh as mesh_lib from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes from jax._src.interpreters import pxla @@ -5125,7 +5125,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jtu.with_user_mesh((1,), 'x') def test_broadcasting_nary_error(self, mesh): mesh2 = Mesh([jax.devices()[0]], 'y', - axis_types={mesh_lib.AxisTypes.User: 'y'}) + axis_types={mesh_lib.AxisTypes.Visible: 'y'}) arr1 = jax.device_put(np.arange(8), NamedSharding(mesh, P())) arr2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) @@ -5688,7 +5688,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) self.check_wsc_in_lowered(f.lower(arr).as_text()) - @jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Auto: ('x', 'y')}) + @jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Hidden: ('x', 'y')}) def test_only_auto(self, mesh): np_inp = np.arange(16.).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -5708,7 +5708,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_auto_user(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), - axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')}) + axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')}) np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5720,16 +5720,16 @@ class ShardingInTypesTest(jtu.JaxTestCase): a = z @ x2 return a - with mesh_lib.set_mesh(mesh): + with mesh_lib.use_mesh(mesh): out = f(arr, arr.T) self.assertEqual(out.sharding, NamedSharding(mesh, P('x',))) lowered_text = f.lower(arr, arr.T).as_text() self.assertNotIn('unspecified_dims', lowered_text) mesh2 = jtu.create_mesh((2, 2), ('x', 'y'), - axis_types={mesh_lib.AxisTypes.User: 'x', - mesh_lib.AxisTypes.Auto: 'y'}) - with mesh_lib.set_mesh(mesh2): + axis_types={mesh_lib.AxisTypes.Visible: 'x', + mesh_lib.AxisTypes.Hidden: 'y'}) + with mesh_lib.use_mesh(mesh2): arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None))) out = f(arr, arr2) @@ -5741,9 +5741,9 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertTrue(lowered_text.count("unspecified_dims") == 5) mesh3 = jtu.create_mesh((2, 2), ('x', 'y'), - axis_types={mesh_lib.AxisTypes.User: 'y', - mesh_lib.AxisTypes.Auto: 'x'}) - with mesh_lib.set_mesh(mesh3): + axis_types={mesh_lib.AxisTypes.Visible: 'y', + mesh_lib.AxisTypes.Hidden: 'x'}) + with mesh_lib.use_mesh(mesh3): arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P('y', 'x'))) out = f(arr, arr2) @@ -5779,7 +5779,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def f(x): y = x * 2 - with auto_mode_ctx(axes=('x', 'y')): + with hidden_axes(axes=('x', 'y')): y = sharding_cast(y, P(None, None)) self.assertEqual(y.sharding.spec, P(None, None)) z = jnp.sin(y) @@ -5798,7 +5798,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')}) + axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')}) def test_full_auto_to_full_user(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5807,7 +5807,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def f(x): y = x * 2 - with user_mode_ctx(axes=('x', 'y')): + with visible_axes(axes=('x', 'y')): y = sharding_cast(y, P(None, 'y')) self.assertEqual(y.sharding.spec, P(None, 'y')) z = jnp.sin(y) @@ -5833,7 +5833,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def f(x): y = x * 2 - with auto_mode_ctx('x'): + with hidden_axes('x'): y = sharding_cast(y, P(None, 'y')) self.assertEqual(y.sharding.spec, P(None, 'y')) z = jnp.sin(y) @@ -5860,7 +5860,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def f(x, y): x = x * 2 - with auto_mode_ctx('x'): + with hidden_axes('x'): z = x @ y return z @@ -5871,9 +5871,9 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_sharding_cast_src_dst_mesh_mismatch(self): np_inp = np.arange(16.).reshape(8, 2) mesh = jtu.create_mesh((2, 1), ('x', 'y'), - axis_types={mesh_lib.AxisTypes.User: ('x', 'y')}) + axis_types={mesh_lib.AxisTypes.Visible: ('x', 'y')}) mesh2 = jtu.create_mesh((2, 1), ('a', 'b'), - axis_types={mesh_lib.AxisTypes.User: ('a', 'b')}) + axis_types={mesh_lib.AxisTypes.Visible: ('a', 'b')}) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) f = lambda x: sharding_cast(x, NamedSharding(mesh2, P('a', 'b'))) @@ -5881,7 +5881,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): ValueError, "Mesh shape of the input.*does not match"): f(arr) - with mesh_lib.set_mesh(mesh): + with mesh_lib.use_mesh(mesh): with self.assertRaisesRegex( ValueError, "Mesh shape of the input.*does not match"): jax.jit(f)(arr) @@ -5923,15 +5923,15 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def f(x): - auto_mesh = get_abstract_mesh().update_axis_types({AxisTypes.Auto: 'x'}) + auto_mesh = get_abstract_mesh().update_axis_types({AxisTypes.Hidden: 'x'}) with set_abstract_mesh(auto_mesh): x = sharding_cast(x, P(None, None)) return x - self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.User: 'x'}) + self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.Visible: 'x'}) out = f(arr) self.assertArraysEqual(out, np_inp) - self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'}) + self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'}) @jtu.with_user_mesh((2,), 'x') def test_inputs_different_context(self, mesh): @@ -5939,18 +5939,18 @@ class ShardingInTypesTest(jtu.JaxTestCase): s = NamedSharding(mesh, P('x')) arr = jax.device_put(np_inp, s) - auto_mesh = jax.make_mesh((2,), 'x', axis_types={AxisTypes.Auto: 'x'}) - with mesh_lib.set_mesh(auto_mesh): + auto_mesh = jax.make_mesh((2,), 'x', axis_types={AxisTypes.Hidden: 'x'}) + with mesh_lib.use_mesh(auto_mesh): arr2 = jnp.ones(8) - self.assertDictEqual(arr2.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'}) + self.assertDictEqual(arr2.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'}) @jax.jit def f(x, y): return x, y out1, out2 = f(arr, arr2) - self.assertDictEqual(out1.sharding.mesh.axis_types, {AxisTypes.User: 'x'}) - self.assertDictEqual(out2.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'}) + self.assertDictEqual(out1.sharding.mesh.axis_types, {AxisTypes.Visible: 'x'}) + self.assertDictEqual(out2.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'}) @jtu.with_user_mesh((2,), 'x') def test_output_different_context_error(self, mesh): @@ -5958,7 +5958,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x'))) auto_mesh = jax.make_mesh((2,), 'x', - axis_types={AxisTypes.Auto: 'x'}).abstract_mesh + axis_types={AxisTypes.Hidden: 'x'}).abstract_mesh @jax.jit def f(x, y): @@ -5972,17 +5972,17 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def g(x, y): - with auto_mode_ctx('x'): + with hidden_axes('x'): out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None)) return out with self.assertRaisesRegex( - ValueError, "PartitionSpec cannot contain axis names.*Auto"): + ValueError, "PartitionSpec cannot contain axis names.*Hidden"): g(arr1, arr2) @jtu.with_user_mesh((2, 2, 2), ('x', 'y', 'z'), - axis_types={AxisTypes.User: ('x', 'y'), - AxisTypes.Auto: 'z'}) + axis_types={AxisTypes.Visible: ('x', 'y'), + AxisTypes.Hidden: 'z'}) def test_out_sharding_mix_axis_types(self, mesh): np_inp = np.arange(16).reshape(4, 2, 2) s = NamedSharding(mesh, P('x', None, None)) @@ -6011,7 +6011,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(auto_mode, axes='x', out_specs=P('x', None)) + @partial(hidden_mode, axes='x', out_specs=P('x', None)) def h(y): self.assertEqual(y.sharding.spec, P(None, 'y')) z = jnp.sin(y) @@ -6035,13 +6035,13 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')}) + axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')}) def test_full_user_mode(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(user_mode, axes=('x', 'y'), in_specs=P('x', 'y')) + @partial(visible_mode, axes=('x', 'y'), in_specs=P('x', 'y')) def h(y): self.assertEqual(y.sharding.spec, P('x', 'y')) z = jnp.sin(y) @@ -6064,14 +6064,14 @@ class ShardingInTypesTest(jtu.JaxTestCase): core.jaxpr_as_fun(jaxpr)(arr) # doesn't crash @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types={mesh_lib.AxisTypes.User: 'x', - mesh_lib.AxisTypes.Auto: 'y'}) + axis_types={mesh_lib.AxisTypes.Visible: 'x', + mesh_lib.AxisTypes.Hidden: 'y'}) def test_mix_to_full_user_mode(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(user_mode, axes='y', in_specs=P('x', 'y')) + @partial(visible_mode, axes='y', in_specs=P('x', 'y')) def h(y): self.assertEqual(y.sharding.spec, P('x', 'y')) z = jnp.sin(y) @@ -6091,13 +6091,13 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')}) + axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')}) def test_full_auto_to_partial_user(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(user_mode, axes='y', in_specs=P(None, 'y')) + @partial(visible_mode, axes='y', in_specs=P(None, 'y')) def h(y): self.assertEqual(y.sharding.spec, P(None, 'y')) z = jnp.sin(y)