mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[sharding_in_types] Enforce AxisTypes to always exist if set_mesh
is used.
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions. During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded. PiperOrigin-RevId: 704911253
This commit is contained in:
parent
e88b578356
commit
b5e4fd161d
@ -455,6 +455,7 @@ pytype_strict_library(
|
||||
":dtypes",
|
||||
":effects",
|
||||
":mesh",
|
||||
":partition_spec",
|
||||
":pretty_printer",
|
||||
":source_info_util",
|
||||
":traceback_util",
|
||||
@ -558,6 +559,7 @@ pytype_strict_library(
|
||||
":layout",
|
||||
":op_shardings",
|
||||
":partial_eval",
|
||||
":partition_spec",
|
||||
":path",
|
||||
":pickle_util",
|
||||
":sharding",
|
||||
|
@ -39,6 +39,7 @@ from jax._src import config
|
||||
from jax._src import effects
|
||||
from jax._src import compute_on
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton
|
||||
from jax._src.errors import (
|
||||
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
|
||||
TracerIntegerConversionError, UnexpectedTracerError)
|
||||
@ -1599,13 +1600,30 @@ def _invalid_shape_error(shape: Shape, context: str=""):
|
||||
|
||||
return TypeError(msg)
|
||||
|
||||
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
|
||||
# Collective too.
|
||||
def _maybe_modify_sharding(sharding):
|
||||
if mesh_lib.AxisTypes.Auto not in sharding.mesh.axis_types:
|
||||
return sharding
|
||||
|
||||
new_spec = []
|
||||
for s in sharding.spec:
|
||||
if s is None or isinstance(s, UnconstrainedSingleton):
|
||||
new_spec.append(s)
|
||||
else:
|
||||
temp_s = s[0] if isinstance(s, tuple) else s
|
||||
new_spec.append(
|
||||
P.UNCONSTRAINED
|
||||
if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
|
||||
return sharding.with_spec(new_spec)
|
||||
|
||||
|
||||
def get_sharding(sharding, ndim):
|
||||
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
|
||||
if sharding is not None:
|
||||
assert len(sharding.spec) == ndim
|
||||
return sharding
|
||||
return _maybe_modify_sharding(sharding)
|
||||
|
||||
context_mesh = mesh_lib.get_abstract_mesh()
|
||||
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
|
||||
@ -1675,9 +1693,7 @@ class ShapedArray(UnshapedArray):
|
||||
dt_str = dt_str.replace('void', 'float0')
|
||||
if hasattr(self, 'sharding') and self.sharding is not None:
|
||||
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
|
||||
axis_types = self.sharding.mesh.axis_types
|
||||
axt = _get_axis_type_str(axis_types) if axis_types is not None else ''
|
||||
return f'{dt_str}[{shapestr}]{axt}'
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
else:
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
@ -1689,26 +1705,13 @@ class ShapedArray(UnshapedArray):
|
||||
raise TypeError("len() of unsized object") from err # same as numpy error
|
||||
|
||||
|
||||
def _get_axis_type_str(axis_types):
|
||||
from jax._src.mesh import AxisTypes # type: ignore
|
||||
|
||||
out = []
|
||||
for t, axes in axis_types.items():
|
||||
a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes
|
||||
if t == AxisTypes.Collective:
|
||||
out.append(f"C:{a}")
|
||||
elif t == AxisTypes.User:
|
||||
out.append(f"U:{a}")
|
||||
else:
|
||||
assert t == AxisTypes.Auto
|
||||
out.append(f"A:{a}")
|
||||
return f"{{{', '.join(out)}}}"
|
||||
|
||||
def _get_shape_sharding_str(shape, spec):
|
||||
out = []
|
||||
for s1, s2 in zip(shape, spec):
|
||||
if s2 is None:
|
||||
out.append(f"{s1}")
|
||||
elif isinstance(s2, UnconstrainedSingleton):
|
||||
out.append(f"{s1}")
|
||||
elif isinstance(s2, tuple):
|
||||
ss = ','.join(s for s in s2)
|
||||
out.append(f"{s1}@({ss})")
|
||||
|
@ -50,6 +50,7 @@ from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.sharding_impls import AUTO
|
||||
from jax._src.partition_spec import UnconstrainedSingleton
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib.mlir import dialects, ir, passmanager
|
||||
@ -2524,12 +2525,19 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
|
||||
# Don't emit a wsc under full manual mode to avoid increasing HLO size.
|
||||
if aval.sharding.mesh._are_all_axes_collective:
|
||||
return op
|
||||
if aval.sharding.mesh._are_all_axes_auto:
|
||||
return op
|
||||
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
|
||||
# `return op` early and avoid bloating HLO size.
|
||||
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
|
||||
if sharding_proto is None else sharding_proto)
|
||||
# TODO(yashkatariya): Enable this
|
||||
# unspecified_dims = (set(range(aval.ndim))
|
||||
# if aval.sharding.mesh._any_axis_collective else None)
|
||||
return wrap_with_sharding_op(ctx, op, aval, proto)
|
||||
unspecified_dims = None
|
||||
if aval.sharding.mesh._any_axis_collective:
|
||||
unspecified_dims = set(range(aval.ndim))
|
||||
elif aval.sharding.mesh._any_axis_auto:
|
||||
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec)
|
||||
if isinstance(s, UnconstrainedSingleton)}
|
||||
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
|
||||
|
||||
|
||||
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
|
||||
|
@ -63,7 +63,7 @@ from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
|
||||
@ -2123,11 +2123,13 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
|
||||
@lru_cache(maxsize=128)
|
||||
def _abstract_to_concrete_mesh(abstract_mesh):
|
||||
return mesh_lib.Mesh(
|
||||
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names)
|
||||
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
|
||||
axis_types=abstract_mesh.axis_types)
|
||||
|
||||
out = []
|
||||
for s, a in zip(shardings, avals):
|
||||
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
|
||||
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
|
||||
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
|
||||
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
|
||||
a.sharding.spec))
|
||||
else:
|
||||
|
@ -124,6 +124,7 @@ def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
|
||||
|
||||
_mesh_object_dict = {} # type: ignore
|
||||
|
||||
MeshAxisType = dict[AxisTypes, str | tuple[str, ...]]
|
||||
|
||||
class Mesh(contextlib.ContextDecorator):
|
||||
"""Declare the hardware resources available in the scope of this manager.
|
||||
@ -178,11 +179,11 @@ class Mesh(contextlib.ContextDecorator):
|
||||
|
||||
devices: np.ndarray
|
||||
axis_names: tuple[MeshAxisName, ...]
|
||||
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None
|
||||
axis_types: MeshAxisType | None
|
||||
|
||||
def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
|
||||
axis_names: str | Sequence[MeshAxisName],
|
||||
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
|
||||
axis_names: str | Sequence[MeshAxisName], *,
|
||||
axis_types: MeshAxisType | None = None):
|
||||
if not isinstance(devices, np.ndarray):
|
||||
devices = np.array(devices)
|
||||
if isinstance(axis_names, str):
|
||||
@ -216,7 +217,8 @@ class Mesh(contextlib.ContextDecorator):
|
||||
return self
|
||||
|
||||
def __reduce__(self):
|
||||
return (type(self), (self.devices, self.axis_names, self.axis_types))
|
||||
return (type(self), (self.devices, self.axis_names),
|
||||
{'axis_types': self.axis_types})
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Mesh):
|
||||
@ -348,7 +350,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
|
||||
@functools.cached_property
|
||||
def abstract_mesh(self):
|
||||
return AbstractMesh(self.shape_tuple, self.axis_types)
|
||||
return AbstractMesh(self.shape_tuple, axis_types=self.axis_types)
|
||||
|
||||
|
||||
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
|
||||
@ -373,8 +375,8 @@ class AbstractMesh:
|
||||
details.
|
||||
"""
|
||||
|
||||
def __init__(self, shape_tuple: tuple[tuple[str, int], ...],
|
||||
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
|
||||
def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *,
|
||||
axis_types: MeshAxisType | None = None):
|
||||
self.shape_tuple = shape_tuple
|
||||
self.axis_types = axis_types
|
||||
if self.shape_tuple:
|
||||
@ -434,6 +436,24 @@ class AbstractMesh:
|
||||
return False
|
||||
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_auto(self) -> bool:
|
||||
if self.axis_types is None:
|
||||
return False
|
||||
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_collective(self) -> bool:
|
||||
if self.axis_types is None:
|
||||
return False
|
||||
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_auto(self) -> bool:
|
||||
if self.axis_types is None:
|
||||
return False
|
||||
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
|
||||
@property
|
||||
def devices(self):
|
||||
_raise_value_error("devices")
|
||||
@ -474,6 +494,8 @@ def _raise_value_error(name):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_abstract_mesh(mesh: AbstractMesh):
|
||||
if mesh is not None and mesh.axis_types is None:
|
||||
raise RuntimeError('Please set the AxisTypes of Mesh.')
|
||||
prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh)
|
||||
try:
|
||||
yield
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
class _UnconstrainedPartitionSingleton:
|
||||
class UnconstrainedSingleton:
|
||||
|
||||
def __repr__(self):
|
||||
return "UNCONSTRAINED"
|
||||
@ -23,7 +23,7 @@ class _UnconstrainedPartitionSingleton:
|
||||
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
|
||||
# which the user wants XLA to assign the best partitioning.
|
||||
# TODO(yashkatariya): May rename to AUTO.
|
||||
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
|
||||
_UNCONSTRAINED_PARTITION = UnconstrainedSingleton()
|
||||
|
||||
|
||||
class PartitionSpec(tuple):
|
||||
|
@ -67,6 +67,19 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
|
||||
f"is also found in manual_axes: {_manual_axes}.") from None
|
||||
|
||||
|
||||
@util.cache(max_size=128, trace_context_in_key=False)
|
||||
def _check_axis_type_consistency(mesh, parsed_pspec):
|
||||
if mesh.axis_types is None:
|
||||
return
|
||||
for p in parsed_pspec:
|
||||
if p is not None:
|
||||
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
|
||||
raise ValueError(
|
||||
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
|
||||
f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis'
|
||||
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
|
||||
|
||||
|
||||
def hashed_index(x) -> int:
|
||||
# This works for both `pjit` indices and `pmap` indices (which might
|
||||
# have an integer instead of a slice).
|
||||
@ -1084,6 +1097,7 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
|
||||
PartitionSpec() if spec is None else spec,
|
||||
"NamedSharding spec", allow_unconstrained_dims=True)
|
||||
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
|
||||
_check_axis_type_consistency(mesh, parsed_pspec)
|
||||
return parsed_pspec
|
||||
|
||||
|
||||
@ -1673,7 +1687,8 @@ def _gspmd_to_named_sharding_via_mesh(
|
||||
|
||||
|
||||
def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
|
||||
*, devices: Sequence[xc.Device] | None = None) -> mesh_lib.Mesh:
|
||||
*, devices: Sequence[xc.Device] | None = None,
|
||||
axis_types: mesh_lib.MeshAxisType | None = None) -> mesh_lib.Mesh:
|
||||
"""Creates an efficient mesh with the shape and axis names specified.
|
||||
|
||||
This function attempts to automatically compute a good mapping from a set of
|
||||
@ -1735,4 +1750,4 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
|
||||
mesh_devices = mesh_utils.create_device_mesh(
|
||||
new_axis_shapes, devices,
|
||||
allow_split_physical_axes=allow_split_physical_axes)
|
||||
return mesh_lib.Mesh(mesh_devices, axis_names)
|
||||
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)
|
||||
|
@ -1443,26 +1443,28 @@ def with_and_without_mesh(f):
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))(with_mesh_from_kwargs(f))
|
||||
|
||||
def with_user_mesh(sizes, names):
|
||||
def with_user_mesh(sizes, names, axis_types=None):
|
||||
axis_types = ({mesh_lib.AxisTypes.User: names}
|
||||
if axis_types is None else axis_types)
|
||||
def decorator(fn):
|
||||
def mesh_fn(*args, **kwargs):
|
||||
mesh = create_mesh(sizes, names)
|
||||
mesh = create_mesh(sizes, names, axis_types=axis_types)
|
||||
with mesh_lib.set_mesh(mesh):
|
||||
return fn(*args, **kwargs, mesh=mesh)
|
||||
return mesh_fn
|
||||
return decorator
|
||||
|
||||
|
||||
def create_mesh(mesh_shape, axis_names, iota_order=False):
|
||||
def create_mesh(mesh_shape, axis_names, iota_order=False, axis_types=None):
|
||||
size = math.prod(mesh_shape)
|
||||
if len(jax.devices()) < size:
|
||||
raise unittest.SkipTest(f"Test requires {size} global devices.")
|
||||
if iota_order:
|
||||
devices = sorted(jax.devices(), key=lambda d: d.id)
|
||||
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
|
||||
return jax.sharding.Mesh(mesh_devices, axis_names)
|
||||
return jax.sharding.Mesh(mesh_devices, axis_names, axis_types=axis_types)
|
||||
else:
|
||||
return jax.make_mesh(mesh_shape, axis_names)
|
||||
return jax.make_mesh(mesh_shape, axis_names, axis_types=axis_types)
|
||||
|
||||
class _cached_property:
|
||||
null = object()
|
||||
|
@ -46,7 +46,8 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.core import Tracer
|
||||
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes, set_abstract_mesh
|
||||
from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, set_abstract_mesh,
|
||||
get_abstract_mesh)
|
||||
from jax._src.api import _shared_code_pmap, _prepare_pmap
|
||||
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
||||
windowed_reductions, convolution, fft, linalg,
|
||||
@ -536,7 +537,7 @@ def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
for i, sz in enumerate(aval.shape))
|
||||
if config.sharding_in_types.value:
|
||||
new_mesh = AbstractMesh(
|
||||
mesh.shape_tuple, {AxisTypes.Collective: mesh.axis_names})
|
||||
mesh.shape_tuple, axis_types={AxisTypes.Collective: mesh.axis_names})
|
||||
new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim))
|
||||
else:
|
||||
new_sharding = None
|
||||
@ -548,11 +549,9 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape))
|
||||
# TODO(yashkatariya): Reset the mesh properly based on the input avals if the
|
||||
# mesh of shard_map specifies collective axes.
|
||||
if config.sharding_in_types.value:
|
||||
spec = _names_to_pspec(names)._normalized_spec(aval.ndim)
|
||||
new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec)
|
||||
new_sharding = NamedSharding(get_abstract_mesh(), spec)
|
||||
else:
|
||||
new_sharding = None
|
||||
return aval.update(shape=new_shape, sharding=new_sharding)
|
||||
|
@ -4837,6 +4837,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('model', 'data'))
|
||||
def test_aval_repr(self, mesh):
|
||||
mesh = mesh.abstract_mesh
|
||||
aval = core.ShapedArray((128, 64), np.float32,
|
||||
sharding=NamedSharding(mesh, P('model', 'data')))
|
||||
self.assertEqual(aval.str_short(), 'float32[128@model,64@data]')
|
||||
@ -4977,21 +4978,21 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
|
||||
def test_broadcasting_nary_error(self):
|
||||
mesh1 = Mesh([jax.devices()[0]], 'x')
|
||||
mesh2 = Mesh([jax.devices()[0]], 'y')
|
||||
@jtu.with_user_mesh((1,), 'x')
|
||||
def test_broadcasting_nary_error(self, mesh):
|
||||
mesh2 = Mesh([jax.devices()[0]], 'y',
|
||||
axis_types={mesh_lib.AxisTypes.User: 'y'})
|
||||
|
||||
arr1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
|
||||
arr1 = jax.device_put(np.arange(8), NamedSharding(mesh, P()))
|
||||
arr2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
with config.sharding_in_types(True):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Mesh for all inputs should be equal"):
|
||||
f(arr1, arr2)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Mesh for all inputs should be equal"):
|
||||
f(arr1, arr2)
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_sin_unop(self, mesh):
|
||||
@ -5482,6 +5483,53 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertIn('@Sharding', f.lower(arr).as_text())
|
||||
|
||||
def test_auto_user(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@jax.jit
|
||||
def f(x, x2):
|
||||
y = x * 2
|
||||
z = jnp.sin(y)
|
||||
a = z @ x2
|
||||
return a
|
||||
|
||||
with mesh_lib.set_mesh(mesh):
|
||||
out = f(arr, arr.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
|
||||
lowered_text = f.lower(arr, arr.T).as_text()
|
||||
self.assertNotIn('unspecified_dims', lowered_text)
|
||||
|
||||
mesh2 = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.User: 'x',
|
||||
mesh_lib.AxisTypes.Auto: 'y'})
|
||||
with mesh_lib.set_mesh(mesh2):
|
||||
arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None)))
|
||||
out = f(arr, arr2)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh2, P('x', None)))
|
||||
lowered_text = f.lower(arr, arr2).as_text()
|
||||
self.assertTrue(lowered_text.count("unspecified_dims") == 3)
|
||||
|
||||
mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.User: 'y',
|
||||
mesh_lib.AxisTypes.Auto: 'x'})
|
||||
with mesh_lib.set_mesh(mesh3):
|
||||
arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P('y', 'x')))
|
||||
out = f(arr, arr2)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh3, P('x',)))
|
||||
lowered_text = f.lower(arr, arr2).as_text()
|
||||
self.assertTrue(lowered_text.count("unspecified_dims") == 4)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"AxisTypes should be the same in a tuple subset of PartitionSpec"):
|
||||
NamedSharding(mesh2, P(('x', 'y')))
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user