diff --git a/jax/BUILD b/jax/BUILD index 285765549..aed97b2a3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -866,6 +866,7 @@ pytype_strict_library( ":partition_spec", ":sharding", ":sharding_specs", + ":source_info_util", ":tree_util", ":util", ":xla_bridge", diff --git a/jax/_src/core.py b/jax/_src/core.py index 5017f7f5f..c33f28824 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1696,6 +1696,9 @@ class ShapedArray(UnshapedArray): self.weak_type = weak_type if config.sharding_in_types.value: self.sharding = get_sharding(sharding, len(self.shape)) + if not isinstance(self.sharding.mesh, mesh_lib.AbstractMesh): + raise ValueError( + f"Mesh of an aval must be an AbstractMesh. Got {self.sharding.mesh}") def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 2a59d5f5e..f4cfbb476 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -64,7 +64,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo from jax._src.sharding_impls import (PmapSharding, NamedSharding, - PartitionSpec as P) + PartitionSpec as P, canonicalize_sharding) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis, safe_map, safe_zip, split_list, weakref_lru_cache) @@ -586,6 +586,8 @@ def _convert_element_type( isinstance(operand, Array)): sharding = operand.sharding + sharding = canonicalize_sharding(sharding, check_mesh_consistency=False) # type: ignore + if (warn_on_complex_to_real_cast and dtypes.issubdtype(old_dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): @@ -1431,6 +1433,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, if not config.sharding_in_types.value and sharding is not None: raise NotImplementedError("sharding argument to broadcast_in_dim is only " "allowed when sharding_in_types config is on.") + sharding = canonicalize_sharding(sharding) if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array) and sharding is None): return operand @@ -1505,7 +1508,7 @@ def reshape(operand: ArrayLike, new_sizes: Shape, return operand else: dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) - + sharding = canonicalize_sharding(sharding) return reshape_p.bind( operand, *dyn_shape, new_sizes=tuple(static_new_sizes), dimensions=None if dims is None or same_dims else dims, @@ -1947,7 +1950,7 @@ def iota(dtype: DTypeLike, size: int) -> Array: return broadcasted_iota(dtype, (size,), 0) def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int, - _sharding=None) -> Array: + sharding=None) -> Array: """Convenience wrapper around ``iota``.""" dtype = dtypes.canonicalize_dtype(dtype) shape = canonicalize_shape(shape) @@ -1955,11 +1958,12 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int, static_shape = [None if isinstance(d, core.Tracer) else d for d in shape] dimension = core.concrete_or_error( int, dimension, "dimension argument of lax.broadcasted_iota") - if not config.sharding_in_types.value and _sharding is not None: + if not config.sharding_in_types.value and sharding is not None: raise NotImplementedError('sharding support for broadcasted_iota is not ' 'implemented outside of sharding_in_types mode.') + sharding = canonicalize_sharding(sharding) return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape), - dimension=dimension, sharding=_sharding) + dimension=dimension, sharding=sharding) def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array: """Like numpy.eye, create a 2D array with ones on a diagonal.""" @@ -5560,7 +5564,7 @@ def _compute_argminmax(value_comparator, get_identity, axis, = axes indices = broadcasted_iota( index_dtype, np.shape(operand), axis, - _sharding=operand.sharding if config.sharding_in_types.value else None) + sharding=operand.sharding if config.sharding_in_types.value else None) res = reduce([operand, indices], [get_identity(operand.dtype), np.array(0, index_dtype)], _ArgMinMaxReducer(value_comparator), diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 7566e6bf3..dd08ca1e9 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -671,7 +671,7 @@ def _one_hot(x: Array, num_classes: int, *, else: rhs_sharding = None rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis, - _sharding=rhs_sharding) + sharding=rhs_sharding) return (lhs == rhs).astype(dtype) # TODO(slebedev): Change the type of `x` to `ArrayLike`. diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 2681cbc81..c4848503d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -54,7 +54,7 @@ from jax._src.array import ArrayImpl from jax._src.core import ShapedArray from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax as lax_internal -from jax._src.lax.lax import ( PrecisionLike,_array_copy, +from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc from jax._src.numpy import reductions @@ -69,8 +69,9 @@ from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, ceil_of_ratio, partition_list, safe_zip, set_module, unzip2, tuple_replace) -from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, - PartitionSpec as P) +from jax.sharding import Sharding +from jax._src.sharding_impls import (SingleDeviceSharding, NamedSharding, + PartitionSpec as P, canonicalize_sharding) from jax.tree_util import tree_flatten, tree_leaves, tree_map import numpy as np import opt_einsum @@ -9873,6 +9874,7 @@ def _einsum( if out_type is not None and not config.sharding_in_types.value: raise NotImplementedError("out_type only works when sharding_in_types " "config is True.") + out_type = canonicalize_sharding(out_type) if out_type is not None and not isinstance(out_type, NamedSharding): raise NotImplementedError( "`out_type` argument of `einsum` only supports NamedSharding instances." diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 45c45ce11..80f900052 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -67,7 +67,8 @@ from jax._src.sharding import Sharding from jax._src.sharding_impls import ( NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, - ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, parse_flatten_op_sharding) + ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, + parse_flatten_op_sharding, canonicalize_sharding) from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef from jax._src.traceback_util import api_boundary @@ -2670,13 +2671,20 @@ batching.skippable_batchers[sharding_constraint_p] = lambda _: () def sharding_cast(xs, shardings): if isinstance(shardings, NamedSharding): - return tree_map(lambda x: sharding_cast_p.bind( - x, src_sharding=x.sharding, dst_sharding=shardings), xs) + return tree_map( + lambda x: sharding_cast_p.bind( + x, src_sharding=x.sharding, dst_sharding=canonicalize_sharding( + shardings, check_mesh_consistency=False)), + xs) x_flat, treedef = tree_flatten(xs) shardings_flat = flatten_axes("sharding_cast shardings", treedef, shardings) - out_flat = [sharding_cast_p.bind(x, src_sharding=x.sharding, dst_sharding=s) - for x, s in safe_zip(x_flat, shardings_flat)] + out_flat = [ + sharding_cast_p.bind( + x, src_sharding=x.sharding, + dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False)) + for x, s in safe_zip(x_flat, shardings_flat) + ] return tree_unflatten(treedef, out_flat) sharding_cast_p = core.Primitive('sharding_cast') diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 69e2adc4d..e98957af2 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -24,11 +24,13 @@ import math from typing import Any, NamedTuple, Union, cast from jax._src import core +from jax._src import config from jax._src import mesh as mesh_lib -from jax._src import sharding +from jax._src import sharding as jsharding from jax._src import sharding_specs from jax._src import tree_util from jax._src import util +from jax._src import source_info_util from jax._src import xla_bridge from jax._src import mesh_utils from jax._src.lib import xla_client as xc @@ -45,7 +47,7 @@ Device = xc.Device Index = tuple[slice, ...] XLADeviceAssignment = tuple[Device, ...] # TODO(yashkatariya): Remove this after 3 months of deprecation. -XLACompatibleSharding = sharding.Sharding +XLACompatibleSharding = jsharding.Sharding @dataclasses.dataclass(frozen=True) class TransferToMemoryKind: @@ -219,7 +221,7 @@ def named_sharding_to_xla_hlo_sharding( @use_cpp_class(xc.NamedSharding) -class NamedSharding(sharding.Sharding): +class NamedSharding(jsharding.Sharding): r"""A :class:`NamedSharding` expresses sharding using named axes. A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and @@ -388,9 +390,6 @@ class NamedSharding(sharding.Sharding): spec = PartitionSpec(*spec) return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind) - def with_mesh(self, new_mesh: mesh_lib.Mesh) -> NamedSharding: - return NamedSharding(new_mesh, self.spec, memory_kind=self.memory_kind) - def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) @@ -415,7 +414,7 @@ def get_replicated_hlo_sharding(): @use_cpp_class(xc.SingleDeviceSharding) -class SingleDeviceSharding(sharding.Sharding): +class SingleDeviceSharding(jsharding.Sharding): """A :class:`Sharding` that places its data on a single device. Args: @@ -503,7 +502,7 @@ def pmap_sharding_devices_indices_map( @use_cpp_class(xc.PmapSharding) -class PmapSharding(sharding.Sharding): +class PmapSharding(jsharding.Sharding): """Describes a sharding used by :func:`jax.pmap`.""" devices: np.ndarray sharding_spec: sharding_specs.ShardingSpec @@ -713,7 +712,7 @@ def _positional_sharding_to_xla_hlo_sharding( return xc.HloSharding.from_proto(pbuf) -class PositionalSharding(sharding.Sharding): +class PositionalSharding(jsharding.Sharding): _devices: tuple[xc.Device, ...] _memory_kind: str | None _ids: np.ndarray # dtype DeviceIdSet @@ -820,7 +819,7 @@ class PositionalSharding(sharding.Sharding): def is_fully_replicated(self) -> bool: return self.shape == (1,) * self.ndim - # sharding.Sharding interface + # jsharding.Sharding interface @property def _device_assignment(self) -> XLADeviceAssignment: @@ -868,7 +867,7 @@ class DeviceIdSet: @use_cpp_class(xc.GSPMDSharding) -class GSPMDSharding(sharding.Sharding): +class GSPMDSharding(jsharding.Sharding): _devices: tuple[Device, ...] _hlo_sharding: xc.HloSharding _memory_kind: str | None @@ -1122,7 +1121,7 @@ def prepare_axis_resources(axis_resources, arg_name, for entry in entries: if isinstance(entry, (UnspecifiedValue, AUTO)) or entry is None: new_entries.append(entry) - elif isinstance(entry, sharding.Sharding): + elif isinstance(entry, jsharding.Sharding): if isinstance(entry, PmapSharding): raise ValueError(f'One of {what} got sharding {entry} which is not ' 'allowed.') @@ -1138,7 +1137,7 @@ def prepare_axis_resources(axis_resources, arg_name, def _check_unique_resources(axis_resources, arg_name): for arg_axis_resources in axis_resources: if not arg_axis_resources: continue - if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, sharding.Sharding)): + if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, jsharding.Sharding)): continue constrained_dims = [d for d in arg_axis_resources if d is not None] resource_counts = collections.Counter( @@ -1371,7 +1370,7 @@ class NonUniformShardingError(ValueError): def get_process_index_and_count( - tensor_sharding: sharding.Sharding, dim: int, ndims: int) -> tuple[int, int]: + tensor_sharding: jsharding.Sharding, dim: int, ndims: int) -> tuple[int, int]: """Get current process index and number of unique processes for given dimension. This function facilitates mapping of process-level data to individual @@ -1486,7 +1485,7 @@ def get_process_index_and_count( def local_to_global_shape( - sharding: sharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]: + sharding: jsharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]: """Computes the global shape given the per process if possible. The returned shape will have the size of the global tensor in that dimension @@ -1545,7 +1544,7 @@ def local_to_global_shape( def num_addressable_indices( - tensor_sharding: sharding.Sharding, dim: int, global_shape: Shape) -> int: + tensor_sharding: jsharding.Sharding, dim: int, global_shape: Shape) -> int: """Returns the number of indices for given dimension this host has access to. Each host can have multiple number of devices that are spanning @@ -1579,7 +1578,7 @@ def num_addressable_indices( """ # TODO(sandler, yashkatariya): Consider making this function public. addressables = tensor_sharding.addressable_devices_indices_map(global_shape) - addressables = cast(Mapping[sharding.Device, Index], addressables) + addressables = cast(Mapping[jsharding.Device, Index], addressables) num_unique_slices = len({ _slice_as_tuple(addressable[dim]) for addressable in addressables.values() }) @@ -1596,7 +1595,7 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: new_op_sharding.tile_assignment_dimensions = tad return xc.HloSharding.from_proto(new_op_sharding) -def is_single_device_sharding(sharding: sharding.Sharding) -> bool: +def is_single_device_sharding(sharding: jsharding.Sharding) -> bool: # Special case PmapSharding here because PmapSharding maps away an axis # and needs to be handled separately.test_pjit_single_device_sharding_add return sharding.num_devices == 1 and not isinstance(sharding, PmapSharding) @@ -1625,7 +1624,7 @@ def make_key_array_phys_sharding(aval, sharding): def physical_sharding( - aval, sharding: sharding.Sharding) -> sharding.Sharding: + aval, sharding: jsharding.Sharding) -> jsharding.Sharding: return make_key_array_phys_sharding(aval, sharding) @@ -1642,7 +1641,7 @@ def get_logical_gspmd_sharding(aval, phys_sharding): return GSPMDSharding(phys_sharding._device_assignment, xc.HloSharding.from_proto(logical_op_sharding)) -def check_replicated_trailing_dims(sharding: sharding.Sharding, aval): +def check_replicated_trailing_dims(sharding: jsharding.Sharding, aval): if isinstance(sharding, PmapSharding): return phys_aval = core.physical_aval(aval) @@ -1655,7 +1654,7 @@ def check_replicated_trailing_dims(sharding: sharding.Sharding, aval): f" sharding: {sharding}, partitions: {partitions}, " f"num_trailing_dims: {num_trailing_dims}") -def logical_sharding(aval, phys_sharding) -> sharding.Sharding: +def logical_sharding(aval, phys_sharding) -> jsharding.Sharding: # The trailing dims should always be replicated. check_replicated_trailing_dims(phys_sharding, aval) @@ -1695,6 +1694,44 @@ def _gspmd_to_named_sharding_via_mesh( mesh, parsed_pspec.get_partition_spec(), parsed_pspec, out_s.memory_kind) +def flatten_spec(spec): + out = [] + for s in spec: + if s is None: + continue + if isinstance(s, tuple): + out.extend(s) + else: + out.append(s) + return out + +def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, + check_mesh_consistency: bool = True + ) -> NamedSharding | None: + if not config.sharding_in_types.value: + return sharding # type: ignore + if sharding is None: + return sharding + + if isinstance(sharding, PartitionSpec): + sharding = NamedSharding(mesh_lib.get_abstract_mesh(), sharding) # type: ignore + else: + if (check_mesh_consistency and + sharding.mesh != mesh_lib.get_abstract_mesh()): + raise ValueError( + f'Context mesh {mesh_lib.get_abstract_mesh()} should match the mesh' + f' of sharding {sharding.mesh}. This error occurs at source: ' + f' {source_info_util.summarize(source_info_util.current())}') + + for s in flatten_spec(sharding.spec): + if sharding.mesh._name_to_type[s] in { + mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Collective}: + raise ValueError( + 'PartitionSpec cannot contain axis names that are of type Auto or' + f' Collective. Got PartitionSpec: {sharding.spec} with axis name:' + f' {s} or type: {sharding.mesh._name_to_type[s]}') + return sharding + def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], *, devices: Sequence[xc.Device] | None = None, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3fcc5c81a..1d9dee768 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4858,8 +4858,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def f(x, y): - out = jnp.einsum('xy,yz->xz', x, y, - out_type=NamedSharding(x.sharding.mesh, P('x', None))) + out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None)) self.assertEqual(out.sharding.spec, P('x', None)) return jnp.sum(out) @@ -5155,9 +5154,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def g(x): x = x * 2 - y = jax.lax.broadcasted_iota( - x.dtype, (8, 2), 0, - _sharding=NamedSharding(mesh.abstract_mesh, P('x', 'y'))) + y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, sharding=P('x', 'y')) self.assertEqual(y.sharding.spec, P('x', 'y')) return x, y @@ -5186,8 +5183,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def g(x, y): - out = jnp.einsum('xy,yz->xz', x, y, - out_type=NamedSharding(x.sharding.mesh, P('x', None))) + out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None)) self.assertEqual(out.sharding.spec, P('x', None)) return out @@ -5216,9 +5212,9 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def h(x, y): - s = NamedSharding(x.sharding.mesh, P('x', None, 'y', None)) - out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=s) - self.assertEqual(out.sharding.spec, s.spec) + spec = P('x', None, 'y', None) + out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=spec) + self.assertEqual(out.sharding.spec, spec) return out arr1 = jax.device_put(np_inp.reshape(8, 4, 2), @@ -5263,8 +5259,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(y.sharding.spec, dst_spec) return y - new_s = (NamedSharding(mesh.abstract_mesh, dst_spec) - if use_sharding_arg else None) + new_s = dst_spec if use_sharding_arg else None out = f(arr, new_s) self.assertEqual(out.sharding, NamedSharding(mesh, dst_spec)) self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2) @@ -5659,15 +5654,14 @@ class ShardingInTypesTest(jtu.JaxTestCase): y = x * 2 auto_mesh = mesh_lib.get_abstract_mesh().with_axis_types( {mesh_lib.AxisTypes.Auto: ('x', 'y')}) - y = sharding_cast(y, y.sharding.with_mesh(auto_mesh)) with mesh_lib.set_abstract_mesh(auto_mesh): + y = sharding_cast(y, P(None, None)) self.assertEqual(y.sharding.spec, P(None, None)) z = jnp.sin(y) self.assertEqual(z.sharding.spec, P(None, None)) a = z @ z.T self.assertEqual(a.sharding.spec, P(None, None)) - a = sharding_cast( - a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None))) + a = sharding_cast(a, P('x', None)) self.assertEqual(a.sharding.spec, P('x', None)) return a @@ -5697,8 +5691,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(z.sharding.spec, P(None, 'y')) a = z @ z.T self.assertEqual(a.sharding.spec, P(None, None)) - a = sharding_cast( - a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None))) + a = sharding_cast(a, P(None, None)) self.assertEqual(a.sharding.spec, P(None, None)) return a @@ -5719,15 +5712,14 @@ class ShardingInTypesTest(jtu.JaxTestCase): y = x * 2 mix_mesh = mesh_lib.get_abstract_mesh().with_axis_types( {mesh_lib.AxisTypes.Auto: 'x', mesh_lib.AxisTypes.User: 'y'}) - y = sharding_cast(y, y.sharding.with_mesh(mix_mesh)) with mesh_lib.set_abstract_mesh(mix_mesh): + y = sharding_cast(y, P(None, 'y')) self.assertEqual(y.sharding.spec, P(None, 'y')) z = jnp.sin(y) self.assertEqual(z.sharding.spec, P(None, 'y')) a = z @ z.T self.assertEqual(a.sharding.spec, P(None, None)) - a = sharding_cast( - a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None))) + a = sharding_cast(a, P('x', None)) self.assertEqual(a.sharding.spec, P('x', None)) return a @@ -5759,8 +5751,10 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_sharding_cast_src_dst_mesh_mismatch(self): np_inp = np.arange(16.).reshape(8, 2) - mesh = jtu.create_mesh((2, 1), ('x', 'y')) - mesh2 = jtu.create_mesh((2, 1), ('a', 'b')) + mesh = jtu.create_mesh((2, 1), ('x', 'y'), + axis_types={mesh_lib.AxisTypes.User: ('x', 'y')}) + mesh2 = jtu.create_mesh((2, 1), ('a', 'b'), + axis_types={mesh_lib.AxisTypes.User: ('a', 'b')}) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) f = lambda x: sharding_cast(x, NamedSharding(mesh2, P('a', 'b'))) @@ -5812,7 +5806,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def f(x): auto_mesh = get_abstract_mesh().with_axis_types({AxisTypes.Auto: 'x'}) with set_abstract_mesh(auto_mesh): - x = sharding_cast(x, x.sharding.with_mesh(auto_mesh)) + x = sharding_cast(x, P(None, None)) return x self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.User: 'x'}) @@ -5850,13 +5844,23 @@ class ShardingInTypesTest(jtu.JaxTestCase): @jax.jit def f(x, y): out = jnp.einsum('xy,yz->xz', x, y, - out_type=NamedSharding(auto_mesh, P('x', None))) + out_type=NamedSharding(auto_mesh, P(None, None))) return out with self.assertRaisesRegex( - ValueError, "context mesh.* should match the aval mesh"): + ValueError, "Context mesh.*should match the mesh of sharding"): f(arr1, arr2) + @jax.jit + def g(x, y): + with mesh_lib.set_abstract_mesh(auto_mesh): + out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None)) + return out + + with self.assertRaisesRegex( + ValueError, "PartitionSpec cannot contain axis names.*Auto"): + g(arr1, arr2) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):