Rename AxisTypes enum to AxisType

PiperOrigin-RevId: 736935746
This commit is contained in:
Yash Katariya 2025-03-14 11:47:33 -07:00 committed by jax authors
parent bdb6d03322
commit 88d4bc3d45
12 changed files with 109 additions and 109 deletions

View File

@ -39,7 +39,7 @@ from jax._src import config
from jax._src import effects
from jax._src import compute_on
from jax._src import mesh as mesh_lib
from jax._src.mesh import AxisTypes
from jax._src.mesh import AxisType
from jax._src.partition_spec import PartitionSpec as P
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
@ -1835,7 +1835,7 @@ def modify_spec_for_auto_manual(spec, mesh) -> P:
temp_s = s[0] if isinstance(s, tuple) else s
new_spec.append(
None
if mesh._name_to_type[temp_s] in (AxisTypes.Auto, AxisTypes.Manual)
if mesh._name_to_type[temp_s] in (AxisType.Auto, AxisType.Manual)
else s)
return P(*new_spec)

View File

@ -103,7 +103,7 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names)
class AxisTypes(enum.Enum):
class AxisType(enum.Enum):
Auto = enum.auto()
Explicit = enum.auto()
Manual = enum.auto()
@ -112,10 +112,10 @@ class AxisTypes(enum.Enum):
return self.name
def _normalize_axis_types(axis_names, axis_types):
axis_types = ((AxisTypes.Auto,) * len(axis_names)
axis_types = ((AxisType.Auto,) * len(axis_names)
if axis_types is None else axis_types)
if not isinstance(axis_types, tuple):
assert isinstance(axis_types, AxisTypes), axis_types
assert isinstance(axis_types, AxisType), axis_types
axis_types = (axis_types,)
if len(axis_names) != len(axis_types):
raise ValueError(
@ -123,12 +123,12 @@ def _normalize_axis_types(axis_names, axis_types):
f" axis_names={axis_names} and axis_types={axis_types}")
return axis_types
def all_axis_types_match(axis_types, ty: AxisTypes) -> bool:
def all_axis_types_match(axis_types, ty: AxisType) -> bool:
if not axis_types:
return False
return all(t == ty for t in axis_types)
def any_axis_types_match(axis_types, ty: AxisTypes) -> bool:
def any_axis_types_match(axis_types, ty: AxisType) -> bool:
if not axis_types:
return False
return any(t == ty for t in axis_types)
@ -137,42 +137,42 @@ def any_axis_types_match(axis_types, ty: AxisTypes) -> bool:
class _BaseMesh:
axis_names: tuple[MeshAxisName, ...]
shape_tuple: tuple[tuple[str, int], ...]
_axis_types: tuple[AxisTypes, ...]
_axis_types: tuple[AxisType, ...]
@property
def axis_types(self) -> tuple[AxisTypes, ...]:
def axis_types(self) -> tuple[AxisType, ...]:
return self._axis_types
@functools.cached_property
def _are_all_axes_manual(self) -> bool:
return all_axis_types_match(self._axis_types, AxisTypes.Manual)
return all_axis_types_match(self._axis_types, AxisType.Manual)
@functools.cached_property
def _are_all_axes_auto(self) -> bool:
return all_axis_types_match(self._axis_types, AxisTypes.Auto)
return all_axis_types_match(self._axis_types, AxisType.Auto)
@functools.cached_property
def _are_all_axes_explicit(self) -> bool:
return all_axis_types_match(self._axis_types, AxisTypes.Explicit)
return all_axis_types_match(self._axis_types, AxisType.Explicit)
@functools.cached_property
def _are_all_axes_auto_or_manual(self) -> bool:
if not self._axis_types:
return False
return all(t == AxisTypes.Auto or t == AxisTypes.Manual
return all(t == AxisType.Auto or t == AxisType.Manual
for t in self._axis_types)
@functools.cached_property
def _any_axis_manual(self) -> bool:
return any_axis_types_match(self._axis_types, AxisTypes.Manual)
return any_axis_types_match(self._axis_types, AxisType.Manual)
@functools.cached_property
def _any_axis_auto(self) -> bool:
return any_axis_types_match(self._axis_types, AxisTypes.Auto)
return any_axis_types_match(self._axis_types, AxisType.Auto)
@functools.cached_property
def _any_axis_explicit(self) -> bool:
return any_axis_types_match(self._axis_types, AxisTypes.Explicit)
return any_axis_types_match(self._axis_types, AxisType.Explicit)
@functools.cached_property
def _axis_types_dict(self):
@ -247,7 +247,7 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator):
def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
axis_names: str | Sequence[MeshAxisName], *,
axis_types: tuple[AxisTypes, ...] | None = None):
axis_types: tuple[AxisType, ...] | None = None):
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
if isinstance(axis_names, str):
@ -443,7 +443,7 @@ class AbstractMesh(_BaseMesh):
"""
def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...],
*, axis_types: AxisTypes | tuple[AxisTypes, ...] | None = None):
*, axis_types: AxisType | tuple[AxisType, ...] | None = None):
self.axis_sizes = axis_sizes
self.axis_names = axis_names
self._size = math.prod(self.axis_sizes) if self.axis_sizes else 0
@ -494,7 +494,7 @@ class AbstractMesh(_BaseMesh):
def abstract_mesh(self):
return self
def update_axis_types(self, name_to_type: dict[MeshAxisName, AxisTypes]):
def update_axis_types(self, name_to_type: dict[MeshAxisName, AxisType]):
new_axis_types = tuple(name_to_type[n] if n in name_to_type else a
for n, a in zip(self.axis_names, self._axis_types))
return AbstractMesh(self.axis_sizes, self.axis_names,

View File

@ -409,7 +409,7 @@ def named_sharding_to_xla_hlo_sharding(
special_axes = {}
mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items()
if t == mesh_lib.AxisTypes.Manual}
if t == mesh_lib.AxisType.Manual}
manual_axes = self._manual_axes.union(mesh_manual_axes)
if manual_axes:
axis_names = self.mesh.axis_names
@ -564,7 +564,7 @@ def _check_mesh_resource_axis(mesh, pspec, _manual_axes):
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
f' {pspec}. Got subset {p} with axis'
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
if (mesh_lib.AxisTypes.Auto not in mesh._axis_types_dict and
if (mesh_lib.AxisType.Auto not in mesh._axis_types_dict and
PartitionSpec.UNCONSTRAINED in pspec):
raise ValueError(
f'{pspec} cannot contain'

View File

@ -2488,7 +2488,7 @@ def check_shardings_are_auto(shardings_flat):
if not isinstance(s, NamedSharding):
continue
mesh = s.mesh.abstract_mesh
if not all(mesh._name_to_type[i] == mesh_lib.AxisTypes.Auto
if not all(mesh._name_to_type[i] == mesh_lib.AxisType.Auto
for axes in s.spec
if axes is not PartitionSpec.UNCONSTRAINED and axes is not None
for i in (axes if isinstance(axes, tuple) else (axes,))):
@ -2712,7 +2712,7 @@ def _mesh_cast_abstract_eval(aval, dst_sharding):
if (src_sharding.mesh._axis_types_dict == dst_sharding.mesh._axis_types_dict
and src_sharding.spec != dst_sharding.spec):
raise ValueError(
'mesh_cast should only be used when AxisTypes changes between the'
'mesh_cast should only be used when AxisType changes between the'
' input mesh and the target mesh. Got src'
f' axis_types={src_sharding.mesh._axis_types_dict} and dst'
f' axis_types={dst_sharding.mesh._axis_types_dict}. To reshard between'
@ -2723,12 +2723,12 @@ def _mesh_cast_abstract_eval(aval, dst_sharding):
if s is None and d is None:
continue
if s is None and d is not None:
assert (src_sharding.mesh._name_to_type[d] == mesh_lib.AxisTypes.Auto
and dst_sharding.mesh._name_to_type[d] == mesh_lib.AxisTypes.Explicit)
assert (src_sharding.mesh._name_to_type[d] == mesh_lib.AxisType.Auto
and dst_sharding.mesh._name_to_type[d] == mesh_lib.AxisType.Explicit)
continue
if s is not None and d is None:
assert (src_sharding.mesh._name_to_type[s] == mesh_lib.AxisTypes.Explicit
and dst_sharding.mesh._name_to_type[s] == mesh_lib.AxisTypes.Auto)
assert (src_sharding.mesh._name_to_type[s] == mesh_lib.AxisType.Explicit
and dst_sharding.mesh._name_to_type[s] == mesh_lib.AxisType.Auto)
continue
if d != s:
raise ValueError(
@ -2821,7 +2821,7 @@ batching.skippable_batchers[reshard_p] = lambda _: ()
# -------------------- auto and user mode -------------------------
def _get_new_mesh(axes: str | tuple[str, ...] | None,
axis_type: mesh_lib.AxisTypes, name: str,
axis_type: mesh_lib.AxisType, name: str,
error_on_manual_to_auto_explict=False):
cur_mesh = mesh_lib.get_abstract_mesh()
# TODO(yashkatariya): Maybe allow fetching mesh from the args to enable
@ -2837,8 +2837,8 @@ def _get_new_mesh(axes: str | tuple[str, ...] | None,
axes = (axes,)
for a in axes:
if (error_on_manual_to_auto_explict and
cur_mesh._name_to_type[a] == mesh_lib.AxisTypes.Manual and
axis_type in {mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Explicit}):
cur_mesh._name_to_type[a] == mesh_lib.AxisType.Manual and
axis_type in {mesh_lib.AxisType.Auto, mesh_lib.AxisType.Explicit}):
raise NotImplementedError(
'Going from `Manual` AxisType to `Auto` or `Explicit` AxisType is not'
' allowed. Please file a bug at https://github.com/jax-ml/jax/issues'
@ -2855,7 +2855,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
raise TypeError("Missing required keyword argument: 'out_shardings'")
else:
_out_shardings = out_shardings
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'auto_axes',
new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Auto, 'auto_axes',
error_on_manual_to_auto_explict=True)
with mesh_lib.use_abstract_mesh(new_mesh):
in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual(
@ -2867,7 +2867,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
@contextlib.contextmanager
def use_auto_axes(*axes):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'use_auto_axes')
new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Auto, 'use_auto_axes')
with mesh_lib.use_abstract_mesh(new_mesh):
yield
@ -2882,7 +2882,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
raise TypeError("Missing required keyword argument: 'in_shardings'")
else:
_in_shardings = in_shardings
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit, 'explicit_axes',
new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, 'explicit_axes',
error_on_manual_to_auto_explict=True)
with mesh_lib.use_abstract_mesh(new_mesh):
args = mesh_cast(args, _in_shardings)
@ -2894,7 +2894,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
@contextlib.contextmanager
def use_explicit_axes(*axes):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit,
new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Explicit,
'use_explicit_axes')
with mesh_lib.use_abstract_mesh(new_mesh):
yield

View File

@ -108,7 +108,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
used_axes.extend(d.axes)
remaining_axes = set(mesh.axis_names) - set(used_axes)
replicated_axes = tuple(r for r in remaining_axes
if mesh._name_to_type[r] == mesh_lib.AxisTypes.Explicit)
if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit)
return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings,
sdy_sharding.logical_device_ids, replicated_axes)
return sdy_sharding
@ -1301,7 +1301,7 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
if s is None:
continue
if sharding.mesh._name_to_type[s] in {
mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Manual}:
mesh_lib.AxisType.Auto, mesh_lib.AxisType.Manual}:
raise ValueError(
f'PartitionSpec passed to {api_name} cannot contain axis'
' names that are of type Auto or Manual. Got PartitionSpec:'
@ -1313,7 +1313,7 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
*, devices: Sequence[xc.Device] | None = None,
axis_types: tuple[mesh_lib.AxisTypes, ...] | None = None
axis_types: tuple[mesh_lib.AxisType, ...] | None = None
) -> mesh_lib.Mesh:
"""Creates an efficient mesh with the shape and axis names specified.

View File

@ -1576,7 +1576,7 @@ def with_and_without_mesh(f):
))(with_mesh_from_kwargs(f))
def with_user_mesh(sizes, names, axis_types=None):
axis_types = ((mesh_lib.AxisTypes.Explicit,) * len(names)
axis_types = ((mesh_lib.AxisType.Explicit,) * len(names)
if axis_types is None else axis_types)
def decorator(fn):
def mesh_fn(*args, **kwargs):

View File

@ -47,7 +47,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.core import Tracer
from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, use_abstract_mesh,
from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh,
get_abstract_mesh)
from jax._src.api import _shared_code_pmap, _prepare_pmap
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
@ -487,21 +487,21 @@ def _as_manual_mesh(mesh, auto: frozenset):
cur_mesh = mesh
explicit_axes, auto_axes = set(), set() # type: ignore
for a in auto:
if cur_mesh._name_to_type[a] == AxisTypes.Auto:
if cur_mesh._name_to_type[a] == AxisType.Auto:
auto_axes.add(a)
else:
assert cur_mesh._name_to_type[a] == AxisTypes.Explicit
assert cur_mesh._name_to_type[a] == AxisType.Explicit
explicit_axes.add(a)
new_axis_types = []
for n in mesh.axis_names:
if n in manual_axes:
new_axis_types.append(AxisTypes.Manual)
new_axis_types.append(AxisType.Manual)
elif n in auto_axes:
new_axis_types.append(AxisTypes.Auto)
new_axis_types.append(AxisType.Auto)
else:
assert n in explicit_axes
new_axis_types.append(AxisTypes.Explicit)
new_axis_types.append(AxisType.Explicit)
return AbstractMesh(mesh.axis_sizes, mesh.axis_names,
axis_types=tuple(new_axis_types))
@ -1901,7 +1901,7 @@ def _all_newly_manual_mesh_names(
vmap_spmd_names = set(axis_env.spmd_axis_names)
if not (ctx_mesh := get_abstract_mesh()).empty:
mesh = ctx_mesh
already_manual_names = set(ctx_mesh._axis_types_dict.get(AxisTypes.Manual, ()))
already_manual_names = set(ctx_mesh._axis_types_dict.get(AxisType.Manual, ()))
else:
# TODO(mattjj): remove this mechanism when we revise mesh scopes
already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names

View File

@ -31,7 +31,7 @@ from jax._src.partition_spec import (
from jax._src.interpreters.pxla import Mesh as Mesh
from jax._src.mesh import (
AbstractMesh as AbstractMesh,
AxisTypes as AxisTypes,
AxisType as AxisType,
)
_deprecations = {

View File

@ -32,7 +32,7 @@ from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import dialects, ir
from jax._src.util import safe_zip
from jax._src.mesh import AxisTypes
from jax._src.mesh import AxisType
from jax._src.sharding import common_devices_indices_map
from jax._src.sharding_impls import (
_op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map,
@ -1356,40 +1356,40 @@ class ShardingTest(jtu.JaxTestCase):
ValueError,
'Number of axis names should match the number of axis_types'):
jtu.create_mesh((2, 1), ('x', 'y'),
axis_types=jax.sharding.AxisTypes.Auto)
axis_types=jax.sharding.AxisType.Auto)
with self.assertRaisesRegex(
ValueError,
'Number of axis names should match the number of axis_types'):
jax.sharding.AbstractMesh((2, 1), ('x', 'y'),
axis_types=jax.sharding.AxisTypes.Auto)
axis_types=jax.sharding.AxisType.Auto)
def test_make_mesh_axis_types(self):
Auto, Explicit, Manual = AxisTypes.Auto, AxisTypes.Explicit, AxisTypes.Manual
Auto, Explicit, Manual = AxisType.Auto, AxisType.Explicit, AxisType.Manual
mesh1 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
mesh2 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto)
self.assertEqual(mesh1, mesh2)
mesh = jax.make_mesh((1, 1), ('x', 'y'))
self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Auto: ('x', 'y')})
self.assertDictEqual(mesh._axis_types_dict, {AxisType.Auto: ('x', 'y')})
mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'),
axis_types=(Explicit, Auto, Manual))
self.assertDictEqual(
mesh._axis_types_dict, {AxisTypes.Auto: ('y',), AxisTypes.Explicit: ('x',),
AxisTypes.Manual: ('z',)})
mesh._axis_types_dict, {AxisType.Auto: ('y',), AxisType.Explicit: ('x',),
AxisType.Manual: ('z',)})
mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'),
axis_types=(Explicit, Explicit, Manual))
self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Explicit: ('x', 'y'),
AxisTypes.Manual: ('z',)})
self.assertDictEqual(mesh._axis_types_dict, {AxisType.Explicit: ('x', 'y'),
AxisType.Manual: ('z',)})
mesh = jax.make_mesh((1, 1), ('x', 'y'), axis_types=(Explicit, Explicit))
self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Explicit: ('x', 'y')})
self.assertDictEqual(mesh._axis_types_dict, {AxisType.Explicit: ('x', 'y')})
mesh = jax.make_mesh((1,), 'model', axis_types=Manual)
self.assertDictEqual(mesh._axis_types_dict, {AxisTypes.Manual: ('model',)})
self.assertDictEqual(mesh._axis_types_dict, {AxisType.Manual: ('model',)})
with self.assertRaisesRegex(
ValueError,

View File

@ -21,7 +21,7 @@ import jax
from jax._src import core
from jax._src import config
from jax._src import test_util as jtu
from jax.sharding import NamedSharding, PartitionSpec as P, AxisTypes
from jax.sharding import NamedSharding, PartitionSpec as P, AxisType
import jax.numpy as jnp
from jax._src.state.types import (RefEffect)
@ -212,7 +212,7 @@ class MutableArrayTest(jtu.JaxTestCase):
def test_explicit_sharding_after_indexing(self):
# https://github.com/jax-ml/jax/issues/26936
mesh = jtu.create_mesh((1, 1), ('x', 'y'),
axis_types=(AxisTypes.Explicit,) * 2)
axis_types=(AxisType.Explicit,) * 2)
sharding = NamedSharding(mesh, P('x', 'y'))
@jax.jit

View File

@ -57,7 +57,7 @@ from jax._src.pjit import (pjit, mesh_cast, auto_axes, explicit_axes,
use_auto_axes, use_explicit_axes, reshard)
from jax._src.named_sharding import DuplicateSpecError
from jax._src import mesh as mesh_lib
from jax._src.mesh import AxisTypes
from jax._src.mesh import AxisType
from jax._src.interpreters import pxla
from jax._src.lib.mlir import dialects
from jax._src import xla_bridge
@ -5254,7 +5254,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jtu.with_user_mesh((1,), 'x')
def test_broadcasting_nary_error(self, mesh):
mesh2 = Mesh([jax.devices()[0]], 'y',
axis_types=(mesh_lib.AxisTypes.Explicit,))
axis_types=(mesh_lib.AxisType.Explicit,))
arr1 = jax.device_put(np.arange(8), NamedSharding(mesh, P()))
arr2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@ -5572,7 +5572,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_explicit_mode_no_context_mesh(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'),
axis_types=(AxisTypes.Explicit,) * 2)
axis_types=(AxisType.Explicit,) * 2)
abstract_mesh = mesh.abstract_mesh
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
@ -5597,7 +5597,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_auto_mode_no_context_mesh(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'),
axis_types=(AxisTypes.Auto,) * 2)
axis_types=(AxisType.Auto,) * 2)
abstract_mesh = mesh.abstract_mesh
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
@ -5627,7 +5627,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
'mesh_cast should only be used when AxisTypes changes between the input'
'mesh_cast should only be used when AxisType changes between the input'
' mesh and the target mesh'):
f(arr)
@ -5637,18 +5637,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
'mesh_cast should only be used when AxisTypes changes between the input'
'mesh_cast should only be used when AxisType changes between the input'
' mesh and the target mesh'):
g(arr)
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types=(AxisTypes.Explicit, AxisTypes.Auto))
axis_types=(AxisType.Explicit, AxisType.Auto))
def test_mesh_cast_explicit_data_movement_error(self, mesh):
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
full_user_mesh = mesh_lib.AbstractMesh(
(2, 2), ('x', 'y'), axis_types=(AxisTypes.Explicit,) * 2)
(2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2)
@jax.jit
def f(x):
@ -5912,7 +5912,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
self.check_wsc_in_lowered(f.lower(arr).as_text())
@jtu.with_user_mesh((2, 2), ('x', 'y'), (mesh_lib.AxisTypes.Auto,) * 2)
@jtu.with_user_mesh((2, 2), ('x', 'y'), (mesh_lib.AxisType.Auto,) * 2)
def test_only_auto(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
@ -5932,7 +5932,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_auto_user(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types=(mesh_lib.AxisTypes.Auto,) * 2)
axis_types=(mesh_lib.AxisType.Auto,) * 2)
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@ -5951,8 +5951,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertNotIn('unspecified_dims', lowered_text)
mesh2 = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types=(mesh_lib.AxisTypes.Explicit,
mesh_lib.AxisTypes.Auto))
axis_types=(mesh_lib.AxisType.Explicit,
mesh_lib.AxisType.Auto))
with jax.sharding.use_mesh(mesh2):
arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None)))
@ -5965,8 +5965,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertTrue(lowered_text.count("unspecified_dims") == 5)
mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types=(mesh_lib.AxisTypes.Auto,
mesh_lib.AxisTypes.Explicit))
axis_types=(mesh_lib.AxisType.Auto,
mesh_lib.AxisType.Explicit))
with jax.sharding.use_mesh(mesh3):
arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P(None, 'x')))
@ -6022,7 +6022,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None)))
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types=(mesh_lib.AxisTypes.Auto,) * 2)
axis_types=(mesh_lib.AxisType.Auto,) * 2)
def test_full_auto_to_full_user(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
@ -6132,11 +6132,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
return x
self.assertDictEqual(arr.sharding.mesh._axis_types_dict,
{AxisTypes.Explicit: ('x',)})
{AxisType.Explicit: ('x',)})
out = f(arr)
self.assertArraysEqual(out, np_inp)
self.assertDictEqual(out.sharding.mesh._axis_types_dict,
{AxisTypes.Auto: ('x',)})
{AxisType.Auto: ('x',)})
@jtu.with_user_mesh((2,), 'x')
def test_inputs_different_context(self, mesh):
@ -6144,11 +6144,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
s = NamedSharding(mesh, P('x'))
arr = jax.device_put(np_inp, s)
auto_mesh = jax.make_mesh((2,), 'x', axis_types=(AxisTypes.Auto,))
auto_mesh = jax.make_mesh((2,), 'x', axis_types=(AxisType.Auto,))
with jax.sharding.use_mesh(auto_mesh):
arr2 = jnp.ones(8)
self.assertDictEqual(arr2.sharding.mesh._axis_types_dict,
{AxisTypes.Auto: ('x',)})
{AxisType.Auto: ('x',)})
@jax.jit
def f(x, y):
@ -6156,16 +6156,16 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out1, out2 = f(arr, arr2)
self.assertDictEqual(out1.sharding.mesh._axis_types_dict,
{AxisTypes.Explicit: ('x',)})
{AxisType.Explicit: ('x',)})
self.assertDictEqual(out2.sharding.mesh._axis_types_dict,
{AxisTypes.Auto: ('x',)})
{AxisType.Auto: ('x',)})
@jtu.with_user_mesh((2,), 'x')
def test_output_different_context_error(self, mesh):
np_inp1 = np.arange(16).reshape(8, 2)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
auto_mesh = jax.make_mesh((2,), 'x', axis_types=(AxisTypes.Auto,)).abstract_mesh
auto_mesh = jax.make_mesh((2,), 'x', axis_types=(AxisType.Auto,)).abstract_mesh
@jax.jit
def f(x, y):
@ -6188,8 +6188,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
g(arr1, arr2)
@jtu.with_user_mesh((2, 2, 2), ('x', 'y', 'z'),
axis_types=(AxisTypes.Explicit, AxisTypes.Explicit,
AxisTypes.Auto))
axis_types=(AxisType.Explicit, AxisType.Explicit,
AxisType.Auto))
def test_out_sharding_mix_axis_types(self, mesh):
np_inp = np.arange(16).reshape(4, 2, 2)
s = NamedSharding(mesh, P('x', None, None))
@ -6267,13 +6267,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_aval_spec_explicit_auto_complete(self):
abstract_mesh = mesh_lib.AbstractMesh(
(2,), 'x', axis_types=AxisTypes.Explicit)
(2,), 'x', axis_types=AxisType.Explicit)
s = NamedSharding(abstract_mesh, P('x'))
out = core.ShapedArray((8, 2), jnp.int32, sharding=s)
self.assertEqual(out.sharding.spec, P('x', None))
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types=(mesh_lib.AxisTypes.Auto,) * 2)
axis_types=(mesh_lib.AxisType.Auto,) * 2)
def test_full_user_mode(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
@ -6403,8 +6403,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
f(arr1, arr2, arr3)
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types=(mesh_lib.AxisTypes.Explicit,
mesh_lib.AxisTypes.Auto))
axis_types=(mesh_lib.AxisType.Explicit,
mesh_lib.AxisType.Auto))
def test_mix_to_full_user_mode(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
@ -6430,7 +6430,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types=(mesh_lib.AxisTypes.Auto,) * 2)
axis_types=(mesh_lib.AxisType.Auto,) * 2)
def test_full_auto_to_partial_user(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
@ -6545,7 +6545,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_auto_axes_top_level(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types=(AxisTypes.Explicit,) * 2)
axis_types=(AxisType.Explicit,) * 2)
np_inp = np.arange(16.).reshape(8, 2)
arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
@ -6567,7 +6567,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_explicit_axes_top_level(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types=(AxisTypes.Auto,) * 2)
axis_types=(AxisType.Auto,) * 2)
np_inp = np.arange(16.).reshape(8, 2)
arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
@ -6590,7 +6590,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_reshard_eager_mode(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types=(AxisTypes.Explicit,) * 2)
axis_types=(AxisType.Explicit,) * 2)
np_inp = np.arange(16.).reshape(8, 2)
arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
@ -6626,7 +6626,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types=(AxisTypes.Auto,) * 2)
axis_types=(AxisType.Auto,) * 2)
def test_full_visible_outside_jit(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
@ -6817,7 +6817,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.shape, (4, 8))
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x')))
@jtu.with_user_mesh((2,), ('x',), axis_types=AxisTypes.Auto)
@jtu.with_user_mesh((2,), ('x',), axis_types=AxisType.Auto)
def test_shmap_close_over(self, mesh):
const = jnp.arange(8)
def f():
@ -6828,7 +6828,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
jax.jit(shmap_f)() # doesn't crash
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types=(AxisTypes.Auto,) * 2)
axis_types=(AxisType.Auto,) * 2)
def test_shmap_close_over_partial_auto(self, mesh):
const = jnp.arange(8)
def f():
@ -6861,7 +6861,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
jax.lax.with_sharding_constraint(np.arange(8), s)
s = NamedSharding(Mesh(mesh.devices, mesh.axis_names,
axis_types=(AxisTypes.Explicit, AxisTypes.Auto)),
axis_types=(AxisType.Explicit, AxisType.Auto)),
P('x', P.UNCONSTRAINED))
with self.assertRaisesRegex(
ValueError,
@ -6876,7 +6876,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_pspec_einsum_no_context_mesh(self):
mesh = jtu.create_mesh((1, 1), ('x', 'y'),
axis_types=(AxisTypes.Explicit,) * 2)
axis_types=(AxisType.Explicit,) * 2)
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', None)))
@ -6891,7 +6891,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
f(arr, arr2)
@jtu.with_user_mesh((2, 1), ('x', 'y'),
axis_types=(AxisTypes.Auto,) * 2)
axis_types=(AxisType.Auto,) * 2)
def test_error_on_canonicalize_under_auto_mode(self, mesh):
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
@ -6968,7 +6968,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', 'y')))
@jtu.with_user_mesh((2, 1), ('x', 'y'),
axis_types=(AxisTypes.Auto,) * 2)
axis_types=(AxisType.Auto,) * 2)
def test_axes_api_error_manual_to_auto_explicit(self, mesh):
def g(x):
return auto_axes(lambda a: a * 2, axes=('x', 'y'),
@ -7037,7 +7037,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, expected_spec))
def test_auto_axes_computation_follows_data_error(self):
mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisTypes.Explicit,))
mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,))
s = NamedSharding(mesh, P('x'))
arr = jax.device_put(np.arange(8), s)
@ -7050,7 +7050,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_divisbility_aval_error(self):
abstract_mesh = mesh_lib.AbstractMesh(
(2,), ('x',), axis_types=AxisTypes.Explicit)
(2,), ('x',), axis_types=AxisType.Explicit)
s = NamedSharding(abstract_mesh, P('x'))
with self.assertRaisesRegex(
ValueError, 'does not evenly divide the dimension size'):
@ -7082,7 +7082,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
def test_set_mesh(self):
mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisTypes.Explicit,))
mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,))
prev_mesh = config.device_context.value
prev_abstract_mesh = config.abstract_mesh_context_manager.value
try:
@ -7104,7 +7104,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
self.assertArraysEqual(out, np.arange(8) * 2)
@jtu.with_user_mesh((2,), ('x',), axis_types=AxisTypes.Auto)
@jtu.with_user_mesh((2,), ('x',), axis_types=AxisType.Auto)
def test_explicit_axes_late_bind(self, mesh):
@explicit_axes
def f(x):

View File

@ -40,7 +40,7 @@ from jax._src import test_util as jtu
from jax._src.lib.mlir.dialects import sdy
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
from jax._src.ad_checkpoint import saved_residuals
from jax._src.mesh import AxisTypes
from jax._src.mesh import AxisType
from jax._src.interpreters import partial_eval as pe
from jax._src import linear_util as lu
from jax._src import tree_util
@ -1891,7 +1891,7 @@ class ShardMapTest(jtu.JaxTestCase):
def g(x):
self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict,
{AxisTypes.Manual: ('i',), AxisTypes.Auto: ('j',)})
{AxisType.Manual: ('i',), AxisType.Auto: ('j',)})
x = jax.lax.with_sharding_constraint(
x, jax.sharding.NamedSharding(mesh, P(None, 'j')))
return x * x
@ -1923,11 +1923,11 @@ class ShardMapTest(jtu.JaxTestCase):
def test_partial_auto_explicit_no_use_mesh(self):
mesh = jtu.create_mesh((2, 2), ('i', 'j'),
axis_types=(AxisTypes.Explicit,) * 2)
axis_types=(AxisType.Explicit,) * 2)
def g(x):
self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict,
{AxisTypes.Manual: ('i',), AxisTypes.Explicit: ('j',)})
{AxisType.Manual: ('i',), AxisType.Explicit: ('j',)})
self.assertEqual(x.aval.sharding.spec, P(None, 'j'))
out = x * x
self.assertEqual(out.aval.sharding.spec, P(None, 'j'))
@ -1953,7 +1953,7 @@ class ShardMapTest(jtu.JaxTestCase):
def test_partial_auto_explicit(self, mesh):
def g(x):
self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict,
{AxisTypes.Manual: ('i',), AxisTypes.Explicit: ('j',)})
{AxisType.Manual: ('i',), AxisType.Explicit: ('j',)})
self.assertEqual(x.aval.sharding.spec, P(None, 'j'))
out = x * x
self.assertEqual(out.aval.sharding.spec, P(None, 'j'))
@ -1997,8 +1997,8 @@ class ShardMapTest(jtu.JaxTestCase):
def test_partial_auto_explicit_multi_explicit(self, mesh):
def g(x):
self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict,
{AxisTypes.Manual: ('i', 'j'),
AxisTypes.Explicit: ('k', 'l')})
{AxisType.Manual: ('i', 'j'),
AxisType.Explicit: ('k', 'l')})
self.assertEqual(x.aval.sharding.spec, P(None, None, 'k', 'l'))
out = x.T
self.assertEqual(out.aval.sharding.spec, P('l', 'k', None, None))