Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.

Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
This commit is contained in:
Yash Katariya 2025-01-16 17:55:15 -08:00 committed by jax authors
parent bd22bfef71
commit 49224d6cdb
13 changed files with 120 additions and 90 deletions

View File

@ -729,9 +729,10 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
raise AttributeError(self,
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
raise AttributeError(
self,
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
@property
def committed(self):
@ -1674,7 +1675,7 @@ def _invalid_shape_error(shape: Shape, context: str=""):
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
# Collective too.
def modify_spec_for_auto(spec, mesh) -> P:
def modify_spec_for_hidden(spec, mesh) -> P:
new_spec = [] # type: ignore
for s in spec:
if s is None:
@ -1682,13 +1683,13 @@ def modify_spec_for_auto(spec, mesh) -> P:
else:
temp_s = s[0] if isinstance(s, tuple) else s
new_spec.append(
None if mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
None if mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Hidden else s)
return P(*new_spec)
def _maybe_modify_sharding(sharding):
if mesh_lib.AxisTypes.Auto not in sharding.mesh.axis_types:
if mesh_lib.AxisTypes.Hidden not in sharding.mesh.axis_types:
return sharding
new_spec = modify_spec_for_auto(sharding.spec, sharding.mesh)
new_spec = modify_spec_for_hidden(sharding.spec, sharding.mesh)
return sharding.with_spec(new_spec)

View File

@ -2596,7 +2596,7 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
# Don't emit a wsc under full manual mode to avoid increasing HLO size.
if aval.sharding.mesh._are_all_axes_collective:
return op
if aval.sharding.mesh._are_all_axes_auto:
if aval.sharding.mesh._are_all_axes_hidden:
return op
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
# `return op` early and avoid bloating HLO size.
@ -2609,7 +2609,7 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
if sharding_proto is None else sharding_proto)
unspecified_dims = None
if aval.sharding.mesh._any_axis_auto:
if aval.sharding.mesh._any_axis_hidden:
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes
# as unspecified?
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None}

View File

@ -2172,7 +2172,7 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
for sp in a.sharding.spec])
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
if a.sharding.mesh._any_axis_hidden else a.sharding.spec)
out.append(NamedSharding(
_abstract_to_concrete_mesh(a.sharding.mesh), spec))
else:

View File

@ -1883,7 +1883,7 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted,
mode, fill_value):
# TODO(yashkatariya): Write a proper gather sharding rule.
if mesh_lib.get_abstract_mesh()._are_all_axes_auto: # type: ignore
if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
return None
raise GatherShardingError(
"Use `.at[...].get(out_specs=)` to provide output PartitionSpec for the"

View File

@ -49,7 +49,7 @@ def _get_array_abstraction_level(a): return a.array_abstraction_level
def call_sharding_rule(rule, num_out, *avals, **kwargs):
if config.sharding_in_types.value:
if rule is None and mesh_lib.get_abstract_mesh()._are_all_axes_auto: # type: ignore
if rule is None and mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
return None if num_out is None else [None] * num_out
return rule(*avals, **kwargs)
return None if num_out is None else [None] * num_out

View File

@ -103,8 +103,8 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
class AxisTypes(enum.Enum):
Auto = enum.auto()
User = enum.auto()
Hidden = enum.auto()
Visible = enum.auto()
Collective = enum.auto()
def __repr__(self):
@ -198,9 +198,15 @@ class Mesh(contextlib.ContextDecorator):
f"devices.ndim == {devices.ndim} and "
f"len(axis_names) == {len(axis_names)}.")
axis_types = ({AxisTypes.Auto: axis_names} if axis_types is None else
axis_types = ({AxisTypes.Hidden: axis_names} if axis_types is None else
axis_types)
axis_types_tuple = tuple(axis_types.items())
if len(axis_names_to_types(axis_types).keys()) != len(axis_names):
raise ValueError(
"Number of axis names in axis_types should match the number of"
f" axis_names. Got axis_names={axis_names} and"
f" axis_types={axis_types}")
key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple)
val = _mesh_object_dict.get(key, None)
if val is not None:
@ -356,16 +362,16 @@ class Mesh(contextlib.ContextDecorator):
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
@functools.cached_property
def _are_all_axes_auto(self) -> bool:
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
def _are_all_axes_hidden(self) -> bool:
return all(t == AxisTypes.Hidden for t in self.axis_types.keys())
@functools.cached_property
def _any_axis_collective(self) -> bool:
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
@functools.cached_property
def _any_axis_auto(self) -> bool:
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
def _any_axis_hidden(self) -> bool:
return any(t == AxisTypes.Hidden for t in self.axis_types.keys())
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
@ -397,9 +403,14 @@ class AbstractMesh:
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
else:
self._axis_names, self._axis_sizes = (), ()
self.axis_types = ({AxisTypes.Auto: self._axis_names} if axis_types is None
else axis_types)
self.axis_types = ({AxisTypes.Hidden: self._axis_names}
if axis_types is None else axis_types)
self._axis_types_tuple = tuple(self.axis_types.items())
if len(self._name_to_type.keys()) != len(self._axis_names):
raise ValueError(
"Number of axis names in axis_types should match the number of"
f" axis_names in shape_tuple. Got axis_names={self._axis_names} and"
f" axis_types={self.axis_types}")
def __hash__(self):
return hash((self.shape_tuple, self._axis_types_tuple))
@ -461,16 +472,16 @@ class AbstractMesh:
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
@functools.cached_property
def _are_all_axes_auto(self) -> bool:
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
def _are_all_axes_hidden(self) -> bool:
return all(t == AxisTypes.Hidden for t in self.axis_types.keys())
@functools.cached_property
def _any_axis_collective(self) -> bool:
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
@functools.cached_property
def _any_axis_auto(self) -> bool:
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
def _any_axis_hidden(self) -> bool:
return any(t == AxisTypes.Hidden for t in self.axis_types.keys())
@property
def devices(self):
@ -535,7 +546,7 @@ def get_concrete_mesh():
@contextlib.contextmanager
def set_mesh(mesh: Mesh):
def use_mesh(mesh: Mesh):
with (set_abstract_mesh(mesh.abstract_mesh),
jax_config.sharding_in_types(True), set_concrete_mesh(mesh)):
yield

View File

@ -42,7 +42,7 @@ from jax._src.lib import xla_client as xc
from jax._src.numpy import array_api_metadata
from jax._src.numpy import lax_numpy
from jax._src import mesh as mesh_lib
from jax._src.pjit import auto_mode, PartitionSpec
from jax._src.pjit import hidden_mode, PartitionSpec
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.ops import scatter
@ -781,8 +781,8 @@ class _IndexUpdateRef:
fill_value=fill_value)
if out_spec is not None:
assert isinstance(out_spec, PartitionSpec)
take = auto_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
out_specs=out_spec)
take = hidden_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
out_specs=out_spec)
return take(self.array, self.index)
def set(self, values, *, indices_are_sorted=False, unique_indices=False,

View File

@ -2754,11 +2754,11 @@ def _get_new_mesh(axes: str | tuple[str, ...], axis_type: mesh_lib.AxisTypes):
new_mesh = cur_mesh.update_axis_types({axis_type: axes}) # type: ignore
return new_mesh
def auto_mode(fun, *, axes: str | tuple[str, ...], out_specs):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto)
def hidden_mode(fun, *, axes: str | tuple[str, ...], out_specs):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden)
def decorator(*args, **kwargs):
with mesh_lib.set_abstract_mesh(new_mesh):
in_specs = tree_map(lambda a: core.modify_spec_for_auto(
in_specs = tree_map(lambda a: core.modify_spec_for_hidden(
a.sharding.spec, new_mesh), args)
args = sharding_cast(args, in_specs)
out = fun(*args, **kwargs)
@ -2767,26 +2767,26 @@ def auto_mode(fun, *, axes: str | tuple[str, ...], out_specs):
@contextlib.contextmanager
def auto_mode_ctx(axes: str | tuple[str, ...]):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto)
def hidden_axes(axes: str | tuple[str, ...]):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden)
with mesh_lib.set_abstract_mesh(new_mesh):
yield
def user_mode(fun, *, axes: str | tuple[str, ...], in_specs):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.User)
def visible_mode(fun, *, axes: str | tuple[str, ...], in_specs):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Visible)
def decorator(*args, **kwargs):
with mesh_lib.set_abstract_mesh(new_mesh):
args = sharding_cast(args, in_specs)
out = fun(*args, **kwargs)
out_specs = tree_map(lambda o: core.modify_spec_for_auto(
out_specs = tree_map(lambda o: core.modify_spec_for_hidden(
o.sharding.spec, mesh_lib.get_abstract_mesh()), out)
return sharding_cast(out, out_specs)
return decorator
@contextlib.contextmanager
def user_mode_ctx(axes: str | tuple[str, ...]):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.User)
def visible_axes(axes: str | tuple[str, ...]):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Visible)
with mesh_lib.set_abstract_mesh(new_mesh):
yield

View File

@ -78,10 +78,10 @@ def _check_axis_type_consistency(mesh, parsed_pspec):
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis'
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
if mesh_lib.AxisTypes.Auto not in mesh.axis_types and None in parsed_pspec:
if mesh_lib.AxisTypes.Hidden not in mesh.axis_types and None in parsed_pspec:
raise ValueError(
f'PartitionSpec {parsed_pspec.get_partition_spec()} cannot contain'
' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh'
' `P.UNCONSTRAINED` when no mesh axis_types are `Hidden`. Got mesh'
f' axis_types: {mesh.axis_types}')
@ -439,7 +439,7 @@ class NamedSharding(jsharding.Sharding):
# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra
# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)`
def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
if mesh._any_axis_auto:
if mesh._any_axis_hidden:
dim_shardings, used_axes = [], [] # type: ignore
for d in sdy_sharding.dimension_shardings:
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open?
@ -448,7 +448,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
used_axes.extend(d.axes)
remaining_axes = set(mesh.axis_names) - set(used_axes)
replicated_axes = tuple(r for r in remaining_axes
if mesh._name_to_type[r] == mesh_lib.AxisTypes.User)
if mesh._name_to_type[r] == mesh_lib.AxisTypes.Visible)
return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings,
sdy_sharding.logical_device_ids, replicated_axes)
return sdy_sharding
@ -1774,9 +1774,9 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
for s in flatten_spec(sharding.spec):
if sharding.mesh._name_to_type[s] in {
mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Collective}:
mesh_lib.AxisTypes.Hidden, mesh_lib.AxisTypes.Collective}:
raise ValueError(
'PartitionSpec cannot contain axis names that are of type Auto or'
'PartitionSpec cannot contain axis names that are of type Hidden or'
f' Collective. Got PartitionSpec: {sharding.spec} with axis name:'
f' {s} or type: {sharding.mesh._name_to_type[s]}')
return sharding

View File

@ -1549,12 +1549,12 @@ def with_and_without_mesh(f):
))(with_mesh_from_kwargs(f))
def with_user_mesh(sizes, names, axis_types=None):
axis_types = ({mesh_lib.AxisTypes.User: names}
axis_types = ({mesh_lib.AxisTypes.Visible: names}
if axis_types is None else axis_types)
def decorator(fn):
def mesh_fn(*args, **kwargs):
mesh = create_mesh(sizes, names, axis_types=axis_types)
with mesh_lib.set_mesh(mesh):
with mesh_lib.use_mesh(mesh):
return fn(*args, **kwargs, mesh=mesh)
return mesh_fn
return decorator

View File

@ -27,7 +27,11 @@ from jax._src.partition_spec import (
PartitionSpec as PartitionSpec,
)
from jax._src.interpreters.pxla import Mesh as Mesh
from jax._src.mesh import AbstractMesh as AbstractMesh
from jax._src.mesh import (
AbstractMesh as AbstractMesh,
AxisTypes as AxisTypes,
use_mesh as use_mesh
)
_deprecations = {
# Finalized 2024-10-01; remove after 2025-01-01.

View File

@ -1260,7 +1260,6 @@ class ShardingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, msg):
jax.make_array_from_single_device_arrays(x.shape, s, [x, x])
def test_gspmd_sharding_hash_eq(self):
mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z'))
ns = NamedSharding(mesh, P('x', 'y', 'z'))
@ -1299,6 +1298,21 @@ class ShardingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, 'Mesh axis names cannot be None.'):
jax.sharding.Mesh(jax.devices(), (None, 'x'))
def test_mesh_axis_types_mismatch(self):
with self.assertRaisesRegex(
ValueError,
'Number of axis names in axis_types should match the number of'
' axis_names'):
jtu.create_mesh((2, 1), ('x', 'y'),
axis_types={jax.sharding.AxisTypes.Hidden: 'x'})
with self.assertRaisesRegex(
ValueError,
'Number of axis names in axis_types should match the number of'
' axis_names in shape_tuple'):
jax.sharding.AbstractMesh((('x', 2), ('y', 1)),
axis_types={jax.sharding.AxisTypes.Hidden: 'x'})
@jtu.with_config(jax_use_shardy_partitioner=True)
class ShardyShardingTest(jtu.JaxTestCase):

View File

@ -51,8 +51,8 @@ from jax._src import sharding_impls
from jax._src.sharding_impls import (
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
SingleDeviceSharding, parse_flatten_op_sharding)
from jax._src.pjit import (pjit, sharding_cast, auto_mode, user_mode,
auto_mode_ctx, user_mode_ctx)
from jax._src.pjit import (pjit, sharding_cast, hidden_mode, visible_mode,
hidden_axes, visible_axes)
from jax._src import mesh as mesh_lib
from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes
from jax._src.interpreters import pxla
@ -5125,7 +5125,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jtu.with_user_mesh((1,), 'x')
def test_broadcasting_nary_error(self, mesh):
mesh2 = Mesh([jax.devices()[0]], 'y',
axis_types={mesh_lib.AxisTypes.User: 'y'})
axis_types={mesh_lib.AxisTypes.Visible: 'y'})
arr1 = jax.device_put(np.arange(8), NamedSharding(mesh, P()))
arr2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@ -5688,7 +5688,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
self.check_wsc_in_lowered(f.lower(arr).as_text())
@jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Auto: ('x', 'y')})
@jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Hidden: ('x', 'y')})
def test_only_auto(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
@ -5708,7 +5708,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_auto_user(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')})
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@ -5720,16 +5720,16 @@ class ShardingInTypesTest(jtu.JaxTestCase):
a = z @ x2
return a
with mesh_lib.set_mesh(mesh):
with mesh_lib.use_mesh(mesh):
out = f(arr, arr.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
lowered_text = f.lower(arr, arr.T).as_text()
self.assertNotIn('unspecified_dims', lowered_text)
mesh2 = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.User: 'x',
mesh_lib.AxisTypes.Auto: 'y'})
with mesh_lib.set_mesh(mesh2):
axis_types={mesh_lib.AxisTypes.Visible: 'x',
mesh_lib.AxisTypes.Hidden: 'y'})
with mesh_lib.use_mesh(mesh2):
arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None)))
out = f(arr, arr2)
@ -5741,9 +5741,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertTrue(lowered_text.count("unspecified_dims") == 5)
mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.User: 'y',
mesh_lib.AxisTypes.Auto: 'x'})
with mesh_lib.set_mesh(mesh3):
axis_types={mesh_lib.AxisTypes.Visible: 'y',
mesh_lib.AxisTypes.Hidden: 'x'})
with mesh_lib.use_mesh(mesh3):
arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P('y', 'x')))
out = f(arr, arr2)
@ -5779,7 +5779,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = x * 2
with auto_mode_ctx(axes=('x', 'y')):
with hidden_axes(axes=('x', 'y')):
y = sharding_cast(y, P(None, None))
self.assertEqual(y.sharding.spec, P(None, None))
z = jnp.sin(y)
@ -5798,7 +5798,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None)))
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')})
def test_full_auto_to_full_user(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
@ -5807,7 +5807,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = x * 2
with user_mode_ctx(axes=('x', 'y')):
with visible_axes(axes=('x', 'y')):
y = sharding_cast(y, P(None, 'y'))
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
@ -5833,7 +5833,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = x * 2
with auto_mode_ctx('x'):
with hidden_axes('x'):
y = sharding_cast(y, P(None, 'y'))
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
@ -5860,7 +5860,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x, y):
x = x * 2
with auto_mode_ctx('x'):
with hidden_axes('x'):
z = x @ y
return z
@ -5871,9 +5871,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_sharding_cast_src_dst_mesh_mismatch(self):
np_inp = np.arange(16.).reshape(8, 2)
mesh = jtu.create_mesh((2, 1), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.User: ('x', 'y')})
axis_types={mesh_lib.AxisTypes.Visible: ('x', 'y')})
mesh2 = jtu.create_mesh((2, 1), ('a', 'b'),
axis_types={mesh_lib.AxisTypes.User: ('a', 'b')})
axis_types={mesh_lib.AxisTypes.Visible: ('a', 'b')})
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
f = lambda x: sharding_cast(x, NamedSharding(mesh2, P('a', 'b')))
@ -5881,7 +5881,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
ValueError, "Mesh shape of the input.*does not match"):
f(arr)
with mesh_lib.set_mesh(mesh):
with mesh_lib.use_mesh(mesh):
with self.assertRaisesRegex(
ValueError, "Mesh shape of the input.*does not match"):
jax.jit(f)(arr)
@ -5923,15 +5923,15 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
auto_mesh = get_abstract_mesh().update_axis_types({AxisTypes.Auto: 'x'})
auto_mesh = get_abstract_mesh().update_axis_types({AxisTypes.Hidden: 'x'})
with set_abstract_mesh(auto_mesh):
x = sharding_cast(x, P(None, None))
return x
self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.User: 'x'})
self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.Visible: 'x'})
out = f(arr)
self.assertArraysEqual(out, np_inp)
self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'})
self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'})
@jtu.with_user_mesh((2,), 'x')
def test_inputs_different_context(self, mesh):
@ -5939,18 +5939,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
s = NamedSharding(mesh, P('x'))
arr = jax.device_put(np_inp, s)
auto_mesh = jax.make_mesh((2,), 'x', axis_types={AxisTypes.Auto: 'x'})
with mesh_lib.set_mesh(auto_mesh):
auto_mesh = jax.make_mesh((2,), 'x', axis_types={AxisTypes.Hidden: 'x'})
with mesh_lib.use_mesh(auto_mesh):
arr2 = jnp.ones(8)
self.assertDictEqual(arr2.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'})
self.assertDictEqual(arr2.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'})
@jax.jit
def f(x, y):
return x, y
out1, out2 = f(arr, arr2)
self.assertDictEqual(out1.sharding.mesh.axis_types, {AxisTypes.User: 'x'})
self.assertDictEqual(out2.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'})
self.assertDictEqual(out1.sharding.mesh.axis_types, {AxisTypes.Visible: 'x'})
self.assertDictEqual(out2.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'})
@jtu.with_user_mesh((2,), 'x')
def test_output_different_context_error(self, mesh):
@ -5958,7 +5958,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
auto_mesh = jax.make_mesh((2,), 'x',
axis_types={AxisTypes.Auto: 'x'}).abstract_mesh
axis_types={AxisTypes.Hidden: 'x'}).abstract_mesh
@jax.jit
def f(x, y):
@ -5972,17 +5972,17 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def g(x, y):
with auto_mode_ctx('x'):
with hidden_axes('x'):
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
return out
with self.assertRaisesRegex(
ValueError, "PartitionSpec cannot contain axis names.*Auto"):
ValueError, "PartitionSpec cannot contain axis names.*Hidden"):
g(arr1, arr2)
@jtu.with_user_mesh((2, 2, 2), ('x', 'y', 'z'),
axis_types={AxisTypes.User: ('x', 'y'),
AxisTypes.Auto: 'z'})
axis_types={AxisTypes.Visible: ('x', 'y'),
AxisTypes.Hidden: 'z'})
def test_out_sharding_mix_axis_types(self, mesh):
np_inp = np.arange(16).reshape(4, 2, 2)
s = NamedSharding(mesh, P('x', None, None))
@ -6011,7 +6011,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(auto_mode, axes='x', out_specs=P('x', None))
@partial(hidden_mode, axes='x', out_specs=P('x', None))
def h(y):
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
@ -6035,13 +6035,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None)))
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')})
def test_full_user_mode(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(user_mode, axes=('x', 'y'), in_specs=P('x', 'y'))
@partial(visible_mode, axes=('x', 'y'), in_specs=P('x', 'y'))
def h(y):
self.assertEqual(y.sharding.spec, P('x', 'y'))
z = jnp.sin(y)
@ -6064,14 +6064,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
core.jaxpr_as_fun(jaxpr)(arr) # doesn't crash
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.User: 'x',
mesh_lib.AxisTypes.Auto: 'y'})
axis_types={mesh_lib.AxisTypes.Visible: 'x',
mesh_lib.AxisTypes.Hidden: 'y'})
def test_mix_to_full_user_mode(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(user_mode, axes='y', in_specs=P('x', 'y'))
@partial(visible_mode, axes='y', in_specs=P('x', 'y'))
def h(y):
self.assertEqual(y.sharding.spec, P('x', 'y'))
z = jnp.sin(y)
@ -6091,13 +6091,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
@jtu.with_user_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')})
def test_full_auto_to_partial_user(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(user_mode, axes='y', in_specs=P(None, 'y'))
@partial(visible_mode, axes='y', in_specs=P(None, 'y'))
def h(y):
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)