From 88d4bc3d45fab05ff7a6909f2db2d1e9963fea46 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 14 Mar 2025 11:47:33 -0700 Subject: [PATCH] Rename AxisTypes enum to AxisType PiperOrigin-RevId: 736935746 --- jax/_src/core.py | 4 +- jax/_src/mesh.py | 34 +++++++------- jax/_src/named_sharding.py | 4 +- jax/_src/pjit.py | 26 +++++------ jax/_src/sharding_impls.py | 6 +-- jax/_src/test_util.py | 2 +- jax/experimental/shard_map.py | 14 +++--- jax/sharding.py | 2 +- tests/array_test.py | 22 ++++----- tests/mutable_array_test.py | 4 +- tests/pjit_test.py | 86 +++++++++++++++++------------------ tests/shard_map_test.py | 14 +++--- 12 files changed, 109 insertions(+), 109 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 167244064..36ce2f004 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -39,7 +39,7 @@ from jax._src import config from jax._src import effects from jax._src import compute_on from jax._src import mesh as mesh_lib -from jax._src.mesh import AxisTypes +from jax._src.mesh import AxisType from jax._src.partition_spec import PartitionSpec as P from jax._src.errors import ( ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, @@ -1835,7 +1835,7 @@ def modify_spec_for_auto_manual(spec, mesh) -> P: temp_s = s[0] if isinstance(s, tuple) else s new_spec.append( None - if mesh._name_to_type[temp_s] in (AxisTypes.Auto, AxisTypes.Manual) + if mesh._name_to_type[temp_s] in (AxisType.Auto, AxisType.Manual) else s) return P(*new_spec) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 1e313d058..94e27a2ba 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -103,7 +103,7 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names) -class AxisTypes(enum.Enum): +class AxisType(enum.Enum): Auto = enum.auto() Explicit = enum.auto() Manual = enum.auto() @@ -112,10 +112,10 @@ class AxisTypes(enum.Enum): return self.name def _normalize_axis_types(axis_names, axis_types): - axis_types = ((AxisTypes.Auto,) * len(axis_names) + axis_types = ((AxisType.Auto,) * len(axis_names) if axis_types is None else axis_types) if not isinstance(axis_types, tuple): - assert isinstance(axis_types, AxisTypes), axis_types + assert isinstance(axis_types, AxisType), axis_types axis_types = (axis_types,) if len(axis_names) != len(axis_types): raise ValueError( @@ -123,12 +123,12 @@ def _normalize_axis_types(axis_names, axis_types): f" axis_names={axis_names} and axis_types={axis_types}") return axis_types -def all_axis_types_match(axis_types, ty: AxisTypes) -> bool: +def all_axis_types_match(axis_types, ty: AxisType) -> bool: if not axis_types: return False return all(t == ty for t in axis_types) -def any_axis_types_match(axis_types, ty: AxisTypes) -> bool: +def any_axis_types_match(axis_types, ty: AxisType) -> bool: if not axis_types: return False return any(t == ty for t in axis_types) @@ -137,42 +137,42 @@ def any_axis_types_match(axis_types, ty: AxisTypes) -> bool: class _BaseMesh: axis_names: tuple[MeshAxisName, ...] shape_tuple: tuple[tuple[str, int], ...] - _axis_types: tuple[AxisTypes, ...] + _axis_types: tuple[AxisType, ...] @property - def axis_types(self) -> tuple[AxisTypes, ...]: + def axis_types(self) -> tuple[AxisType, ...]: return self._axis_types @functools.cached_property def _are_all_axes_manual(self) -> bool: - return all_axis_types_match(self._axis_types, AxisTypes.Manual) + return all_axis_types_match(self._axis_types, AxisType.Manual) @functools.cached_property def _are_all_axes_auto(self) -> bool: - return all_axis_types_match(self._axis_types, AxisTypes.Auto) + return all_axis_types_match(self._axis_types, AxisType.Auto) @functools.cached_property def _are_all_axes_explicit(self) -> bool: - return all_axis_types_match(self._axis_types, AxisTypes.Explicit) + return all_axis_types_match(self._axis_types, AxisType.Explicit) @functools.cached_property def _are_all_axes_auto_or_manual(self) -> bool: if not self._axis_types: return False - return all(t == AxisTypes.Auto or t == AxisTypes.Manual + return all(t == AxisType.Auto or t == AxisType.Manual for t in self._axis_types) @functools.cached_property def _any_axis_manual(self) -> bool: - return any_axis_types_match(self._axis_types, AxisTypes.Manual) + return any_axis_types_match(self._axis_types, AxisType.Manual) @functools.cached_property def _any_axis_auto(self) -> bool: - return any_axis_types_match(self._axis_types, AxisTypes.Auto) + return any_axis_types_match(self._axis_types, AxisType.Auto) @functools.cached_property def _any_axis_explicit(self) -> bool: - return any_axis_types_match(self._axis_types, AxisTypes.Explicit) + return any_axis_types_match(self._axis_types, AxisType.Explicit) @functools.cached_property def _axis_types_dict(self): @@ -247,7 +247,7 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator): def __new__(cls, devices: np.ndarray | Sequence[xc.Device], axis_names: str | Sequence[MeshAxisName], *, - axis_types: tuple[AxisTypes, ...] | None = None): + axis_types: tuple[AxisType, ...] | None = None): if not isinstance(devices, np.ndarray): devices = np.array(devices) if isinstance(axis_names, str): @@ -443,7 +443,7 @@ class AbstractMesh(_BaseMesh): """ def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...], - *, axis_types: AxisTypes | tuple[AxisTypes, ...] | None = None): + *, axis_types: AxisType | tuple[AxisType, ...] | None = None): self.axis_sizes = axis_sizes self.axis_names = axis_names self._size = math.prod(self.axis_sizes) if self.axis_sizes else 0 @@ -494,7 +494,7 @@ class AbstractMesh(_BaseMesh): def abstract_mesh(self): return self - def update_axis_types(self, name_to_type: dict[MeshAxisName, AxisTypes]): + def update_axis_types(self, name_to_type: dict[MeshAxisName, AxisType]): new_axis_types = tuple(name_to_type[n] if n in name_to_type else a for n, a in zip(self.axis_names, self._axis_types)) return AbstractMesh(self.axis_sizes, self.axis_names, diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index f05e83f08..5accdd880 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -409,7 +409,7 @@ def named_sharding_to_xla_hlo_sharding( special_axes = {} mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items() - if t == mesh_lib.AxisTypes.Manual} + if t == mesh_lib.AxisType.Manual} manual_axes = self._manual_axes.union(mesh_manual_axes) if manual_axes: axis_names = self.mesh.axis_names @@ -564,7 +564,7 @@ def _check_mesh_resource_axis(mesh, pspec, _manual_axes): 'AxisTypes should be the same in a tuple subset of PartitionSpec:' f' {pspec}. Got subset {p} with axis' f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})') - if (mesh_lib.AxisTypes.Auto not in mesh._axis_types_dict and + if (mesh_lib.AxisType.Auto not in mesh._axis_types_dict and PartitionSpec.UNCONSTRAINED in pspec): raise ValueError( f'{pspec} cannot contain' diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b0308e436..f7a4361ff 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2488,7 +2488,7 @@ def check_shardings_are_auto(shardings_flat): if not isinstance(s, NamedSharding): continue mesh = s.mesh.abstract_mesh - if not all(mesh._name_to_type[i] == mesh_lib.AxisTypes.Auto + if not all(mesh._name_to_type[i] == mesh_lib.AxisType.Auto for axes in s.spec if axes is not PartitionSpec.UNCONSTRAINED and axes is not None for i in (axes if isinstance(axes, tuple) else (axes,))): @@ -2712,7 +2712,7 @@ def _mesh_cast_abstract_eval(aval, dst_sharding): if (src_sharding.mesh._axis_types_dict == dst_sharding.mesh._axis_types_dict and src_sharding.spec != dst_sharding.spec): raise ValueError( - 'mesh_cast should only be used when AxisTypes changes between the' + 'mesh_cast should only be used when AxisType changes between the' ' input mesh and the target mesh. Got src' f' axis_types={src_sharding.mesh._axis_types_dict} and dst' f' axis_types={dst_sharding.mesh._axis_types_dict}. To reshard between' @@ -2723,12 +2723,12 @@ def _mesh_cast_abstract_eval(aval, dst_sharding): if s is None and d is None: continue if s is None and d is not None: - assert (src_sharding.mesh._name_to_type[d] == mesh_lib.AxisTypes.Auto - and dst_sharding.mesh._name_to_type[d] == mesh_lib.AxisTypes.Explicit) + assert (src_sharding.mesh._name_to_type[d] == mesh_lib.AxisType.Auto + and dst_sharding.mesh._name_to_type[d] == mesh_lib.AxisType.Explicit) continue if s is not None and d is None: - assert (src_sharding.mesh._name_to_type[s] == mesh_lib.AxisTypes.Explicit - and dst_sharding.mesh._name_to_type[s] == mesh_lib.AxisTypes.Auto) + assert (src_sharding.mesh._name_to_type[s] == mesh_lib.AxisType.Explicit + and dst_sharding.mesh._name_to_type[s] == mesh_lib.AxisType.Auto) continue if d != s: raise ValueError( @@ -2821,7 +2821,7 @@ batching.skippable_batchers[reshard_p] = lambda _: () # -------------------- auto and user mode ------------------------- def _get_new_mesh(axes: str | tuple[str, ...] | None, - axis_type: mesh_lib.AxisTypes, name: str, + axis_type: mesh_lib.AxisType, name: str, error_on_manual_to_auto_explict=False): cur_mesh = mesh_lib.get_abstract_mesh() # TODO(yashkatariya): Maybe allow fetching mesh from the args to enable @@ -2837,8 +2837,8 @@ def _get_new_mesh(axes: str | tuple[str, ...] | None, axes = (axes,) for a in axes: if (error_on_manual_to_auto_explict and - cur_mesh._name_to_type[a] == mesh_lib.AxisTypes.Manual and - axis_type in {mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Explicit}): + cur_mesh._name_to_type[a] == mesh_lib.AxisType.Manual and + axis_type in {mesh_lib.AxisType.Auto, mesh_lib.AxisType.Explicit}): raise NotImplementedError( 'Going from `Manual` AxisType to `Auto` or `Explicit` AxisType is not' ' allowed. Please file a bug at https://github.com/jax-ml/jax/issues' @@ -2855,7 +2855,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None, raise TypeError("Missing required keyword argument: 'out_shardings'") else: _out_shardings = out_shardings - new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'auto_axes', + new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Auto, 'auto_axes', error_on_manual_to_auto_explict=True) with mesh_lib.use_abstract_mesh(new_mesh): in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual( @@ -2867,7 +2867,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None, @contextlib.contextmanager def use_auto_axes(*axes): - new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'use_auto_axes') + new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Auto, 'use_auto_axes') with mesh_lib.use_abstract_mesh(new_mesh): yield @@ -2882,7 +2882,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None, raise TypeError("Missing required keyword argument: 'in_shardings'") else: _in_shardings = in_shardings - new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit, 'explicit_axes', + new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, 'explicit_axes', error_on_manual_to_auto_explict=True) with mesh_lib.use_abstract_mesh(new_mesh): args = mesh_cast(args, _in_shardings) @@ -2894,7 +2894,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None, @contextlib.contextmanager def use_explicit_axes(*axes): - new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit, + new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, 'use_explicit_axes') with mesh_lib.use_abstract_mesh(new_mesh): yield diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 25a002efa..019411c77 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -108,7 +108,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): used_axes.extend(d.axes) remaining_axes = set(mesh.axis_names) - set(used_axes) replicated_axes = tuple(r for r in remaining_axes - if mesh._name_to_type[r] == mesh_lib.AxisTypes.Explicit) + if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings, sdy_sharding.logical_device_ids, replicated_axes) return sdy_sharding @@ -1301,7 +1301,7 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, if s is None: continue if sharding.mesh._name_to_type[s] in { - mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Manual}: + mesh_lib.AxisType.Auto, mesh_lib.AxisType.Manual}: raise ValueError( f'PartitionSpec passed to {api_name} cannot contain axis' ' names that are of type Auto or Manual. Got PartitionSpec:' @@ -1313,7 +1313,7 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], *, devices: Sequence[xc.Device] | None = None, - axis_types: tuple[mesh_lib.AxisTypes, ...] | None = None + axis_types: tuple[mesh_lib.AxisType, ...] | None = None ) -> mesh_lib.Mesh: """Creates an efficient mesh with the shape and axis names specified. diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 7db88f447..c55dc2a56 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1576,7 +1576,7 @@ def with_and_without_mesh(f): ))(with_mesh_from_kwargs(f)) def with_user_mesh(sizes, names, axis_types=None): - axis_types = ((mesh_lib.AxisTypes.Explicit,) * len(names) + axis_types = ((mesh_lib.AxisType.Explicit,) * len(names) if axis_types is None else axis_types) def decorator(fn): def mesh_fn(*args, **kwargs): diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index c6a2d8d7a..a51d1fc7d 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -47,7 +47,7 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src import util from jax._src.core import Tracer -from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, use_abstract_mesh, +from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh, get_abstract_mesh) from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lax import (lax, parallel as lax_parallel, slicing, @@ -487,21 +487,21 @@ def _as_manual_mesh(mesh, auto: frozenset): cur_mesh = mesh explicit_axes, auto_axes = set(), set() # type: ignore for a in auto: - if cur_mesh._name_to_type[a] == AxisTypes.Auto: + if cur_mesh._name_to_type[a] == AxisType.Auto: auto_axes.add(a) else: - assert cur_mesh._name_to_type[a] == AxisTypes.Explicit + assert cur_mesh._name_to_type[a] == AxisType.Explicit explicit_axes.add(a) new_axis_types = [] for n in mesh.axis_names: if n in manual_axes: - new_axis_types.append(AxisTypes.Manual) + new_axis_types.append(AxisType.Manual) elif n in auto_axes: - new_axis_types.append(AxisTypes.Auto) + new_axis_types.append(AxisType.Auto) else: assert n in explicit_axes - new_axis_types.append(AxisTypes.Explicit) + new_axis_types.append(AxisType.Explicit) return AbstractMesh(mesh.axis_sizes, mesh.axis_names, axis_types=tuple(new_axis_types)) @@ -1901,7 +1901,7 @@ def _all_newly_manual_mesh_names( vmap_spmd_names = set(axis_env.spmd_axis_names) if not (ctx_mesh := get_abstract_mesh()).empty: mesh = ctx_mesh - already_manual_names = set(ctx_mesh._axis_types_dict.get(AxisTypes.Manual, ())) + already_manual_names = set(ctx_mesh._axis_types_dict.get(AxisType.Manual, ())) else: # TODO(mattjj): remove this mechanism when we revise mesh scopes already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names diff --git a/jax/sharding.py b/jax/sharding.py index 58c099d1f..6ddc81584 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -31,7 +31,7 @@ from jax._src.partition_spec import ( from jax._src.interpreters.pxla import Mesh as Mesh from jax._src.mesh import ( AbstractMesh as AbstractMesh, - AxisTypes as AxisTypes, + AxisType as AxisType, ) _deprecations = { diff --git a/tests/array_test.py b/tests/array_test.py index 96a4a6517..4184c835a 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -32,7 +32,7 @@ from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc from jax._src.lib.mlir import dialects, ir from jax._src.util import safe_zip -from jax._src.mesh import AxisTypes +from jax._src.mesh import AxisType from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import ( _op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map, @@ -1356,40 +1356,40 @@ class ShardingTest(jtu.JaxTestCase): ValueError, 'Number of axis names should match the number of axis_types'): jtu.create_mesh((2, 1), ('x', 'y'), - axis_types=jax.sharding.AxisTypes.Auto) + axis_types=jax.sharding.AxisType.Auto) with self.assertRaisesRegex( ValueError, 'Number of axis names should match the number of axis_types'): jax.sharding.AbstractMesh((2, 1), ('x', 'y'), - axis_types=jax.sharding.AxisTypes.Auto) + axis_types=jax.sharding.AxisType.Auto) def test_make_mesh_axis_types(self): - Auto, Explicit, Manual = AxisTypes.Auto, AxisTypes.Explicit, AxisTypes.Manual + Auto, Explicit, Manual = AxisType.Auto, AxisType.Explicit, AxisType.Manual mesh1 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto) mesh2 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto) self.assertEqual(mesh1, mesh2) mesh = jax.make_mesh((1, 1), ('x', 'y')) - self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Auto: ('x', 'y')}) + self.assertDictEqual(mesh._axis_types_dict, {AxisType.Auto: ('x', 'y')}) mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), axis_types=(Explicit, Auto, Manual)) self.assertDictEqual( - mesh._axis_types_dict, {AxisTypes.Auto: ('y',), AxisTypes.Explicit: ('x',), - AxisTypes.Manual: ('z',)}) + mesh._axis_types_dict, {AxisType.Auto: ('y',), AxisType.Explicit: ('x',), + AxisType.Manual: ('z',)}) mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), axis_types=(Explicit, Explicit, Manual)) - self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Explicit: ('x', 'y'), - AxisTypes.Manual: ('z',)}) + self.assertDictEqual(mesh._axis_types_dict, {AxisType.Explicit: ('x', 'y'), + AxisType.Manual: ('z',)}) mesh = jax.make_mesh((1, 1), ('x', 'y'), axis_types=(Explicit, Explicit)) - self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Explicit: ('x', 'y')}) + self.assertDictEqual(mesh._axis_types_dict, {AxisType.Explicit: ('x', 'y')}) mesh = jax.make_mesh((1,), 'model', axis_types=Manual) - self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Manual: ('model',)}) + self.assertDictEqual(mesh._axis_types_dict, {AxisType.Manual: ('model',)}) with self.assertRaisesRegex( ValueError, diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 4c1b19cb1..e962653ed 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -21,7 +21,7 @@ import jax from jax._src import core from jax._src import config from jax._src import test_util as jtu -from jax.sharding import NamedSharding, PartitionSpec as P, AxisTypes +from jax.sharding import NamedSharding, PartitionSpec as P, AxisType import jax.numpy as jnp from jax._src.state.types import (RefEffect) @@ -212,7 +212,7 @@ class MutableArrayTest(jtu.JaxTestCase): def test_explicit_sharding_after_indexing(self): # https://github.com/jax-ml/jax/issues/26936 mesh = jtu.create_mesh((1, 1), ('x', 'y'), - axis_types=(AxisTypes.Explicit,) * 2) + axis_types=(AxisType.Explicit,) * 2) sharding = NamedSharding(mesh, P('x', 'y')) @jax.jit diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5d4e1939d..e3ac00133 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -57,7 +57,7 @@ from jax._src.pjit import (pjit, mesh_cast, auto_axes, explicit_axes, use_auto_axes, use_explicit_axes, reshard) from jax._src.named_sharding import DuplicateSpecError from jax._src import mesh as mesh_lib -from jax._src.mesh import AxisTypes +from jax._src.mesh import AxisType from jax._src.interpreters import pxla from jax._src.lib.mlir import dialects from jax._src import xla_bridge @@ -5254,7 +5254,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jtu.with_user_mesh((1,), 'x') def test_broadcasting_nary_error(self, mesh): mesh2 = Mesh([jax.devices()[0]], 'y', - axis_types=(mesh_lib.AxisTypes.Explicit,)) + axis_types=(mesh_lib.AxisType.Explicit,)) arr1 = jax.device_put(np.arange(8), NamedSharding(mesh, P())) arr2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) @@ -5572,7 +5572,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_explicit_mode_no_context_mesh(self): mesh = jtu.create_mesh((4, 2), ('x', 'y'), - axis_types=(AxisTypes.Explicit,) * 2) + axis_types=(AxisType.Explicit,) * 2) abstract_mesh = mesh.abstract_mesh np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -5597,7 +5597,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_auto_mode_no_context_mesh(self): mesh = jtu.create_mesh((4, 2), ('x', 'y'), - axis_types=(AxisTypes.Auto,) * 2) + axis_types=(AxisType.Auto,) * 2) abstract_mesh = mesh.abstract_mesh np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -5627,7 +5627,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, - 'mesh_cast should only be used when AxisTypes changes between the input' + 'mesh_cast should only be used when AxisType changes between the input' ' mesh and the target mesh'): f(arr) @@ -5637,18 +5637,18 @@ class ShardingInTypesTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, - 'mesh_cast should only be used when AxisTypes changes between the input' + 'mesh_cast should only be used when AxisType changes between the input' ' mesh and the target mesh'): g(arr) @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types=(AxisTypes.Explicit, AxisTypes.Auto)) + axis_types=(AxisType.Explicit, AxisType.Auto)) def test_mesh_cast_explicit_data_movement_error(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) full_user_mesh = mesh_lib.AbstractMesh( - (2, 2), ('x', 'y'), axis_types=(AxisTypes.Explicit,) * 2) + (2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) @jax.jit def f(x): @@ -5912,7 +5912,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) self.check_wsc_in_lowered(f.lower(arr).as_text()) - @jtu.with_user_mesh((2, 2), ('x', 'y'), (mesh_lib.AxisTypes.Auto,) * 2) + @jtu.with_user_mesh((2, 2), ('x', 'y'), (mesh_lib.AxisType.Auto,) * 2) def test_only_auto(self, mesh): np_inp = np.arange(16.).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -5932,7 +5932,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_auto_user(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), - axis_types=(mesh_lib.AxisTypes.Auto,) * 2) + axis_types=(mesh_lib.AxisType.Auto,) * 2) np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5951,8 +5951,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertNotIn('unspecified_dims', lowered_text) mesh2 = jtu.create_mesh((2, 2), ('x', 'y'), - axis_types=(mesh_lib.AxisTypes.Explicit, - mesh_lib.AxisTypes.Auto)) + axis_types=(mesh_lib.AxisType.Explicit, + mesh_lib.AxisType.Auto)) with jax.sharding.use_mesh(mesh2): arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None))) @@ -5965,8 +5965,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertTrue(lowered_text.count("unspecified_dims") == 5) mesh3 = jtu.create_mesh((2, 2), ('x', 'y'), - axis_types=(mesh_lib.AxisTypes.Auto, - mesh_lib.AxisTypes.Explicit)) + axis_types=(mesh_lib.AxisType.Auto, + mesh_lib.AxisType.Explicit)) with jax.sharding.use_mesh(mesh3): arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P(None, 'x'))) @@ -6022,7 +6022,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types=(mesh_lib.AxisTypes.Auto,) * 2) + axis_types=(mesh_lib.AxisType.Auto,) * 2) def test_full_auto_to_full_user(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6132,11 +6132,11 @@ class ShardingInTypesTest(jtu.JaxTestCase): return x self.assertDictEqual(arr.sharding.mesh._axis_types_dict, - {AxisTypes.Explicit: ('x',)}) + {AxisType.Explicit: ('x',)}) out = f(arr) self.assertArraysEqual(out, np_inp) self.assertDictEqual(out.sharding.mesh._axis_types_dict, - {AxisTypes.Auto: ('x',)}) + {AxisType.Auto: ('x',)}) @jtu.with_user_mesh((2,), 'x') def test_inputs_different_context(self, mesh): @@ -6144,11 +6144,11 @@ class ShardingInTypesTest(jtu.JaxTestCase): s = NamedSharding(mesh, P('x')) arr = jax.device_put(np_inp, s) - auto_mesh = jax.make_mesh((2,), 'x', axis_types=(AxisTypes.Auto,)) + auto_mesh = jax.make_mesh((2,), 'x', axis_types=(AxisType.Auto,)) with jax.sharding.use_mesh(auto_mesh): arr2 = jnp.ones(8) self.assertDictEqual(arr2.sharding.mesh._axis_types_dict, - {AxisTypes.Auto: ('x',)}) + {AxisType.Auto: ('x',)}) @jax.jit def f(x, y): @@ -6156,16 +6156,16 @@ class ShardingInTypesTest(jtu.JaxTestCase): out1, out2 = f(arr, arr2) self.assertDictEqual(out1.sharding.mesh._axis_types_dict, - {AxisTypes.Explicit: ('x',)}) + {AxisType.Explicit: ('x',)}) self.assertDictEqual(out2.sharding.mesh._axis_types_dict, - {AxisTypes.Auto: ('x',)}) + {AxisType.Auto: ('x',)}) @jtu.with_user_mesh((2,), 'x') def test_output_different_context_error(self, mesh): np_inp1 = np.arange(16).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x'))) - auto_mesh = jax.make_mesh((2,), 'x', axis_types=(AxisTypes.Auto,)).abstract_mesh + auto_mesh = jax.make_mesh((2,), 'x', axis_types=(AxisType.Auto,)).abstract_mesh @jax.jit def f(x, y): @@ -6188,8 +6188,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): g(arr1, arr2) @jtu.with_user_mesh((2, 2, 2), ('x', 'y', 'z'), - axis_types=(AxisTypes.Explicit, AxisTypes.Explicit, - AxisTypes.Auto)) + axis_types=(AxisType.Explicit, AxisType.Explicit, + AxisType.Auto)) def test_out_sharding_mix_axis_types(self, mesh): np_inp = np.arange(16).reshape(4, 2, 2) s = NamedSharding(mesh, P('x', None, None)) @@ -6267,13 +6267,13 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_aval_spec_explicit_auto_complete(self): abstract_mesh = mesh_lib.AbstractMesh( - (2,), 'x', axis_types=AxisTypes.Explicit) + (2,), 'x', axis_types=AxisType.Explicit) s = NamedSharding(abstract_mesh, P('x')) out = core.ShapedArray((8, 2), jnp.int32, sharding=s) self.assertEqual(out.sharding.spec, P('x', None)) @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types=(mesh_lib.AxisTypes.Auto,) * 2) + axis_types=(mesh_lib.AxisType.Auto,) * 2) def test_full_user_mode(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6403,8 +6403,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): f(arr1, arr2, arr3) @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types=(mesh_lib.AxisTypes.Explicit, - mesh_lib.AxisTypes.Auto)) + axis_types=(mesh_lib.AxisType.Explicit, + mesh_lib.AxisType.Auto)) def test_mix_to_full_user_mode(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6430,7 +6430,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types=(mesh_lib.AxisTypes.Auto,) * 2) + axis_types=(mesh_lib.AxisType.Auto,) * 2) def test_full_auto_to_partial_user(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6545,7 +6545,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_auto_axes_top_level(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), - axis_types=(AxisTypes.Explicit,) * 2) + axis_types=(AxisType.Explicit,) * 2) np_inp = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) @@ -6567,7 +6567,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_explicit_axes_top_level(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), - axis_types=(AxisTypes.Auto,) * 2) + axis_types=(AxisType.Auto,) * 2) np_inp = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) @@ -6590,7 +6590,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_reshard_eager_mode(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), - axis_types=(AxisTypes.Explicit,) * 2) + axis_types=(AxisType.Explicit,) * 2) np_inp = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) @@ -6626,7 +6626,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types=(AxisTypes.Auto,) * 2) + axis_types=(AxisType.Auto,) * 2) def test_full_visible_outside_jit(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6817,7 +6817,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.shape, (4, 8)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x'))) - @jtu.with_user_mesh((2,), ('x',), axis_types=AxisTypes.Auto) + @jtu.with_user_mesh((2,), ('x',), axis_types=AxisType.Auto) def test_shmap_close_over(self, mesh): const = jnp.arange(8) def f(): @@ -6828,7 +6828,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): jax.jit(shmap_f)() # doesn't crash @jtu.with_user_mesh((2, 2), ('x', 'y'), - axis_types=(AxisTypes.Auto,) * 2) + axis_types=(AxisType.Auto,) * 2) def test_shmap_close_over_partial_auto(self, mesh): const = jnp.arange(8) def f(): @@ -6861,7 +6861,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): jax.lax.with_sharding_constraint(np.arange(8), s) s = NamedSharding(Mesh(mesh.devices, mesh.axis_names, - axis_types=(AxisTypes.Explicit, AxisTypes.Auto)), + axis_types=(AxisType.Explicit, AxisType.Auto)), P('x', P.UNCONSTRAINED)) with self.assertRaisesRegex( ValueError, @@ -6876,7 +6876,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_pspec_einsum_no_context_mesh(self): mesh = jtu.create_mesh((1, 1), ('x', 'y'), - axis_types=(AxisTypes.Explicit,) * 2) + axis_types=(AxisType.Explicit,) * 2) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', None))) @@ -6891,7 +6891,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): f(arr, arr2) @jtu.with_user_mesh((2, 1), ('x', 'y'), - axis_types=(AxisTypes.Auto,) * 2) + axis_types=(AxisType.Auto,) * 2) def test_error_on_canonicalize_under_auto_mode(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -6968,7 +6968,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', 'y'))) @jtu.with_user_mesh((2, 1), ('x', 'y'), - axis_types=(AxisTypes.Auto,) * 2) + axis_types=(AxisType.Auto,) * 2) def test_axes_api_error_manual_to_auto_explicit(self, mesh): def g(x): return auto_axes(lambda a: a * 2, axes=('x', 'y'), @@ -7037,7 +7037,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.sharding, NamedSharding(mesh, expected_spec)) def test_auto_axes_computation_follows_data_error(self): - mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisTypes.Explicit,)) + mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) s = NamedSharding(mesh, P('x')) arr = jax.device_put(np.arange(8), s) @@ -7050,7 +7050,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_divisbility_aval_error(self): abstract_mesh = mesh_lib.AbstractMesh( - (2,), ('x',), axis_types=AxisTypes.Explicit) + (2,), ('x',), axis_types=AxisType.Explicit) s = NamedSharding(abstract_mesh, P('x')) with self.assertRaisesRegex( ValueError, 'does not evenly divide the dimension size'): @@ -7082,7 +7082,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) def test_set_mesh(self): - mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisTypes.Explicit,)) + mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) prev_mesh = config.device_context.value prev_abstract_mesh = config.abstract_mesh_context_manager.value try: @@ -7104,7 +7104,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) - @jtu.with_user_mesh((2,), ('x',), axis_types=AxisTypes.Auto) + @jtu.with_user_mesh((2,), ('x',), axis_types=AxisType.Auto) def test_explicit_axes_late_bind(self, mesh): @explicit_axes def f(x): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index e9157759b..f8d5a11e8 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -40,7 +40,7 @@ from jax._src import test_util as jtu from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals -from jax._src.mesh import AxisTypes +from jax._src.mesh import AxisType from jax._src.interpreters import partial_eval as pe from jax._src import linear_util as lu from jax._src import tree_util @@ -1891,7 +1891,7 @@ class ShardMapTest(jtu.JaxTestCase): def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, - {AxisTypes.Manual: ('i',), AxisTypes.Auto: ('j',)}) + {AxisType.Manual: ('i',), AxisType.Auto: ('j',)}) x = jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P(None, 'j'))) return x * x @@ -1923,11 +1923,11 @@ class ShardMapTest(jtu.JaxTestCase): def test_partial_auto_explicit_no_use_mesh(self): mesh = jtu.create_mesh((2, 2), ('i', 'j'), - axis_types=(AxisTypes.Explicit,) * 2) + axis_types=(AxisType.Explicit,) * 2) def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, - {AxisTypes.Manual: ('i',), AxisTypes.Explicit: ('j',)}) + {AxisType.Manual: ('i',), AxisType.Explicit: ('j',)}) self.assertEqual(x.aval.sharding.spec, P(None, 'j')) out = x * x self.assertEqual(out.aval.sharding.spec, P(None, 'j')) @@ -1953,7 +1953,7 @@ class ShardMapTest(jtu.JaxTestCase): def test_partial_auto_explicit(self, mesh): def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, - {AxisTypes.Manual: ('i',), AxisTypes.Explicit: ('j',)}) + {AxisType.Manual: ('i',), AxisType.Explicit: ('j',)}) self.assertEqual(x.aval.sharding.spec, P(None, 'j')) out = x * x self.assertEqual(out.aval.sharding.spec, P(None, 'j')) @@ -1997,8 +1997,8 @@ class ShardMapTest(jtu.JaxTestCase): def test_partial_auto_explicit_multi_explicit(self, mesh): def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, - {AxisTypes.Manual: ('i', 'j'), - AxisTypes.Explicit: ('k', 'l')}) + {AxisType.Manual: ('i', 'j'), + AxisType.Explicit: ('k', 'l')}) self.assertEqual(x.aval.sharding.spec, P(None, None, 'k', 'l')) out = x.T self.assertEqual(out.aval.sharding.spec, P('l', 'k', None, None))