mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`. PiperOrigin-RevId: 716446406
This commit is contained in:
parent
bd22bfef71
commit
49224d6cdb
@ -729,9 +729,10 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
|
||||
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
|
||||
# we raise an AttributeError so that hasattr() and getattr() work as expected.
|
||||
raise AttributeError(self,
|
||||
f"The 'sharding' attribute is not available on {self._error_repr()}."
|
||||
f"{self._origin_msg()}")
|
||||
raise AttributeError(
|
||||
self,
|
||||
f"The 'sharding' attribute is not available on {self._error_repr()}."
|
||||
f"{self._origin_msg()}")
|
||||
|
||||
@property
|
||||
def committed(self):
|
||||
@ -1674,7 +1675,7 @@ def _invalid_shape_error(shape: Shape, context: str=""):
|
||||
|
||||
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
|
||||
# Collective too.
|
||||
def modify_spec_for_auto(spec, mesh) -> P:
|
||||
def modify_spec_for_hidden(spec, mesh) -> P:
|
||||
new_spec = [] # type: ignore
|
||||
for s in spec:
|
||||
if s is None:
|
||||
@ -1682,13 +1683,13 @@ def modify_spec_for_auto(spec, mesh) -> P:
|
||||
else:
|
||||
temp_s = s[0] if isinstance(s, tuple) else s
|
||||
new_spec.append(
|
||||
None if mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
|
||||
None if mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Hidden else s)
|
||||
return P(*new_spec)
|
||||
|
||||
def _maybe_modify_sharding(sharding):
|
||||
if mesh_lib.AxisTypes.Auto not in sharding.mesh.axis_types:
|
||||
if mesh_lib.AxisTypes.Hidden not in sharding.mesh.axis_types:
|
||||
return sharding
|
||||
new_spec = modify_spec_for_auto(sharding.spec, sharding.mesh)
|
||||
new_spec = modify_spec_for_hidden(sharding.spec, sharding.mesh)
|
||||
return sharding.with_spec(new_spec)
|
||||
|
||||
|
||||
|
@ -2596,7 +2596,7 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
|
||||
# Don't emit a wsc under full manual mode to avoid increasing HLO size.
|
||||
if aval.sharding.mesh._are_all_axes_collective:
|
||||
return op
|
||||
if aval.sharding.mesh._are_all_axes_auto:
|
||||
if aval.sharding.mesh._are_all_axes_hidden:
|
||||
return op
|
||||
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
|
||||
# `return op` early and avoid bloating HLO size.
|
||||
@ -2609,7 +2609,7 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
|
||||
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
|
||||
if sharding_proto is None else sharding_proto)
|
||||
unspecified_dims = None
|
||||
if aval.sharding.mesh._any_axis_auto:
|
||||
if aval.sharding.mesh._any_axis_hidden:
|
||||
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes
|
||||
# as unspecified?
|
||||
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None}
|
||||
|
@ -2172,7 +2172,7 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
|
||||
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
|
||||
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
|
||||
for sp in a.sharding.spec])
|
||||
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
|
||||
if a.sharding.mesh._any_axis_hidden else a.sharding.spec)
|
||||
out.append(NamedSharding(
|
||||
_abstract_to_concrete_mesh(a.sharding.mesh), spec))
|
||||
else:
|
||||
|
@ -1883,7 +1883,7 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers,
|
||||
slice_sizes, unique_indices, indices_are_sorted,
|
||||
mode, fill_value):
|
||||
# TODO(yashkatariya): Write a proper gather sharding rule.
|
||||
if mesh_lib.get_abstract_mesh()._are_all_axes_auto: # type: ignore
|
||||
if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
|
||||
return None
|
||||
raise GatherShardingError(
|
||||
"Use `.at[...].get(out_specs=)` to provide output PartitionSpec for the"
|
||||
|
@ -49,7 +49,7 @@ def _get_array_abstraction_level(a): return a.array_abstraction_level
|
||||
|
||||
def call_sharding_rule(rule, num_out, *avals, **kwargs):
|
||||
if config.sharding_in_types.value:
|
||||
if rule is None and mesh_lib.get_abstract_mesh()._are_all_axes_auto: # type: ignore
|
||||
if rule is None and mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
|
||||
return None if num_out is None else [None] * num_out
|
||||
return rule(*avals, **kwargs)
|
||||
return None if num_out is None else [None] * num_out
|
||||
|
@ -103,8 +103,8 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
|
||||
|
||||
|
||||
class AxisTypes(enum.Enum):
|
||||
Auto = enum.auto()
|
||||
User = enum.auto()
|
||||
Hidden = enum.auto()
|
||||
Visible = enum.auto()
|
||||
Collective = enum.auto()
|
||||
|
||||
def __repr__(self):
|
||||
@ -198,9 +198,15 @@ class Mesh(contextlib.ContextDecorator):
|
||||
f"devices.ndim == {devices.ndim} and "
|
||||
f"len(axis_names) == {len(axis_names)}.")
|
||||
|
||||
axis_types = ({AxisTypes.Auto: axis_names} if axis_types is None else
|
||||
axis_types = ({AxisTypes.Hidden: axis_names} if axis_types is None else
|
||||
axis_types)
|
||||
axis_types_tuple = tuple(axis_types.items())
|
||||
if len(axis_names_to_types(axis_types).keys()) != len(axis_names):
|
||||
raise ValueError(
|
||||
"Number of axis names in axis_types should match the number of"
|
||||
f" axis_names. Got axis_names={axis_names} and"
|
||||
f" axis_types={axis_types}")
|
||||
|
||||
key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple)
|
||||
val = _mesh_object_dict.get(key, None)
|
||||
if val is not None:
|
||||
@ -356,16 +362,16 @@ class Mesh(contextlib.ContextDecorator):
|
||||
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_auto(self) -> bool:
|
||||
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
def _are_all_axes_hidden(self) -> bool:
|
||||
return all(t == AxisTypes.Hidden for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_collective(self) -> bool:
|
||||
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_auto(self) -> bool:
|
||||
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
def _any_axis_hidden(self) -> bool:
|
||||
return any(t == AxisTypes.Hidden for t in self.axis_types.keys())
|
||||
|
||||
|
||||
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
|
||||
@ -397,9 +403,14 @@ class AbstractMesh:
|
||||
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
|
||||
else:
|
||||
self._axis_names, self._axis_sizes = (), ()
|
||||
self.axis_types = ({AxisTypes.Auto: self._axis_names} if axis_types is None
|
||||
else axis_types)
|
||||
self.axis_types = ({AxisTypes.Hidden: self._axis_names}
|
||||
if axis_types is None else axis_types)
|
||||
self._axis_types_tuple = tuple(self.axis_types.items())
|
||||
if len(self._name_to_type.keys()) != len(self._axis_names):
|
||||
raise ValueError(
|
||||
"Number of axis names in axis_types should match the number of"
|
||||
f" axis_names in shape_tuple. Got axis_names={self._axis_names} and"
|
||||
f" axis_types={self.axis_types}")
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.shape_tuple, self._axis_types_tuple))
|
||||
@ -461,16 +472,16 @@ class AbstractMesh:
|
||||
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_auto(self) -> bool:
|
||||
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
def _are_all_axes_hidden(self) -> bool:
|
||||
return all(t == AxisTypes.Hidden for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_collective(self) -> bool:
|
||||
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_auto(self) -> bool:
|
||||
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
def _any_axis_hidden(self) -> bool:
|
||||
return any(t == AxisTypes.Hidden for t in self.axis_types.keys())
|
||||
|
||||
@property
|
||||
def devices(self):
|
||||
@ -535,7 +546,7 @@ def get_concrete_mesh():
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_mesh(mesh: Mesh):
|
||||
def use_mesh(mesh: Mesh):
|
||||
with (set_abstract_mesh(mesh.abstract_mesh),
|
||||
jax_config.sharding_in_types(True), set_concrete_mesh(mesh)):
|
||||
yield
|
||||
|
@ -42,7 +42,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.numpy import array_api_metadata
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.pjit import auto_mode, PartitionSpec
|
||||
from jax._src.pjit import hidden_mode, PartitionSpec
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.ops import scatter
|
||||
@ -781,8 +781,8 @@ class _IndexUpdateRef:
|
||||
fill_value=fill_value)
|
||||
if out_spec is not None:
|
||||
assert isinstance(out_spec, PartitionSpec)
|
||||
take = auto_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
|
||||
out_specs=out_spec)
|
||||
take = hidden_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
|
||||
out_specs=out_spec)
|
||||
return take(self.array, self.index)
|
||||
|
||||
def set(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
|
@ -2754,11 +2754,11 @@ def _get_new_mesh(axes: str | tuple[str, ...], axis_type: mesh_lib.AxisTypes):
|
||||
new_mesh = cur_mesh.update_axis_types({axis_type: axes}) # type: ignore
|
||||
return new_mesh
|
||||
|
||||
def auto_mode(fun, *, axes: str | tuple[str, ...], out_specs):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto)
|
||||
def hidden_mode(fun, *, axes: str | tuple[str, ...], out_specs):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden)
|
||||
def decorator(*args, **kwargs):
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
in_specs = tree_map(lambda a: core.modify_spec_for_auto(
|
||||
in_specs = tree_map(lambda a: core.modify_spec_for_hidden(
|
||||
a.sharding.spec, new_mesh), args)
|
||||
args = sharding_cast(args, in_specs)
|
||||
out = fun(*args, **kwargs)
|
||||
@ -2767,26 +2767,26 @@ def auto_mode(fun, *, axes: str | tuple[str, ...], out_specs):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def auto_mode_ctx(axes: str | tuple[str, ...]):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto)
|
||||
def hidden_axes(axes: str | tuple[str, ...]):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden)
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
yield
|
||||
|
||||
|
||||
def user_mode(fun, *, axes: str | tuple[str, ...], in_specs):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.User)
|
||||
def visible_mode(fun, *, axes: str | tuple[str, ...], in_specs):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Visible)
|
||||
def decorator(*args, **kwargs):
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
args = sharding_cast(args, in_specs)
|
||||
out = fun(*args, **kwargs)
|
||||
out_specs = tree_map(lambda o: core.modify_spec_for_auto(
|
||||
out_specs = tree_map(lambda o: core.modify_spec_for_hidden(
|
||||
o.sharding.spec, mesh_lib.get_abstract_mesh()), out)
|
||||
return sharding_cast(out, out_specs)
|
||||
return decorator
|
||||
|
||||
@contextlib.contextmanager
|
||||
def user_mode_ctx(axes: str | tuple[str, ...]):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.User)
|
||||
def visible_axes(axes: str | tuple[str, ...]):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Visible)
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
yield
|
||||
|
||||
|
@ -78,10 +78,10 @@ def _check_axis_type_consistency(mesh, parsed_pspec):
|
||||
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
|
||||
f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis'
|
||||
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
|
||||
if mesh_lib.AxisTypes.Auto not in mesh.axis_types and None in parsed_pspec:
|
||||
if mesh_lib.AxisTypes.Hidden not in mesh.axis_types and None in parsed_pspec:
|
||||
raise ValueError(
|
||||
f'PartitionSpec {parsed_pspec.get_partition_spec()} cannot contain'
|
||||
' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh'
|
||||
' `P.UNCONSTRAINED` when no mesh axis_types are `Hidden`. Got mesh'
|
||||
f' axis_types: {mesh.axis_types}')
|
||||
|
||||
|
||||
@ -439,7 +439,7 @@ class NamedSharding(jsharding.Sharding):
|
||||
# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra
|
||||
# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)`
|
||||
def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
|
||||
if mesh._any_axis_auto:
|
||||
if mesh._any_axis_hidden:
|
||||
dim_shardings, used_axes = [], [] # type: ignore
|
||||
for d in sdy_sharding.dimension_shardings:
|
||||
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open?
|
||||
@ -448,7 +448,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
|
||||
used_axes.extend(d.axes)
|
||||
remaining_axes = set(mesh.axis_names) - set(used_axes)
|
||||
replicated_axes = tuple(r for r in remaining_axes
|
||||
if mesh._name_to_type[r] == mesh_lib.AxisTypes.User)
|
||||
if mesh._name_to_type[r] == mesh_lib.AxisTypes.Visible)
|
||||
return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings,
|
||||
sdy_sharding.logical_device_ids, replicated_axes)
|
||||
return sdy_sharding
|
||||
@ -1774,9 +1774,9 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
|
||||
|
||||
for s in flatten_spec(sharding.spec):
|
||||
if sharding.mesh._name_to_type[s] in {
|
||||
mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Collective}:
|
||||
mesh_lib.AxisTypes.Hidden, mesh_lib.AxisTypes.Collective}:
|
||||
raise ValueError(
|
||||
'PartitionSpec cannot contain axis names that are of type Auto or'
|
||||
'PartitionSpec cannot contain axis names that are of type Hidden or'
|
||||
f' Collective. Got PartitionSpec: {sharding.spec} with axis name:'
|
||||
f' {s} or type: {sharding.mesh._name_to_type[s]}')
|
||||
return sharding
|
||||
|
@ -1549,12 +1549,12 @@ def with_and_without_mesh(f):
|
||||
))(with_mesh_from_kwargs(f))
|
||||
|
||||
def with_user_mesh(sizes, names, axis_types=None):
|
||||
axis_types = ({mesh_lib.AxisTypes.User: names}
|
||||
axis_types = ({mesh_lib.AxisTypes.Visible: names}
|
||||
if axis_types is None else axis_types)
|
||||
def decorator(fn):
|
||||
def mesh_fn(*args, **kwargs):
|
||||
mesh = create_mesh(sizes, names, axis_types=axis_types)
|
||||
with mesh_lib.set_mesh(mesh):
|
||||
with mesh_lib.use_mesh(mesh):
|
||||
return fn(*args, **kwargs, mesh=mesh)
|
||||
return mesh_fn
|
||||
return decorator
|
||||
|
@ -27,7 +27,11 @@ from jax._src.partition_spec import (
|
||||
PartitionSpec as PartitionSpec,
|
||||
)
|
||||
from jax._src.interpreters.pxla import Mesh as Mesh
|
||||
from jax._src.mesh import AbstractMesh as AbstractMesh
|
||||
from jax._src.mesh import (
|
||||
AbstractMesh as AbstractMesh,
|
||||
AxisTypes as AxisTypes,
|
||||
use_mesh as use_mesh
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
# Finalized 2024-10-01; remove after 2025-01-01.
|
||||
|
@ -1260,7 +1260,6 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jax.make_array_from_single_device_arrays(x.shape, s, [x, x])
|
||||
|
||||
|
||||
def test_gspmd_sharding_hash_eq(self):
|
||||
mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z'))
|
||||
ns = NamedSharding(mesh, P('x', 'y', 'z'))
|
||||
@ -1299,6 +1298,21 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'Mesh axis names cannot be None.'):
|
||||
jax.sharding.Mesh(jax.devices(), (None, 'x'))
|
||||
|
||||
def test_mesh_axis_types_mismatch(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Number of axis names in axis_types should match the number of'
|
||||
' axis_names'):
|
||||
jtu.create_mesh((2, 1), ('x', 'y'),
|
||||
axis_types={jax.sharding.AxisTypes.Hidden: 'x'})
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Number of axis names in axis_types should match the number of'
|
||||
' axis_names in shape_tuple'):
|
||||
jax.sharding.AbstractMesh((('x', 2), ('y', 1)),
|
||||
axis_types={jax.sharding.AxisTypes.Hidden: 'x'})
|
||||
|
||||
|
||||
@jtu.with_config(jax_use_shardy_partitioner=True)
|
||||
class ShardyShardingTest(jtu.JaxTestCase):
|
||||
|
@ -51,8 +51,8 @@ from jax._src import sharding_impls
|
||||
from jax._src.sharding_impls import (
|
||||
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
|
||||
SingleDeviceSharding, parse_flatten_op_sharding)
|
||||
from jax._src.pjit import (pjit, sharding_cast, auto_mode, user_mode,
|
||||
auto_mode_ctx, user_mode_ctx)
|
||||
from jax._src.pjit import (pjit, sharding_cast, hidden_mode, visible_mode,
|
||||
hidden_axes, visible_axes)
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes
|
||||
from jax._src.interpreters import pxla
|
||||
@ -5125,7 +5125,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jtu.with_user_mesh((1,), 'x')
|
||||
def test_broadcasting_nary_error(self, mesh):
|
||||
mesh2 = Mesh([jax.devices()[0]], 'y',
|
||||
axis_types={mesh_lib.AxisTypes.User: 'y'})
|
||||
axis_types={mesh_lib.AxisTypes.Visible: 'y'})
|
||||
|
||||
arr1 = jax.device_put(np.arange(8), NamedSharding(mesh, P()))
|
||||
arr2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
|
||||
@ -5688,7 +5688,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
|
||||
self.check_wsc_in_lowered(f.lower(arr).as_text())
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Auto: ('x', 'y')})
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Hidden: ('x', 'y')})
|
||||
def test_only_auto(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
|
||||
@ -5708,7 +5708,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_auto_user(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
|
||||
axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')})
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -5720,16 +5720,16 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
a = z @ x2
|
||||
return a
|
||||
|
||||
with mesh_lib.set_mesh(mesh):
|
||||
with mesh_lib.use_mesh(mesh):
|
||||
out = f(arr, arr.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
|
||||
lowered_text = f.lower(arr, arr.T).as_text()
|
||||
self.assertNotIn('unspecified_dims', lowered_text)
|
||||
|
||||
mesh2 = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.User: 'x',
|
||||
mesh_lib.AxisTypes.Auto: 'y'})
|
||||
with mesh_lib.set_mesh(mesh2):
|
||||
axis_types={mesh_lib.AxisTypes.Visible: 'x',
|
||||
mesh_lib.AxisTypes.Hidden: 'y'})
|
||||
with mesh_lib.use_mesh(mesh2):
|
||||
arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None)))
|
||||
out = f(arr, arr2)
|
||||
@ -5741,9 +5741,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertTrue(lowered_text.count("unspecified_dims") == 5)
|
||||
|
||||
mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.User: 'y',
|
||||
mesh_lib.AxisTypes.Auto: 'x'})
|
||||
with mesh_lib.set_mesh(mesh3):
|
||||
axis_types={mesh_lib.AxisTypes.Visible: 'y',
|
||||
mesh_lib.AxisTypes.Hidden: 'x'})
|
||||
with mesh_lib.use_mesh(mesh3):
|
||||
arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P('y', 'x')))
|
||||
out = f(arr, arr2)
|
||||
@ -5779,7 +5779,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
with auto_mode_ctx(axes=('x', 'y')):
|
||||
with hidden_axes(axes=('x', 'y')):
|
||||
y = sharding_cast(y, P(None, None))
|
||||
self.assertEqual(y.sharding.spec, P(None, None))
|
||||
z = jnp.sin(y)
|
||||
@ -5798,7 +5798,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
|
||||
axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')})
|
||||
def test_full_auto_to_full_user(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
@ -5807,7 +5807,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
with user_mode_ctx(axes=('x', 'y')):
|
||||
with visible_axes(axes=('x', 'y')):
|
||||
y = sharding_cast(y, P(None, 'y'))
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
@ -5833,7 +5833,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
with auto_mode_ctx('x'):
|
||||
with hidden_axes('x'):
|
||||
y = sharding_cast(y, P(None, 'y'))
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
@ -5860,7 +5860,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
x = x * 2
|
||||
with auto_mode_ctx('x'):
|
||||
with hidden_axes('x'):
|
||||
z = x @ y
|
||||
return z
|
||||
|
||||
@ -5871,9 +5871,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def test_sharding_cast_src_dst_mesh_mismatch(self):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.User: ('x', 'y')})
|
||||
axis_types={mesh_lib.AxisTypes.Visible: ('x', 'y')})
|
||||
mesh2 = jtu.create_mesh((2, 1), ('a', 'b'),
|
||||
axis_types={mesh_lib.AxisTypes.User: ('a', 'b')})
|
||||
axis_types={mesh_lib.AxisTypes.Visible: ('a', 'b')})
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
f = lambda x: sharding_cast(x, NamedSharding(mesh2, P('a', 'b')))
|
||||
@ -5881,7 +5881,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
ValueError, "Mesh shape of the input.*does not match"):
|
||||
f(arr)
|
||||
|
||||
with mesh_lib.set_mesh(mesh):
|
||||
with mesh_lib.use_mesh(mesh):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Mesh shape of the input.*does not match"):
|
||||
jax.jit(f)(arr)
|
||||
@ -5923,15 +5923,15 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
auto_mesh = get_abstract_mesh().update_axis_types({AxisTypes.Auto: 'x'})
|
||||
auto_mesh = get_abstract_mesh().update_axis_types({AxisTypes.Hidden: 'x'})
|
||||
with set_abstract_mesh(auto_mesh):
|
||||
x = sharding_cast(x, P(None, None))
|
||||
return x
|
||||
|
||||
self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.User: 'x'})
|
||||
self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.Visible: 'x'})
|
||||
out = f(arr)
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'})
|
||||
self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'})
|
||||
|
||||
@jtu.with_user_mesh((2,), 'x')
|
||||
def test_inputs_different_context(self, mesh):
|
||||
@ -5939,18 +5939,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
auto_mesh = jax.make_mesh((2,), 'x', axis_types={AxisTypes.Auto: 'x'})
|
||||
with mesh_lib.set_mesh(auto_mesh):
|
||||
auto_mesh = jax.make_mesh((2,), 'x', axis_types={AxisTypes.Hidden: 'x'})
|
||||
with mesh_lib.use_mesh(auto_mesh):
|
||||
arr2 = jnp.ones(8)
|
||||
self.assertDictEqual(arr2.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'})
|
||||
self.assertDictEqual(arr2.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'})
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return x, y
|
||||
|
||||
out1, out2 = f(arr, arr2)
|
||||
self.assertDictEqual(out1.sharding.mesh.axis_types, {AxisTypes.User: 'x'})
|
||||
self.assertDictEqual(out2.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'})
|
||||
self.assertDictEqual(out1.sharding.mesh.axis_types, {AxisTypes.Visible: 'x'})
|
||||
self.assertDictEqual(out2.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'})
|
||||
|
||||
@jtu.with_user_mesh((2,), 'x')
|
||||
def test_output_different_context_error(self, mesh):
|
||||
@ -5958,7 +5958,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
|
||||
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
|
||||
auto_mesh = jax.make_mesh((2,), 'x',
|
||||
axis_types={AxisTypes.Auto: 'x'}).abstract_mesh
|
||||
axis_types={AxisTypes.Hidden: 'x'}).abstract_mesh
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
@ -5972,17 +5972,17 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
with auto_mode_ctx('x'):
|
||||
with hidden_axes('x'):
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
|
||||
return out
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "PartitionSpec cannot contain axis names.*Auto"):
|
||||
ValueError, "PartitionSpec cannot contain axis names.*Hidden"):
|
||||
g(arr1, arr2)
|
||||
|
||||
@jtu.with_user_mesh((2, 2, 2), ('x', 'y', 'z'),
|
||||
axis_types={AxisTypes.User: ('x', 'y'),
|
||||
AxisTypes.Auto: 'z'})
|
||||
axis_types={AxisTypes.Visible: ('x', 'y'),
|
||||
AxisTypes.Hidden: 'z'})
|
||||
def test_out_sharding_mix_axis_types(self, mesh):
|
||||
np_inp = np.arange(16).reshape(4, 2, 2)
|
||||
s = NamedSharding(mesh, P('x', None, None))
|
||||
@ -6011,7 +6011,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(auto_mode, axes='x', out_specs=P('x', None))
|
||||
@partial(hidden_mode, axes='x', out_specs=P('x', None))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
@ -6035,13 +6035,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
|
||||
axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')})
|
||||
def test_full_user_mode(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(user_mode, axes=('x', 'y'), in_specs=P('x', 'y'))
|
||||
@partial(visible_mode, axes=('x', 'y'), in_specs=P('x', 'y'))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
z = jnp.sin(y)
|
||||
@ -6064,14 +6064,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
core.jaxpr_as_fun(jaxpr)(arr) # doesn't crash
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.User: 'x',
|
||||
mesh_lib.AxisTypes.Auto: 'y'})
|
||||
axis_types={mesh_lib.AxisTypes.Visible: 'x',
|
||||
mesh_lib.AxisTypes.Hidden: 'y'})
|
||||
def test_mix_to_full_user_mode(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(user_mode, axes='y', in_specs=P('x', 'y'))
|
||||
@partial(visible_mode, axes='y', in_specs=P('x', 'y'))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
z = jnp.sin(y)
|
||||
@ -6091,13 +6091,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
|
||||
axis_types={mesh_lib.AxisTypes.Hidden: ('x', 'y')})
|
||||
def test_full_auto_to_partial_user(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(user_mode, axes='y', in_specs=P(None, 'y'))
|
||||
@partial(visible_mode, axes='y', in_specs=P(None, 'y'))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
|
Loading…
x
Reference in New Issue
Block a user