[sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.

If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
This commit is contained in:
Yash Katariya 2025-01-08 11:10:37 -08:00 committed by jax authors
parent c1a60c676a
commit 3848f0d2ac
8 changed files with 120 additions and 61 deletions

View File

@ -866,6 +866,7 @@ pytype_strict_library(
":partition_spec",
":sharding",
":sharding_specs",
":source_info_util",
":tree_util",
":util",
":xla_bridge",

View File

@ -1696,6 +1696,9 @@ class ShapedArray(UnshapedArray):
self.weak_type = weak_type
if config.sharding_in_types.value:
self.sharding = get_sharding(sharding, len(self.shape))
if not isinstance(self.sharding.mesh, mesh_lib.AbstractMesh):
raise ValueError(
f"Mesh of an aval must be an AbstractMesh. Got {self.sharding.mesh}")
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
if shape is None:

View File

@ -64,7 +64,7 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
PartitionSpec as P)
PartitionSpec as P, canonicalize_sharding)
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis,
safe_map, safe_zip, split_list, weakref_lru_cache)
@ -586,6 +586,8 @@ def _convert_element_type(
isinstance(operand, Array)):
sharding = operand.sharding
sharding = canonicalize_sharding(sharding, check_mesh_consistency=False) # type: ignore
if (warn_on_complex_to_real_cast and
dtypes.issubdtype(old_dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
@ -1431,6 +1433,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
if not config.sharding_in_types.value and sharding is not None:
raise NotImplementedError("sharding argument to broadcast_in_dim is only "
"allowed when sharding_in_types config is on.")
sharding = canonicalize_sharding(sharding)
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and
isinstance(operand, Array) and sharding is None):
return operand
@ -1505,7 +1508,7 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
return operand
else:
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
sharding = canonicalize_sharding(sharding)
return reshape_p.bind(
operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
dimensions=None if dims is None or same_dims else dims,
@ -1947,7 +1950,7 @@ def iota(dtype: DTypeLike, size: int) -> Array:
return broadcasted_iota(dtype, (size,), 0)
def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
_sharding=None) -> Array:
sharding=None) -> Array:
"""Convenience wrapper around ``iota``."""
dtype = dtypes.canonicalize_dtype(dtype)
shape = canonicalize_shape(shape)
@ -1955,11 +1958,12 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
dimension = core.concrete_or_error(
int, dimension, "dimension argument of lax.broadcasted_iota")
if not config.sharding_in_types.value and _sharding is not None:
if not config.sharding_in_types.value and sharding is not None:
raise NotImplementedError('sharding support for broadcasted_iota is not '
'implemented outside of sharding_in_types mode.')
sharding = canonicalize_sharding(sharding)
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
dimension=dimension, sharding=_sharding)
dimension=dimension, sharding=sharding)
def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array:
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
@ -5560,7 +5564,7 @@ def _compute_argminmax(value_comparator, get_identity,
axis, = axes
indices = broadcasted_iota(
index_dtype, np.shape(operand), axis,
_sharding=operand.sharding if config.sharding_in_types.value else None)
sharding=operand.sharding if config.sharding_in_types.value else None)
res = reduce([operand, indices],
[get_identity(operand.dtype), np.array(0, index_dtype)],
_ArgMinMaxReducer(value_comparator),

View File

@ -671,7 +671,7 @@ def _one_hot(x: Array, num_classes: int, *,
else:
rhs_sharding = None
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
_sharding=rhs_sharding)
sharding=rhs_sharding)
return (lhs == rhs).astype(dtype)
# TODO(slebedev): Change the type of `x` to `ArrayLike`.

View File

@ -54,7 +54,7 @@ from jax._src.array import ArrayImpl
from jax._src.core import ShapedArray
from jax._src.custom_derivatives import custom_jvp
from jax._src.lax import lax as lax_internal
from jax._src.lax.lax import ( PrecisionLike,_array_copy,
from jax._src.lax.lax import (PrecisionLike,_array_copy,
_sort_le_comparator, _sort_lt_comparator)
from jax._src.lib import xla_client as xc
from jax._src.numpy import reductions
@ -69,8 +69,9 @@ from jax._src.util import (
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
ceil_of_ratio, partition_list, safe_zip, set_module, unzip2,
tuple_replace)
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
PartitionSpec as P)
from jax.sharding import Sharding
from jax._src.sharding_impls import (SingleDeviceSharding, NamedSharding,
PartitionSpec as P, canonicalize_sharding)
from jax.tree_util import tree_flatten, tree_leaves, tree_map
import numpy as np
import opt_einsum
@ -9873,6 +9874,7 @@ def _einsum(
if out_type is not None and not config.sharding_in_types.value:
raise NotImplementedError("out_type only works when sharding_in_types "
"config is True.")
out_type = canonicalize_sharding(out_type)
if out_type is not None and not isinstance(out_type, NamedSharding):
raise NotImplementedError(
"`out_type` argument of `einsum` only supports NamedSharding instances."

View File

@ -67,7 +67,8 @@ from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
NamedSharding, GSPMDSharding,
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, parse_flatten_op_sharding)
ParsedPartitionSpec, get_single_pspec, prepare_axis_resources,
parse_flatten_op_sharding, canonicalize_sharding)
from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout
from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef
from jax._src.traceback_util import api_boundary
@ -2670,13 +2671,20 @@ batching.skippable_batchers[sharding_constraint_p] = lambda _: ()
def sharding_cast(xs, shardings):
if isinstance(shardings, NamedSharding):
return tree_map(lambda x: sharding_cast_p.bind(
x, src_sharding=x.sharding, dst_sharding=shardings), xs)
return tree_map(
lambda x: sharding_cast_p.bind(
x, src_sharding=x.sharding, dst_sharding=canonicalize_sharding(
shardings, check_mesh_consistency=False)),
xs)
x_flat, treedef = tree_flatten(xs)
shardings_flat = flatten_axes("sharding_cast shardings", treedef, shardings)
out_flat = [sharding_cast_p.bind(x, src_sharding=x.sharding, dst_sharding=s)
for x, s in safe_zip(x_flat, shardings_flat)]
out_flat = [
sharding_cast_p.bind(
x, src_sharding=x.sharding,
dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False))
for x, s in safe_zip(x_flat, shardings_flat)
]
return tree_unflatten(treedef, out_flat)
sharding_cast_p = core.Primitive('sharding_cast')

View File

@ -24,11 +24,13 @@ import math
from typing import Any, NamedTuple, Union, cast
from jax._src import core
from jax._src import config
from jax._src import mesh as mesh_lib
from jax._src import sharding
from jax._src import sharding as jsharding
from jax._src import sharding_specs
from jax._src import tree_util
from jax._src import util
from jax._src import source_info_util
from jax._src import xla_bridge
from jax._src import mesh_utils
from jax._src.lib import xla_client as xc
@ -45,7 +47,7 @@ Device = xc.Device
Index = tuple[slice, ...]
XLADeviceAssignment = tuple[Device, ...]
# TODO(yashkatariya): Remove this after 3 months of deprecation.
XLACompatibleSharding = sharding.Sharding
XLACompatibleSharding = jsharding.Sharding
@dataclasses.dataclass(frozen=True)
class TransferToMemoryKind:
@ -219,7 +221,7 @@ def named_sharding_to_xla_hlo_sharding(
@use_cpp_class(xc.NamedSharding)
class NamedSharding(sharding.Sharding):
class NamedSharding(jsharding.Sharding):
r"""A :class:`NamedSharding` expresses sharding using named axes.
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
@ -388,9 +390,6 @@ class NamedSharding(sharding.Sharding):
spec = PartitionSpec(*spec)
return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind)
def with_mesh(self, new_mesh: mesh_lib.Mesh) -> NamedSharding:
return NamedSharding(new_mesh, self.spec, memory_kind=self.memory_kind)
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
@ -415,7 +414,7 @@ def get_replicated_hlo_sharding():
@use_cpp_class(xc.SingleDeviceSharding)
class SingleDeviceSharding(sharding.Sharding):
class SingleDeviceSharding(jsharding.Sharding):
"""A :class:`Sharding` that places its data on a single device.
Args:
@ -503,7 +502,7 @@ def pmap_sharding_devices_indices_map(
@use_cpp_class(xc.PmapSharding)
class PmapSharding(sharding.Sharding):
class PmapSharding(jsharding.Sharding):
"""Describes a sharding used by :func:`jax.pmap`."""
devices: np.ndarray
sharding_spec: sharding_specs.ShardingSpec
@ -713,7 +712,7 @@ def _positional_sharding_to_xla_hlo_sharding(
return xc.HloSharding.from_proto(pbuf)
class PositionalSharding(sharding.Sharding):
class PositionalSharding(jsharding.Sharding):
_devices: tuple[xc.Device, ...]
_memory_kind: str | None
_ids: np.ndarray # dtype DeviceIdSet
@ -820,7 +819,7 @@ class PositionalSharding(sharding.Sharding):
def is_fully_replicated(self) -> bool:
return self.shape == (1,) * self.ndim
# sharding.Sharding interface
# jsharding.Sharding interface
@property
def _device_assignment(self) -> XLADeviceAssignment:
@ -868,7 +867,7 @@ class DeviceIdSet:
@use_cpp_class(xc.GSPMDSharding)
class GSPMDSharding(sharding.Sharding):
class GSPMDSharding(jsharding.Sharding):
_devices: tuple[Device, ...]
_hlo_sharding: xc.HloSharding
_memory_kind: str | None
@ -1122,7 +1121,7 @@ def prepare_axis_resources(axis_resources, arg_name,
for entry in entries:
if isinstance(entry, (UnspecifiedValue, AUTO)) or entry is None:
new_entries.append(entry)
elif isinstance(entry, sharding.Sharding):
elif isinstance(entry, jsharding.Sharding):
if isinstance(entry, PmapSharding):
raise ValueError(f'One of {what} got sharding {entry} which is not '
'allowed.')
@ -1138,7 +1137,7 @@ def prepare_axis_resources(axis_resources, arg_name,
def _check_unique_resources(axis_resources, arg_name):
for arg_axis_resources in axis_resources:
if not arg_axis_resources: continue
if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, sharding.Sharding)):
if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, jsharding.Sharding)):
continue
constrained_dims = [d for d in arg_axis_resources if d is not None]
resource_counts = collections.Counter(
@ -1371,7 +1370,7 @@ class NonUniformShardingError(ValueError):
def get_process_index_and_count(
tensor_sharding: sharding.Sharding, dim: int, ndims: int) -> tuple[int, int]:
tensor_sharding: jsharding.Sharding, dim: int, ndims: int) -> tuple[int, int]:
"""Get current process index and number of unique processes for given dimension.
This function facilitates mapping of process-level data to individual
@ -1486,7 +1485,7 @@ def get_process_index_and_count(
def local_to_global_shape(
sharding: sharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]:
sharding: jsharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]:
"""Computes the global shape given the per process if possible.
The returned shape will have the size of the global tensor in that dimension
@ -1545,7 +1544,7 @@ def local_to_global_shape(
def num_addressable_indices(
tensor_sharding: sharding.Sharding, dim: int, global_shape: Shape) -> int:
tensor_sharding: jsharding.Sharding, dim: int, global_shape: Shape) -> int:
"""Returns the number of indices for given dimension this host has access to.
Each host can have multiple number of devices that are spanning
@ -1579,7 +1578,7 @@ def num_addressable_indices(
"""
# TODO(sandler, yashkatariya): Consider making this function public.
addressables = tensor_sharding.addressable_devices_indices_map(global_shape)
addressables = cast(Mapping[sharding.Device, Index], addressables)
addressables = cast(Mapping[jsharding.Device, Index], addressables)
num_unique_slices = len({
_slice_as_tuple(addressable[dim]) for addressable in addressables.values()
})
@ -1596,7 +1595,7 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)
def is_single_device_sharding(sharding: sharding.Sharding) -> bool:
def is_single_device_sharding(sharding: jsharding.Sharding) -> bool:
# Special case PmapSharding here because PmapSharding maps away an axis
# and needs to be handled separately.test_pjit_single_device_sharding_add
return sharding.num_devices == 1 and not isinstance(sharding, PmapSharding)
@ -1625,7 +1624,7 @@ def make_key_array_phys_sharding(aval, sharding):
def physical_sharding(
aval, sharding: sharding.Sharding) -> sharding.Sharding:
aval, sharding: jsharding.Sharding) -> jsharding.Sharding:
return make_key_array_phys_sharding(aval, sharding)
@ -1642,7 +1641,7 @@ def get_logical_gspmd_sharding(aval, phys_sharding):
return GSPMDSharding(phys_sharding._device_assignment,
xc.HloSharding.from_proto(logical_op_sharding))
def check_replicated_trailing_dims(sharding: sharding.Sharding, aval):
def check_replicated_trailing_dims(sharding: jsharding.Sharding, aval):
if isinstance(sharding, PmapSharding):
return
phys_aval = core.physical_aval(aval)
@ -1655,7 +1654,7 @@ def check_replicated_trailing_dims(sharding: sharding.Sharding, aval):
f" sharding: {sharding}, partitions: {partitions}, "
f"num_trailing_dims: {num_trailing_dims}")
def logical_sharding(aval, phys_sharding) -> sharding.Sharding:
def logical_sharding(aval, phys_sharding) -> jsharding.Sharding:
# The trailing dims should always be replicated.
check_replicated_trailing_dims(phys_sharding, aval)
@ -1695,6 +1694,44 @@ def _gspmd_to_named_sharding_via_mesh(
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
out_s.memory_kind)
def flatten_spec(spec):
out = []
for s in spec:
if s is None:
continue
if isinstance(s, tuple):
out.extend(s)
else:
out.append(s)
return out
def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
check_mesh_consistency: bool = True
) -> NamedSharding | None:
if not config.sharding_in_types.value:
return sharding # type: ignore
if sharding is None:
return sharding
if isinstance(sharding, PartitionSpec):
sharding = NamedSharding(mesh_lib.get_abstract_mesh(), sharding) # type: ignore
else:
if (check_mesh_consistency and
sharding.mesh != mesh_lib.get_abstract_mesh()):
raise ValueError(
f'Context mesh {mesh_lib.get_abstract_mesh()} should match the mesh'
f' of sharding {sharding.mesh}. This error occurs at source: '
f' {source_info_util.summarize(source_info_util.current())}')
for s in flatten_spec(sharding.spec):
if sharding.mesh._name_to_type[s] in {
mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Collective}:
raise ValueError(
'PartitionSpec cannot contain axis names that are of type Auto or'
f' Collective. Got PartitionSpec: {sharding.spec} with axis name:'
f' {s} or type: {sharding.mesh._name_to_type[s]}')
return sharding
def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
*, devices: Sequence[xc.Device] | None = None,

View File

@ -4858,8 +4858,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x, y):
out = jnp.einsum('xy,yz->xz', x, y,
out_type=NamedSharding(x.sharding.mesh, P('x', None)))
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
self.assertEqual(out.sharding.spec, P('x', None))
return jnp.sum(out)
@ -5155,9 +5154,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def g(x):
x = x * 2
y = jax.lax.broadcasted_iota(
x.dtype, (8, 2), 0,
_sharding=NamedSharding(mesh.abstract_mesh, P('x', 'y')))
y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, sharding=P('x', 'y'))
self.assertEqual(y.sharding.spec, P('x', 'y'))
return x, y
@ -5186,8 +5183,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def g(x, y):
out = jnp.einsum('xy,yz->xz', x, y,
out_type=NamedSharding(x.sharding.mesh, P('x', None)))
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
self.assertEqual(out.sharding.spec, P('x', None))
return out
@ -5216,9 +5212,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def h(x, y):
s = NamedSharding(x.sharding.mesh, P('x', None, 'y', None))
out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=s)
self.assertEqual(out.sharding.spec, s.spec)
spec = P('x', None, 'y', None)
out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=spec)
self.assertEqual(out.sharding.spec, spec)
return out
arr1 = jax.device_put(np_inp.reshape(8, 4, 2),
@ -5263,8 +5259,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(y.sharding.spec, dst_spec)
return y
new_s = (NamedSharding(mesh.abstract_mesh, dst_spec)
if use_sharding_arg else None)
new_s = dst_spec if use_sharding_arg else None
out = f(arr, new_s)
self.assertEqual(out.sharding, NamedSharding(mesh, dst_spec))
self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2)
@ -5659,15 +5654,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
y = x * 2
auto_mesh = mesh_lib.get_abstract_mesh().with_axis_types(
{mesh_lib.AxisTypes.Auto: ('x', 'y')})
y = sharding_cast(y, y.sharding.with_mesh(auto_mesh))
with mesh_lib.set_abstract_mesh(auto_mesh):
y = sharding_cast(y, P(None, None))
self.assertEqual(y.sharding.spec, P(None, None))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, None))
a = z @ z.T
self.assertEqual(a.sharding.spec, P(None, None))
a = sharding_cast(
a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None)))
a = sharding_cast(a, P('x', None))
self.assertEqual(a.sharding.spec, P('x', None))
return a
@ -5697,8 +5691,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(z.sharding.spec, P(None, 'y'))
a = z @ z.T
self.assertEqual(a.sharding.spec, P(None, None))
a = sharding_cast(
a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None)))
a = sharding_cast(a, P(None, None))
self.assertEqual(a.sharding.spec, P(None, None))
return a
@ -5719,15 +5712,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
y = x * 2
mix_mesh = mesh_lib.get_abstract_mesh().with_axis_types(
{mesh_lib.AxisTypes.Auto: 'x', mesh_lib.AxisTypes.User: 'y'})
y = sharding_cast(y, y.sharding.with_mesh(mix_mesh))
with mesh_lib.set_abstract_mesh(mix_mesh):
y = sharding_cast(y, P(None, 'y'))
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, 'y'))
a = z @ z.T
self.assertEqual(a.sharding.spec, P(None, None))
a = sharding_cast(
a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None)))
a = sharding_cast(a, P('x', None))
self.assertEqual(a.sharding.spec, P('x', None))
return a
@ -5759,8 +5751,10 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def test_sharding_cast_src_dst_mesh_mismatch(self):
np_inp = np.arange(16.).reshape(8, 2)
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
mesh2 = jtu.create_mesh((2, 1), ('a', 'b'))
mesh = jtu.create_mesh((2, 1), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.User: ('x', 'y')})
mesh2 = jtu.create_mesh((2, 1), ('a', 'b'),
axis_types={mesh_lib.AxisTypes.User: ('a', 'b')})
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
f = lambda x: sharding_cast(x, NamedSharding(mesh2, P('a', 'b')))
@ -5812,7 +5806,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def f(x):
auto_mesh = get_abstract_mesh().with_axis_types({AxisTypes.Auto: 'x'})
with set_abstract_mesh(auto_mesh):
x = sharding_cast(x, x.sharding.with_mesh(auto_mesh))
x = sharding_cast(x, P(None, None))
return x
self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.User: 'x'})
@ -5850,13 +5844,23 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x, y):
out = jnp.einsum('xy,yz->xz', x, y,
out_type=NamedSharding(auto_mesh, P('x', None)))
out_type=NamedSharding(auto_mesh, P(None, None)))
return out
with self.assertRaisesRegex(
ValueError, "context mesh.* should match the aval mesh"):
ValueError, "Context mesh.*should match the mesh of sharding"):
f(arr1, arr2)
@jax.jit
def g(x, y):
with mesh_lib.set_abstract_mesh(auto_mesh):
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
return out
with self.assertRaisesRegex(
ValueError, "PartitionSpec cannot contain axis names.*Auto"):
g(arr1, arr2)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):