Don't allow users to query tracer.sharding even under sharding in types mode.

Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
This commit is contained in:
Yash Katariya 2025-01-20 15:12:12 -08:00 committed by jax authors
parent 7f19b345fb
commit d50d1e2c40
9 changed files with 167 additions and 147 deletions

View File

@ -67,8 +67,7 @@ from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
NamedSharding)
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
from jax._src.layout import Layout, AutoLayout
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
@ -2564,11 +2563,7 @@ def _sds_aval_mapping(x):
aval = ShapedArray(
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=x.weak_type)
if config.sharding_in_types.value and isinstance(x.sharding, NamedSharding):
return aval.update(sharding=NamedSharding(
x.sharding.mesh.abstract_mesh,
x.sharding.spec._normalized_spec(x.ndim)))
return aval
return core.update_aval_with_sharding(aval, x.sharding)
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping

View File

@ -41,7 +41,7 @@ from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding,
PmapSharding, SingleDeviceSharding,
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable
from jax._src.typing import ArrayLike, DLDeviceType
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
@ -753,7 +753,8 @@ def make_array_from_callback(
first_value = per_device_values[0]
expected_dtype = first_value.dtype
expected_shape = sharding.shard_shape(shape)
aval = core.ShapedArray(shape, expected_dtype)
aval = core.update_aval_with_sharding(
core.ShapedArray(shape, expected_dtype), sharding)
_validate_shape_and_dtype_for_per_device_arrays(
per_device_values,
expected_shape=expected_shape,
@ -1017,7 +1018,8 @@ def make_array_from_single_device_arrays(
raise ValueError(
"jax.make_array_from_single_device_arrays requires a list of concrete"
f" arrays as input. got types {set(map(type, arrays))}")
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
aval = core.update_aval_with_sharding(
core.ShapedArray(shape, arrays[0].dtype, weak_type=False), sharding)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,
committed=True)
@ -1028,13 +1030,7 @@ def make_array_from_single_device_arrays(
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
def _get_aval_array(self):
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
return self.aval.update(sharding=NamedSharding(
self.sharding.mesh.abstract_mesh,
self.sharding.spec._normalized_spec(self.ndim)))
else:
return self.aval
return core.update_aval_with_sharding(self.aval, self.sharding)
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
# TODO(jakevdp) replace this with true inheritance at the C++ level.
@ -1179,6 +1175,7 @@ pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
def _array_global_result_handler(global_aval, out_sharding, committed):
global_aval = core.update_aval_with_sharding(global_aval, out_sharding)
if global_aval.dtype == dtypes.float0:
return lambda _: np.zeros(global_aval.shape, dtypes.float0)
if dtypes.issubdtype(global_aval.dtype, dtypes.extended):

View File

@ -816,6 +816,12 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
# if the aval property raises an AttributeError, gets caught here
assert not config.enable_checks.value or name != "aval"
if name == 'sharding':
raise AttributeError(
self,
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")
try:
attr = getattr(self.aval, name)
except AttributeError as err:
@ -1421,6 +1427,13 @@ def check_valid_jaxtype(x):
raise TypeError(
f"Value {x!r} of type {type(x)} is not a valid JAX type")
def update_aval_with_sharding(aval, sharding):
from jax._src.sharding_impls import NamedSharding # type: ignore
if config.sharding_in_types.value and isinstance(sharding, NamedSharding):
aval = aval.update(sharding=NamedSharding(
sharding.mesh.abstract_mesh, sharding.spec._normalized_spec(aval.ndim)))
return aval
# We have three flavors of abstractification APIs here which each used to have
# their own separate implementation. Now they're effectively the same, with the
@ -1433,8 +1446,6 @@ def check_valid_jaxtype(x):
# TODO(jakevdp): can these be unified further?
def shaped_abstractify(x):
from jax._src.sharding_impls import NamedSharding # type: ignore
typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
@ -1448,12 +1459,7 @@ def shaped_abstractify(x):
if hasattr(x, 'dtype'):
aval = ShapedArray(np.shape(x), x.dtype,
weak_type=getattr(x, 'weak_type', False))
if (config.sharding_in_types.value and hasattr(x, 'sharding') and
isinstance(x.sharding, NamedSharding)):
return aval.update(sharding=NamedSharding(
x.sharding.mesh.abstract_mesh,
x.sharding.spec._normalized_spec(aval.ndim)))
return aval
return update_aval_with_sharding(aval, getattr(x, 'sharding', None))
raise TypeError(
f"Cannot interpret value of type {typ} as an abstract array; it "
"does not have a dtype attribute")
@ -1701,13 +1707,17 @@ def get_sharding(sharding, ndim):
raise ValueError(
"Length of sharding.spec must be equal to aval's ndim. Got"
f" sharding.spec {sharding.spec} and aval.ndim {ndim}")
return _maybe_modify_sharding(sharding)
context_mesh = mesh_lib.get_abstract_mesh()
if not context_mesh:
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
assert sharding is None
return NamedSharding(context_mesh, P(*[None] * ndim))
out_s = _maybe_modify_sharding(sharding)
else:
context_mesh = mesh_lib.get_abstract_mesh()
if not context_mesh:
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
assert sharding is None
out_s = NamedSharding(context_mesh, P(*[None] * ndim))
if not isinstance(out_s.mesh, mesh_lib.AbstractMesh):
raise ValueError("Mesh of an aval must be an AbstractMesh. "
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
return out_s
class ShapedArray(UnshapedArray):
@ -1720,9 +1730,6 @@ class ShapedArray(UnshapedArray):
self.weak_type = weak_type
if config.sharding_in_types.value:
self.sharding = get_sharding(sharding, len(self.shape))
if not isinstance(self.sharding.mesh, mesh_lib.AbstractMesh):
raise ValueError(
f"Mesh of an aval must be an AbstractMesh. Got {self.sharding.mesh}")
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
if shape is None:
@ -1796,14 +1803,6 @@ def _get_shape_sharding_str(shape, spec):
out.append(f"{s1}@{s2}")
return ','.join(out)
def _get_abstract_sharding(val):
from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error
if (config.sharding_in_types.value and hasattr(val, 'sharding') and
isinstance(val.sharding, NamedSharding)):
return NamedSharding(val.sharding.mesh.abstract_mesh,
val.sharding.spec._normalized_spec(val.ndim))
return None
def primal_dtype_to_tangent_dtype(primal_dtype):
if isinstance(primal_dtype, dtypes.ExtendedDType):

View File

@ -230,9 +230,9 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
if not hasattr(x, 'shape')))) from err
if (config.sharding_in_types.value and
not all(x.sharding.spec[0] is None for x in xs_flat)):
not all(x.aval.sharding.spec[0] is None for x in xs_flat)):
raise ValueError('0th dimension of all xs should be replicated. Got '
f'{", ".join(str(x.sharding.spec) for x in xs_flat)}')
f'{", ".join(str(x.aval.sharding.spec) for x in xs_flat)}')
if length is not None:
try:

View File

@ -586,7 +586,7 @@ def _convert_element_type(
if (config.sharding_in_types.value and sharding is None and
isinstance(operand, Array)):
sharding = operand.sharding
sharding = operand.aval.sharding
sharding = canonicalize_sharding(sharding, check_mesh_consistency=False) # type: ignore
@ -1920,7 +1920,8 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
fill_value = _convert_element_type(fill_value, dtype, weak_type)
if (sharding is not None and not isinstance(sharding, PmapSharding) and
isinstance(fill_value, array.ArrayImpl)):
isinstance(fill_value, array.ArrayImpl) and
not config.sharding_in_types.value):
broadcast_shape = sharding.shard_shape(shape)
shard = broadcast(fill_value, broadcast_shape)
return array.make_array_from_callback(shape, sharding, lambda _: shard)
@ -2137,7 +2138,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
if (config.sharding_in_types.value and sharding is None and
isinstance(x, Array)):
sharding = x.sharding
sharding = x.aval.sharding
else:
# If `x` has a sharding but no `_committed` attribute
# (in case of ShapeDtypeStruct), default it to True.
@ -4496,7 +4497,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions,
broadcast_dimensions=broadcast_dimensions)
if config.sharding_in_types.value:
if sharding is not None:
assert sharding == aval_out.sharding
assert sharding == aval_out.sharding, (sharding, aval_out.sharding)
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
@ -5656,7 +5657,7 @@ def _compute_argminmax(value_comparator, get_identity,
axis, = axes
indices = broadcasted_iota(
index_dtype, np.shape(operand), axis,
sharding=operand.sharding if config.sharding_in_types.value else None)
sharding=operand.aval.sharding if config.sharding_in_types.value else None)
res = reduce([operand, indices],
[get_identity(operand.dtype), np.array(0, index_dtype)],
_ArgMinMaxReducer(value_comparator),
@ -6644,7 +6645,7 @@ _zeros: Callable = partial(full_like, fill_value=0)
def _zero(x):
if config.sharding_in_types.value:
return full_like(x, shape=(), fill_value=0,
sharding=x.sharding.with_spec(P())) # type: ignore
sharding=x.aval.sharding.with_spec(P())) # type: ignore
return full_like(x, shape=(), fill_value=0)
_ones: Callable = partial(full_like, fill_value=1)
@ -6652,7 +6653,7 @@ _ones: Callable = partial(full_like, fill_value=1)
def _one(x):
if config.sharding_in_types.value:
return full_like(x, shape=(), fill_value=1,
sharding=x.sharding.with_spec(P()))
sharding=x.aval.sharding.with_spec(P()))
return full_like(x, shape=(), fill_value=1)
_twos: Callable = partial(full_like, fill_value=2)

View File

@ -667,7 +667,7 @@ def _one_hot(x: Array, num_classes: int, *,
rhs_shape.insert(output_pos_axis, num_classes)
if config.sharding_in_types.value:
# TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too?
rhs_sharding = NamedSharding(x.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
rhs_sharding = NamedSharding(x.aval.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
else:
rhs_sharding = None
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,

View File

@ -5553,7 +5553,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
weak_type = dtype is None and dtypes.is_weakly_typed(object)
if (config.sharding_in_types.value and device is None and
isinstance(object, Array)):
sharding = object.sharding
sharding = object.aval.sharding
else:
sharding = canonicalize_device_to_sharding(device) # type: ignore

View File

@ -2838,7 +2838,7 @@ def hidden_axes(fun, *, axes: str | tuple[str, ...] | None = None,
def decorator(*args, **kwargs):
with mesh_lib.set_abstract_mesh(new_mesh):
in_specs = tree_map(lambda a: core.modify_spec_for_hidden(
a.sharding.spec, new_mesh), args)
a.aval.sharding.spec, new_mesh), args)
args = mesh_cast(args, in_specs)
out = fun(*args, **kwargs)
return mesh_cast(out, out_shardings)
@ -2859,7 +2859,7 @@ def visible_axes(fun, *, axes: str | tuple[str, ...] | None = None,
args = mesh_cast(args, in_shardings)
out = fun(*args, **kwargs)
out_specs = tree_map(lambda o: core.modify_spec_for_hidden(
o.sharding.spec, mesh_lib.get_abstract_mesh()), out)
o.aval.sharding.spec, mesh_lib.get_abstract_mesh()), out)
return mesh_cast(out, out_specs)
return decorator

View File

@ -4788,15 +4788,21 @@ class ShardingInTypesTest(jtu.JaxTestCase):
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@jax.jit
def f(x):
self.assertEqual(x.sharding.spec, s.spec)
self.assertEqual(x.aval.sharding.spec, s.spec)
x = x * 2
self.assertEqual(x.sharding.spec, s.spec)
self.assertEqual(x.aval.sharding.spec, s.spec)
x = x * x
self.assertEqual(x.sharding.spec, s.spec)
self.assertEqual(x.aval.sharding.spec, s.spec)
return x
# Eager mode
out = f(arr)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
f = jax.jit(f)
out = f(arr)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
@ -4832,9 +4838,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x, y):
self.assertEqual(x.sharding.spec, s.spec)
self.assertEqual(x.aval.sharding.spec, s.spec)
out = x * y
self.assertEqual(out.sharding.spec, s.spec)
self.assertEqual(out.aval.sharding.spec, s.spec)
return out
out = f(arr1, arr2)
@ -4876,16 +4882,21 @@ class ShardingInTypesTest(jtu.JaxTestCase):
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2))
@jax.jit
def f(x, y):
out = x @ y
self.assertEqual(out.sharding.spec, out_spec)
self.assertEqual(out.aval.sharding.spec, out_spec)
return out
out = f(arr1, arr2)
self.assertArraysEqual(out, np_inp1 @ np_inp1.T)
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
f = jax.jit(f)
out = f(arr1, arr2)
self.assertArraysEqual(out, np_inp1 @ np_inp1.T)
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
lowered = f.lower(arr1, arr2)
self.check_wsc_in_lowered(lowered.as_text())
@ -4912,16 +4923,21 @@ class ShardingInTypesTest(jtu.JaxTestCase):
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
@jax.jit
def f(x, y):
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
self.assertEqual(out.sharding.spec, P('x', None))
self.assertEqual(out.aval.sharding.spec, P('x', None))
return jnp.sum(out)
out = f(arr1, arr2)
self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T))
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
f = jax.jit(f)
out = f(arr1, arr2)
self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T))
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
out = jax.grad(f, argnums=(0, 1))(arr1, arr2)
self.assertEqual(out[0].sharding, arr1.sharding)
self.assertEqual(out[1].sharding, arr2.sharding)
@ -4999,6 +5015,16 @@ class ShardingInTypesTest(jtu.JaxTestCase):
aval = aval.update(sharding=NamedSharding(mesh, P(('model', 'data'), None)))
self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]')
@jtu.with_user_mesh((2, 1), ('x', 'y'))
def test_jnp_ones_mesh_context_eager(self, mesh):
s = NamedSharding(mesh, P('x', None))
out = jnp.ones((8, 2), dtype=jnp.int32, device=s)
self.assertEqual(out.sharding, s)
s = NamedSharding(mesh, P('x', 'y'))
out = jnp.ones((8, 2), dtype=jnp.int32, device=s)
self.assertEqual(out.sharding, s)
@parameterized.named_parameters(
('all', None, P('x', 'y'), P(), True),
('first', 0, P('x', 'y'), P('y'), True),
@ -5014,9 +5040,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
self.assertEqual(x.sharding.spec, s.spec)
self.assertEqual(x.aval.sharding.spec, s.spec)
y = jnp.sum(x, axis=axis)
self.assertEqual(y.sharding.spec, out_spec)
self.assertEqual(y.aval.sharding.spec, out_spec)
return y
out = f(arr)
@ -5045,9 +5071,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
self.assertEqual(x.sharding.spec, s.spec)
self.assertEqual(x.aval.sharding.spec, s.spec)
y = jnp.max(x, axis=axis)
self.assertEqual(y.sharding.spec, out_spec)
self.assertEqual(y.aval.sharding.spec, out_spec)
return y
out = f(arr)
@ -5090,7 +5116,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = jnp.expand_dims(x, axis=axis)
self.assertEqual(y.sharding.spec, out_spec)
self.assertEqual(y.aval.sharding.spec, out_spec)
return y
out = f(arr)
@ -5113,7 +5139,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = x ** pow
self.assertEqual(y.sharding.spec, s.spec)
self.assertEqual(y.aval.sharding.spec, s.spec)
return y
out = f(arr)
@ -5136,7 +5162,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
return x + y
with self.assertRaisesRegex(
ValueError, "For primitive add, context mesh.*aval mesh"):
ValueError, "For primitive.*context mesh.*aval mesh"):
f(arr1, arr2)
@jtu.with_user_mesh((2, 2), ('x', 'y'))
@ -5148,7 +5174,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = lax.sin(x)
self.assertEqual(y.sharding.spec, s.spec)
self.assertEqual(y.aval.sharding.spec, s.spec)
return y
out = f(arr)
@ -5168,7 +5194,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
assert x.dtype == jnp.int32
y = jnp.array(x, dtype=jnp.float32)
self.assertEqual(y.dtype, jnp.float32)
self.assertEqual(y.sharding.spec, s.spec)
self.assertEqual(y.aval.sharding.spec, s.spec)
return y
f(arr)
@ -5182,7 +5208,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = jnp.transpose(x, (1, 2, 0))
self.assertEqual(y.sharding.spec, P('y', 'z', 'x'))
self.assertEqual(y.aval.sharding.spec, P('y', 'z', 'x'))
return y
out = f(arr)
@ -5201,7 +5227,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = jax.nn.one_hot(x, 4)
self.assertEqual(y.sharding.spec, P('x', None))
self.assertEqual(y.aval.sharding.spec, P('x', None))
return y
out = f(arr)
@ -5211,7 +5237,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def g(x):
x = x * 2
y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, sharding=P('x', 'y'))
self.assertEqual(y.sharding.spec, P('x', 'y'))
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
return x, y
_, out = g(arr)
@ -5226,8 +5252,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x, y):
out = jnp.einsum('xy,yz->xz', x, y,
out_sharding=NamedSharding(x.sharding.mesh, P('x', None)))
self.assertEqual(out.sharding.spec, P('x', None))
out_sharding=NamedSharding(x.aval.sharding.mesh, P('x', None)))
self.assertEqual(out.aval.sharding.spec, P('x', None))
return out
out = f(arr1, arr2)
@ -5240,7 +5266,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def g(x, y):
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
self.assertEqual(out.sharding.spec, P('x', None))
self.assertEqual(out.aval.sharding.spec, P('x', None))
return out
arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
@ -5270,7 +5296,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def h(x, y):
spec = P('x', None, 'y', None)
out = jnp.einsum('btd,dhq->bhtq', x, y, out_sharding=spec)
self.assertEqual(out.sharding.spec, spec)
self.assertEqual(out.aval.sharding.spec, spec)
return out
arr1 = jax.device_put(np_inp.reshape(8, 4, 2),
@ -5315,7 +5341,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def f(x, new_sharding):
y = lax.reshape(x, dst_shape, sharding=new_sharding)
y = y * 2
self.assertEqual(y.sharding.spec, dst_spec)
self.assertEqual(y.aval.sharding.spec, dst_spec)
return y
new_s = dst_spec if use_sharding_arg else None
@ -5384,7 +5410,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def f(x):
y = lax.reshape(x, dst_shape)
y = y * 2
self.assertEqual(y.sharding.spec, dst_spec)
self.assertEqual(y.aval.sharding.spec, dst_spec)
return y
if error_msg:
@ -5415,7 +5441,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(pred, on_true, on_false):
y = lax.select(pred, on_true, on_false)
self.assertEqual(y.sharding.spec, s.spec)
self.assertEqual(y.aval.sharding.spec, s.spec)
return y
out = f(arr1 == arr2, arr1, arr2)
@ -5438,7 +5464,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = mesh_cast(x, NamedSharding(x.sharding.mesh, P('x', None)))
y = mesh_cast(x, NamedSharding(x.aval.sharding.mesh, P('x', None)))
return y
with self.assertRaisesRegex(
@ -5481,18 +5507,19 @@ class ShardingInTypesTest(jtu.JaxTestCase):
arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
def g(x, y):
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
self.assertTrue(x.aval.sharding.mesh._are_all_axes_collective)
self.assertTrue(y.aval.sharding.mesh._are_all_axes_collective)
self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective)
return x * y
@jax.jit
def f(x, y):
z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec),
z = shard_map(g, mesh=mesh,
in_specs=(x.aval.sharding.spec, y.aval.sharding.spec),
out_specs=P('x', 'y'))(x, y)
self.assertEqual(z.sharding.spec, P('x', 'y'))
self.assertEqual(z.aval.sharding.spec, P('x', 'y'))
out = z * 2
self.assertEqual(out.sharding.spec, P('x', 'y'))
self.assertEqual(out.aval.sharding.spec, P('x', 'y'))
return out
out = f(arr, arr2)
@ -5506,8 +5533,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
def g(x, y):
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
self.assertTrue(x.aval.sharding.mesh._are_all_axes_collective)
self.assertTrue(y.aval.sharding.mesh._are_all_axes_collective)
self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective)
allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True)
z = x @ allgatherd_y
@ -5515,11 +5542,12 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x, y):
z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec),
z = shard_map(g, mesh=mesh,
in_specs=(x.aval.sharding.spec, y.aval.sharding.spec),
out_specs=P('x', None))(x, y)
self.assertEqual(z.sharding.spec, P('x', None))
self.assertEqual(z.aval.sharding.spec, P('x', None))
out = z * 2
self.assertEqual(out.sharding.spec, P('x', None))
self.assertEqual(out.aval.sharding.spec, P('x', None))
return out
out = f(arr, arr2)
@ -5534,7 +5562,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = lax.slice(x, (0, 0), (4, 3))
self.assertEqual(y.sharding.spec, P('x', None))
self.assertEqual(y.aval.sharding.spec, P('x', None))
return y
out = f(arr)
@ -5565,7 +5593,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = lax.squeeze(x, (2,))
self.assertEqual(y.sharding.spec, P('x', None))
self.assertEqual(y.aval.sharding.spec, P('x', None))
return y
out = f(arr)
@ -5591,7 +5619,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@partial(jax.jit, static_argnums=(1, 2))
def f(x, padding_config, spec):
y = lax.pad(x, 0., padding_config)
self.assertEqual(y.sharding.spec, spec)
self.assertEqual(y.aval.sharding.spec, spec)
return y
out = f(arr, ((2, 2, 0),), P('x'))
@ -5639,7 +5667,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
else:
assert method == 'lax'
y = lax.concatenate([x, y], dimension=1)
self.assertEqual(y.sharding.spec, P('x', 'y'))
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
return y
out = f(arr1, arr2)
@ -5677,14 +5705,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(carry, xs):
def g(carry, x):
self.assertEqual(carry.sharding.spec, P(None, 'x'))
self.assertEqual(x.sharding.spec, P('x', 'y'))
self.assertEqual(carry.aval.sharding.spec, P(None, 'x'))
self.assertEqual(x.aval.sharding.spec, P('x', 'y'))
y = carry @ x
self.assertEqual(y.sharding.spec, P(None, 'y'))
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
z = jax.nn.relu(y)
self.assertEqual(z.sharding.spec, P(None, 'y'))
self.assertEqual(z.aval.sharding.spec, P(None, 'y'))
a = z @ x.T
self.assertEqual(a.sharding.spec, P(None, 'x'))
self.assertEqual(a.aval.sharding.spec, P(None, 'x'))
return a, y
return jax.lax.scan(g, carry, xs)
@ -5714,9 +5742,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
z = jnp.argmax(x, axis=0)
self.assertEqual(z.sharding.spec, P('y'))
self.assertEqual(z.aval.sharding.spec, P('y'))
a = jnp.argmin(x, axis=1)
self.assertEqual(a.sharding.spec, P('x'))
self.assertEqual(a.aval.sharding.spec, P('x'))
return z, a
out1, out2 = f(arr)
@ -5734,11 +5762,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x, x2):
y = x * 2
self.assertEqual(y.sharding.spec, P(None, None))
self.assertEqual(y.aval.sharding.spec, P(None, None))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, None))
self.assertEqual(z.aval.sharding.spec, P(None, None))
a = z @ x2
self.assertEqual(a.sharding.spec, P(None, None))
self.assertEqual(a.aval.sharding.spec, P(None, None))
return a
out = f(arr, arr.T)
@ -5819,13 +5847,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
y = x * 2
with use_hidden_axes('x', 'y'):
y = mesh_cast(y, P(None, None))
self.assertEqual(y.sharding.spec, P(None, None))
self.assertEqual(y.aval.sharding.spec, P(None, None))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, None))
self.assertEqual(z.aval.sharding.spec, P(None, None))
a = z @ z.T
self.assertEqual(a.sharding.spec, P(None, None))
self.assertEqual(a.aval.sharding.spec, P(None, None))
a = mesh_cast(a, P('x', None))
self.assertEqual(a.sharding.spec, P('x', None))
self.assertEqual(a.aval.sharding.spec, P('x', None))
return a
out = f(arr)
@ -5847,13 +5875,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
y = x * 2
with use_visible_axes('x', 'y'):
y = mesh_cast(y, P(None, 'y'))
self.assertEqual(y.sharding.spec, P(None, 'y'))
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, 'y'))
self.assertEqual(z.aval.sharding.spec, P(None, 'y'))
a = z @ z.T
self.assertEqual(a.sharding.spec, P(None, None))
self.assertEqual(a.aval.sharding.spec, P(None, None))
a = mesh_cast(a, P(None, None))
self.assertEqual(a.sharding.spec, P(None, None))
self.assertEqual(a.aval.sharding.spec, P(None, None))
return a
out = f(arr)
@ -5873,13 +5901,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
y = x * 2
with use_hidden_axes('x'):
y = mesh_cast(y, P(None, 'y'))
self.assertEqual(y.sharding.spec, P(None, 'y'))
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, 'y'))
self.assertEqual(z.aval.sharding.spec, P(None, 'y'))
a = z @ z.T
self.assertEqual(a.sharding.spec, P(None, None))
self.assertEqual(a.aval.sharding.spec, P(None, None))
a = mesh_cast(a, P('x', None))
self.assertEqual(a.sharding.spec, P('x', None))
self.assertEqual(a.aval.sharding.spec, P('x', None))
return a
out = f(arr)
@ -5915,8 +5943,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@partial(jax.jit, static_argnums=(1, 2))
def f(x, sizes=(4, 4), axis=0):
ys = lax.split(x, sizes, axis=axis)
self.assertEqual(ys[0].sharding.spec, P('x', 'y'))
self.assertEqual(ys[1].sharding.spec, P('x', 'y'))
self.assertEqual(ys[0].aval.sharding.spec, P('x', 'y'))
self.assertEqual(ys[1].aval.sharding.spec, P('x', 'y'))
return ys
f(arr)
@ -6010,7 +6038,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = x * 2
self.assertEqual(y.sharding.spec, P('x', None, None))
self.assertEqual(y.aval.sharding.spec, P('x', None, None))
return y
out = f(arr)
@ -6032,18 +6060,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@partial(hidden_axes, axes='x', out_shardings=P('x', None))
def h(y):
self.assertEqual(y.sharding.spec, P(None, 'y'))
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, 'y'))
self.assertEqual(z.aval.sharding.spec, P(None, 'y'))
a = z @ z.T
self.assertEqual(a.sharding.spec, P(None, None))
self.assertEqual(a.aval.sharding.spec, P(None, None))
return a
@jax.jit
def g(x):
y = x * 2
a = h(y)
self.assertEqual(a.sharding.spec, P('x', None))
self.assertEqual(a.aval.sharding.spec, P('x', None))
return a
out = g(arr)
@ -6063,18 +6091,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
# No axes specified means full visible mode.
@partial(visible_axes, in_shardings=P('x', 'y'))
def h(y):
self.assertEqual(y.sharding.spec, P('x', 'y'))
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P('x', 'y'))
self.assertEqual(z.aval.sharding.spec, P('x', 'y'))
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', None))
self.assertEqual(a.sharding.spec, P('x', None))
self.assertEqual(a.aval.sharding.spec, P('x', None))
return a
@jax.jit
def f(x):
y = x * 2
a = h(y)
self.assertEqual(a.sharding.spec, P(None, None))
self.assertEqual(a.aval.sharding.spec, P(None, None))
return a
out = f(arr)
@ -6093,18 +6121,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@partial(visible_axes, axes='y', in_shardings=P('x', 'y'))
def h(y):
self.assertEqual(y.sharding.spec, P('x', 'y'))
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P('x', 'y'))
self.assertEqual(z.aval.sharding.spec, P('x', 'y'))
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', 'y'))
self.assertEqual(a.sharding.spec, P('x', 'y'))
self.assertEqual(a.aval.sharding.spec, P('x', 'y'))
return a
@jax.jit
def f(x):
y = x * 2
a = h(y)
self.assertEqual(a.sharding.spec, P('x', None))
self.assertEqual(a.aval.sharding.spec, P('x', None))
return a
out = f(arr)
@ -6119,18 +6147,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@partial(visible_axes, axes='y', in_shardings=P(None, 'y'))
def h(y):
self.assertEqual(y.sharding.spec, P(None, 'y'))
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, 'y'))
self.assertEqual(z.aval.sharding.spec, P(None, 'y'))
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P(None, 'y'))
self.assertEqual(a.sharding.spec, P(None, 'y'))
self.assertEqual(a.aval.sharding.spec, P(None, 'y'))
return a
@jax.jit
def f(x):
y = x * 2
a = h(y)
self.assertEqual(a.sharding.spec, P(None, None))
self.assertEqual(a.aval.sharding.spec, P(None, None))
return a
out = f(arr)
@ -6147,7 +6175,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def f(embed_vd, token_bt):
out = embed_vd.at[token_bt].get(out_sharding=P('x', None, None))
self.assertEqual(out.shape, (8, 4, 16))
self.assertEqual(out.sharding.spec, P('x', None, None))
self.assertEqual(out.aval.sharding.spec, P('x', None, None))
return out
out = f(embed, tok)
@ -6213,11 +6241,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = x * 2
self.assertEqual(y.sharding.spec, P(None, None))
self.assertEqual(y.aval.sharding.spec, P(None, None))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(None, None))
self.assertEqual(z.aval.sharding.spec, P(None, None))
a = z @ z.T
self.assertEqual(a.sharding.spec, P(None, None))
self.assertEqual(a.aval.sharding.spec, P(None, None))
return a
hf = hidden_axes(f, axes=('x', 'y'), out_shardings=P('x', 'y'))
@ -6234,9 +6262,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = x * 2
self.assertEqual(y.sharding.spec, P('x', 'y'))
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P('x', 'y'))
self.assertEqual(z.aval.sharding.spec, P('x', 'y'))
return z
hf = visible_axes(f, axes=('x', 'y'), in_shardings=P('x', 'y'))