mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Don't allow users to query tracer.sharding
even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode. PiperOrigin-RevId: 717638986
This commit is contained in:
parent
7f19b345fb
commit
d50d1e2c40
@ -67,8 +67,7 @@ from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
|
||||
NamedSharding)
|
||||
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
|
||||
from jax._src.layout import Layout, AutoLayout
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
@ -2564,11 +2563,7 @@ def _sds_aval_mapping(x):
|
||||
aval = ShapedArray(
|
||||
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
|
||||
weak_type=x.weak_type)
|
||||
if config.sharding_in_types.value and isinstance(x.sharding, NamedSharding):
|
||||
return aval.update(sharding=NamedSharding(
|
||||
x.sharding.mesh.abstract_mesh,
|
||||
x.sharding.spec._normalized_spec(x.ndim)))
|
||||
return aval
|
||||
return core.update_aval_with_sharding(aval, x.sharding)
|
||||
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension as xe
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding, NamedSharding,
|
||||
PmapSharding, SingleDeviceSharding,
|
||||
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable
|
||||
from jax._src.typing import ArrayLike, DLDeviceType
|
||||
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
|
||||
@ -753,7 +753,8 @@ def make_array_from_callback(
|
||||
first_value = per_device_values[0]
|
||||
expected_dtype = first_value.dtype
|
||||
expected_shape = sharding.shard_shape(shape)
|
||||
aval = core.ShapedArray(shape, expected_dtype)
|
||||
aval = core.update_aval_with_sharding(
|
||||
core.ShapedArray(shape, expected_dtype), sharding)
|
||||
_validate_shape_and_dtype_for_per_device_arrays(
|
||||
per_device_values,
|
||||
expected_shape=expected_shape,
|
||||
@ -1017,7 +1018,8 @@ def make_array_from_single_device_arrays(
|
||||
raise ValueError(
|
||||
"jax.make_array_from_single_device_arrays requires a list of concrete"
|
||||
f" arrays as input. got types {set(map(type, arrays))}")
|
||||
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
|
||||
aval = core.update_aval_with_sharding(
|
||||
core.ShapedArray(shape, arrays[0].dtype, weak_type=False), sharding)
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,
|
||||
committed=True)
|
||||
@ -1028,13 +1030,7 @@ def make_array_from_single_device_arrays(
|
||||
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
|
||||
|
||||
def _get_aval_array(self):
|
||||
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
|
||||
return self.aval.update(sharding=NamedSharding(
|
||||
self.sharding.mesh.abstract_mesh,
|
||||
self.sharding.spec._normalized_spec(self.ndim)))
|
||||
else:
|
||||
return self.aval
|
||||
|
||||
return core.update_aval_with_sharding(self.aval, self.sharding)
|
||||
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
|
||||
|
||||
# TODO(jakevdp) replace this with true inheritance at the C++ level.
|
||||
@ -1179,6 +1175,7 @@ pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
|
||||
|
||||
|
||||
def _array_global_result_handler(global_aval, out_sharding, committed):
|
||||
global_aval = core.update_aval_with_sharding(global_aval, out_sharding)
|
||||
if global_aval.dtype == dtypes.float0:
|
||||
return lambda _: np.zeros(global_aval.shape, dtypes.float0)
|
||||
if dtypes.issubdtype(global_aval.dtype, dtypes.extended):
|
||||
|
@ -816,6 +816,12 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
# if the aval property raises an AttributeError, gets caught here
|
||||
assert not config.enable_checks.value or name != "aval"
|
||||
|
||||
if name == 'sharding':
|
||||
raise AttributeError(
|
||||
self,
|
||||
f"The 'sharding' attribute is not available on {self._error_repr()}."
|
||||
f"{self._origin_msg()}")
|
||||
|
||||
try:
|
||||
attr = getattr(self.aval, name)
|
||||
except AttributeError as err:
|
||||
@ -1421,6 +1427,13 @@ def check_valid_jaxtype(x):
|
||||
raise TypeError(
|
||||
f"Value {x!r} of type {type(x)} is not a valid JAX type")
|
||||
|
||||
def update_aval_with_sharding(aval, sharding):
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
if config.sharding_in_types.value and isinstance(sharding, NamedSharding):
|
||||
aval = aval.update(sharding=NamedSharding(
|
||||
sharding.mesh.abstract_mesh, sharding.spec._normalized_spec(aval.ndim)))
|
||||
return aval
|
||||
|
||||
|
||||
# We have three flavors of abstractification APIs here which each used to have
|
||||
# their own separate implementation. Now they're effectively the same, with the
|
||||
@ -1433,8 +1446,6 @@ def check_valid_jaxtype(x):
|
||||
# TODO(jakevdp): can these be unified further?
|
||||
|
||||
def shaped_abstractify(x):
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
|
||||
typ = type(x)
|
||||
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
|
||||
return aval_fn(x)
|
||||
@ -1448,12 +1459,7 @@ def shaped_abstractify(x):
|
||||
if hasattr(x, 'dtype'):
|
||||
aval = ShapedArray(np.shape(x), x.dtype,
|
||||
weak_type=getattr(x, 'weak_type', False))
|
||||
if (config.sharding_in_types.value and hasattr(x, 'sharding') and
|
||||
isinstance(x.sharding, NamedSharding)):
|
||||
return aval.update(sharding=NamedSharding(
|
||||
x.sharding.mesh.abstract_mesh,
|
||||
x.sharding.spec._normalized_spec(aval.ndim)))
|
||||
return aval
|
||||
return update_aval_with_sharding(aval, getattr(x, 'sharding', None))
|
||||
raise TypeError(
|
||||
f"Cannot interpret value of type {typ} as an abstract array; it "
|
||||
"does not have a dtype attribute")
|
||||
@ -1701,13 +1707,17 @@ def get_sharding(sharding, ndim):
|
||||
raise ValueError(
|
||||
"Length of sharding.spec must be equal to aval's ndim. Got"
|
||||
f" sharding.spec {sharding.spec} and aval.ndim {ndim}")
|
||||
return _maybe_modify_sharding(sharding)
|
||||
|
||||
context_mesh = mesh_lib.get_abstract_mesh()
|
||||
if not context_mesh:
|
||||
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
|
||||
assert sharding is None
|
||||
return NamedSharding(context_mesh, P(*[None] * ndim))
|
||||
out_s = _maybe_modify_sharding(sharding)
|
||||
else:
|
||||
context_mesh = mesh_lib.get_abstract_mesh()
|
||||
if not context_mesh:
|
||||
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
|
||||
assert sharding is None
|
||||
out_s = NamedSharding(context_mesh, P(*[None] * ndim))
|
||||
if not isinstance(out_s.mesh, mesh_lib.AbstractMesh):
|
||||
raise ValueError("Mesh of an aval must be an AbstractMesh. "
|
||||
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
|
||||
return out_s
|
||||
|
||||
|
||||
class ShapedArray(UnshapedArray):
|
||||
@ -1720,9 +1730,6 @@ class ShapedArray(UnshapedArray):
|
||||
self.weak_type = weak_type
|
||||
if config.sharding_in_types.value:
|
||||
self.sharding = get_sharding(sharding, len(self.shape))
|
||||
if not isinstance(self.sharding.mesh, mesh_lib.AbstractMesh):
|
||||
raise ValueError(
|
||||
f"Mesh of an aval must be an AbstractMesh. Got {self.sharding.mesh}")
|
||||
|
||||
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
|
||||
if shape is None:
|
||||
@ -1796,14 +1803,6 @@ def _get_shape_sharding_str(shape, spec):
|
||||
out.append(f"{s1}@{s2}")
|
||||
return ','.join(out)
|
||||
|
||||
def _get_abstract_sharding(val):
|
||||
from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error
|
||||
|
||||
if (config.sharding_in_types.value and hasattr(val, 'sharding') and
|
||||
isinstance(val.sharding, NamedSharding)):
|
||||
return NamedSharding(val.sharding.mesh.abstract_mesh,
|
||||
val.sharding.spec._normalized_spec(val.ndim))
|
||||
return None
|
||||
|
||||
def primal_dtype_to_tangent_dtype(primal_dtype):
|
||||
if isinstance(primal_dtype, dtypes.ExtendedDType):
|
||||
|
@ -230,9 +230,9 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
if not hasattr(x, 'shape')))) from err
|
||||
|
||||
if (config.sharding_in_types.value and
|
||||
not all(x.sharding.spec[0] is None for x in xs_flat)):
|
||||
not all(x.aval.sharding.spec[0] is None for x in xs_flat)):
|
||||
raise ValueError('0th dimension of all xs should be replicated. Got '
|
||||
f'{", ".join(str(x.sharding.spec) for x in xs_flat)}')
|
||||
f'{", ".join(str(x.aval.sharding.spec) for x in xs_flat)}')
|
||||
|
||||
if length is not None:
|
||||
try:
|
||||
|
@ -586,7 +586,7 @@ def _convert_element_type(
|
||||
|
||||
if (config.sharding_in_types.value and sharding is None and
|
||||
isinstance(operand, Array)):
|
||||
sharding = operand.sharding
|
||||
sharding = operand.aval.sharding
|
||||
|
||||
sharding = canonicalize_sharding(sharding, check_mesh_consistency=False) # type: ignore
|
||||
|
||||
@ -1920,7 +1920,8 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
|
||||
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
|
||||
fill_value = _convert_element_type(fill_value, dtype, weak_type)
|
||||
if (sharding is not None and not isinstance(sharding, PmapSharding) and
|
||||
isinstance(fill_value, array.ArrayImpl)):
|
||||
isinstance(fill_value, array.ArrayImpl) and
|
||||
not config.sharding_in_types.value):
|
||||
broadcast_shape = sharding.shard_shape(shape)
|
||||
shard = broadcast(fill_value, broadcast_shape)
|
||||
return array.make_array_from_callback(shape, sharding, lambda _: shard)
|
||||
@ -2137,7 +2138,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
|
||||
|
||||
if (config.sharding_in_types.value and sharding is None and
|
||||
isinstance(x, Array)):
|
||||
sharding = x.sharding
|
||||
sharding = x.aval.sharding
|
||||
else:
|
||||
# If `x` has a sharding but no `_committed` attribute
|
||||
# (in case of ShapeDtypeStruct), default it to True.
|
||||
@ -4496,7 +4497,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions,
|
||||
broadcast_dimensions=broadcast_dimensions)
|
||||
if config.sharding_in_types.value:
|
||||
if sharding is not None:
|
||||
assert sharding == aval_out.sharding
|
||||
assert sharding == aval_out.sharding, (sharding, aval_out.sharding)
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
@ -5656,7 +5657,7 @@ def _compute_argminmax(value_comparator, get_identity,
|
||||
axis, = axes
|
||||
indices = broadcasted_iota(
|
||||
index_dtype, np.shape(operand), axis,
|
||||
sharding=operand.sharding if config.sharding_in_types.value else None)
|
||||
sharding=operand.aval.sharding if config.sharding_in_types.value else None)
|
||||
res = reduce([operand, indices],
|
||||
[get_identity(operand.dtype), np.array(0, index_dtype)],
|
||||
_ArgMinMaxReducer(value_comparator),
|
||||
@ -6644,7 +6645,7 @@ _zeros: Callable = partial(full_like, fill_value=0)
|
||||
def _zero(x):
|
||||
if config.sharding_in_types.value:
|
||||
return full_like(x, shape=(), fill_value=0,
|
||||
sharding=x.sharding.with_spec(P())) # type: ignore
|
||||
sharding=x.aval.sharding.with_spec(P())) # type: ignore
|
||||
return full_like(x, shape=(), fill_value=0)
|
||||
|
||||
_ones: Callable = partial(full_like, fill_value=1)
|
||||
@ -6652,7 +6653,7 @@ _ones: Callable = partial(full_like, fill_value=1)
|
||||
def _one(x):
|
||||
if config.sharding_in_types.value:
|
||||
return full_like(x, shape=(), fill_value=1,
|
||||
sharding=x.sharding.with_spec(P()))
|
||||
sharding=x.aval.sharding.with_spec(P()))
|
||||
return full_like(x, shape=(), fill_value=1)
|
||||
|
||||
_twos: Callable = partial(full_like, fill_value=2)
|
||||
|
@ -667,7 +667,7 @@ def _one_hot(x: Array, num_classes: int, *,
|
||||
rhs_shape.insert(output_pos_axis, num_classes)
|
||||
if config.sharding_in_types.value:
|
||||
# TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too?
|
||||
rhs_sharding = NamedSharding(x.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
|
||||
rhs_sharding = NamedSharding(x.aval.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
|
||||
else:
|
||||
rhs_sharding = None
|
||||
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
|
||||
|
@ -5553,7 +5553,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
weak_type = dtype is None and dtypes.is_weakly_typed(object)
|
||||
if (config.sharding_in_types.value and device is None and
|
||||
isinstance(object, Array)):
|
||||
sharding = object.sharding
|
||||
sharding = object.aval.sharding
|
||||
else:
|
||||
sharding = canonicalize_device_to_sharding(device) # type: ignore
|
||||
|
||||
|
@ -2838,7 +2838,7 @@ def hidden_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
def decorator(*args, **kwargs):
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
in_specs = tree_map(lambda a: core.modify_spec_for_hidden(
|
||||
a.sharding.spec, new_mesh), args)
|
||||
a.aval.sharding.spec, new_mesh), args)
|
||||
args = mesh_cast(args, in_specs)
|
||||
out = fun(*args, **kwargs)
|
||||
return mesh_cast(out, out_shardings)
|
||||
@ -2859,7 +2859,7 @@ def visible_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
args = mesh_cast(args, in_shardings)
|
||||
out = fun(*args, **kwargs)
|
||||
out_specs = tree_map(lambda o: core.modify_spec_for_hidden(
|
||||
o.sharding.spec, mesh_lib.get_abstract_mesh()), out)
|
||||
o.aval.sharding.spec, mesh_lib.get_abstract_mesh()), out)
|
||||
return mesh_cast(out, out_specs)
|
||||
return decorator
|
||||
|
||||
|
@ -4788,15 +4788,21 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
self.assertEqual(x.aval.sharding.spec, s.spec)
|
||||
x = x * 2
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
self.assertEqual(x.aval.sharding.spec, s.spec)
|
||||
x = x * x
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
self.assertEqual(x.aval.sharding.spec, s.spec)
|
||||
return x
|
||||
|
||||
# Eager mode
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
|
||||
|
||||
f = jax.jit(f)
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
|
||||
@ -4832,9 +4838,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
self.assertEqual(x.aval.sharding.spec, s.spec)
|
||||
out = x * y
|
||||
self.assertEqual(out.sharding.spec, s.spec)
|
||||
self.assertEqual(out.aval.sharding.spec, s.spec)
|
||||
return out
|
||||
|
||||
out = f(arr1, arr2)
|
||||
@ -4876,16 +4882,21 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1))
|
||||
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2))
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
out = x @ y
|
||||
self.assertEqual(out.sharding.spec, out_spec)
|
||||
self.assertEqual(out.aval.sharding.spec, out_spec)
|
||||
return out
|
||||
|
||||
out = f(arr1, arr2)
|
||||
self.assertArraysEqual(out, np_inp1 @ np_inp1.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
|
||||
|
||||
f = jax.jit(f)
|
||||
|
||||
out = f(arr1, arr2)
|
||||
self.assertArraysEqual(out, np_inp1 @ np_inp1.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
|
||||
|
||||
lowered = f.lower(arr1, arr2)
|
||||
self.check_wsc_in_lowered(lowered.as_text())
|
||||
|
||||
@ -4912,16 +4923,21 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
|
||||
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
self.assertEqual(out.aval.sharding.spec, P('x', None))
|
||||
return jnp.sum(out)
|
||||
|
||||
out = f(arr1, arr2)
|
||||
self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
|
||||
|
||||
f = jax.jit(f)
|
||||
|
||||
out = f(arr1, arr2)
|
||||
self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
|
||||
|
||||
out = jax.grad(f, argnums=(0, 1))(arr1, arr2)
|
||||
self.assertEqual(out[0].sharding, arr1.sharding)
|
||||
self.assertEqual(out[1].sharding, arr2.sharding)
|
||||
@ -4999,6 +5015,16 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P(('model', 'data'), None)))
|
||||
self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]')
|
||||
|
||||
@jtu.with_user_mesh((2, 1), ('x', 'y'))
|
||||
def test_jnp_ones_mesh_context_eager(self, mesh):
|
||||
s = NamedSharding(mesh, P('x', None))
|
||||
out = jnp.ones((8, 2), dtype=jnp.int32, device=s)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
out = jnp.ones((8, 2), dtype=jnp.int32, device=s)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('all', None, P('x', 'y'), P(), True),
|
||||
('first', 0, P('x', 'y'), P('y'), True),
|
||||
@ -5014,9 +5040,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
self.assertEqual(x.aval.sharding.spec, s.spec)
|
||||
y = jnp.sum(x, axis=axis)
|
||||
self.assertEqual(y.sharding.spec, out_spec)
|
||||
self.assertEqual(y.aval.sharding.spec, out_spec)
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
@ -5045,9 +5071,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
self.assertEqual(x.aval.sharding.spec, s.spec)
|
||||
y = jnp.max(x, axis=axis)
|
||||
self.assertEqual(y.sharding.spec, out_spec)
|
||||
self.assertEqual(y.aval.sharding.spec, out_spec)
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
@ -5090,7 +5116,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = jnp.expand_dims(x, axis=axis)
|
||||
self.assertEqual(y.sharding.spec, out_spec)
|
||||
self.assertEqual(y.aval.sharding.spec, out_spec)
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
@ -5113,7 +5139,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x ** pow
|
||||
self.assertEqual(y.sharding.spec, s.spec)
|
||||
self.assertEqual(y.aval.sharding.spec, s.spec)
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
@ -5136,7 +5162,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
return x + y
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "For primitive add, context mesh.*aval mesh"):
|
||||
ValueError, "For primitive.*context mesh.*aval mesh"):
|
||||
f(arr1, arr2)
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
@ -5148,7 +5174,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = lax.sin(x)
|
||||
self.assertEqual(y.sharding.spec, s.spec)
|
||||
self.assertEqual(y.aval.sharding.spec, s.spec)
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
@ -5168,7 +5194,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
assert x.dtype == jnp.int32
|
||||
y = jnp.array(x, dtype=jnp.float32)
|
||||
self.assertEqual(y.dtype, jnp.float32)
|
||||
self.assertEqual(y.sharding.spec, s.spec)
|
||||
self.assertEqual(y.aval.sharding.spec, s.spec)
|
||||
return y
|
||||
|
||||
f(arr)
|
||||
@ -5182,7 +5208,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = jnp.transpose(x, (1, 2, 0))
|
||||
self.assertEqual(y.sharding.spec, P('y', 'z', 'x'))
|
||||
self.assertEqual(y.aval.sharding.spec, P('y', 'z', 'x'))
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
@ -5201,7 +5227,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = jax.nn.one_hot(x, 4)
|
||||
self.assertEqual(y.sharding.spec, P('x', None))
|
||||
self.assertEqual(y.aval.sharding.spec, P('x', None))
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
@ -5211,7 +5237,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def g(x):
|
||||
x = x * 2
|
||||
y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, sharding=P('x', 'y'))
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
|
||||
return x, y
|
||||
|
||||
_, out = g(arr)
|
||||
@ -5226,8 +5252,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y,
|
||||
out_sharding=NamedSharding(x.sharding.mesh, P('x', None)))
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
out_sharding=NamedSharding(x.aval.sharding.mesh, P('x', None)))
|
||||
self.assertEqual(out.aval.sharding.spec, P('x', None))
|
||||
return out
|
||||
|
||||
out = f(arr1, arr2)
|
||||
@ -5240,7 +5266,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
self.assertEqual(out.aval.sharding.spec, P('x', None))
|
||||
return out
|
||||
|
||||
arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
@ -5270,7 +5296,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def h(x, y):
|
||||
spec = P('x', None, 'y', None)
|
||||
out = jnp.einsum('btd,dhq->bhtq', x, y, out_sharding=spec)
|
||||
self.assertEqual(out.sharding.spec, spec)
|
||||
self.assertEqual(out.aval.sharding.spec, spec)
|
||||
return out
|
||||
|
||||
arr1 = jax.device_put(np_inp.reshape(8, 4, 2),
|
||||
@ -5315,7 +5341,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def f(x, new_sharding):
|
||||
y = lax.reshape(x, dst_shape, sharding=new_sharding)
|
||||
y = y * 2
|
||||
self.assertEqual(y.sharding.spec, dst_spec)
|
||||
self.assertEqual(y.aval.sharding.spec, dst_spec)
|
||||
return y
|
||||
|
||||
new_s = dst_spec if use_sharding_arg else None
|
||||
@ -5384,7 +5410,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
y = lax.reshape(x, dst_shape)
|
||||
y = y * 2
|
||||
self.assertEqual(y.sharding.spec, dst_spec)
|
||||
self.assertEqual(y.aval.sharding.spec, dst_spec)
|
||||
return y
|
||||
|
||||
if error_msg:
|
||||
@ -5415,7 +5441,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(pred, on_true, on_false):
|
||||
y = lax.select(pred, on_true, on_false)
|
||||
self.assertEqual(y.sharding.spec, s.spec)
|
||||
self.assertEqual(y.aval.sharding.spec, s.spec)
|
||||
return y
|
||||
|
||||
out = f(arr1 == arr2, arr1, arr2)
|
||||
@ -5438,7 +5464,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = mesh_cast(x, NamedSharding(x.sharding.mesh, P('x', None)))
|
||||
y = mesh_cast(x, NamedSharding(x.aval.sharding.mesh, P('x', None)))
|
||||
return y
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
@ -5481,18 +5507,19 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
def g(x, y):
|
||||
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(x.aval.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(y.aval.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective)
|
||||
return x * y
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec),
|
||||
z = shard_map(g, mesh=mesh,
|
||||
in_specs=(x.aval.sharding.spec, y.aval.sharding.spec),
|
||||
out_specs=P('x', 'y'))(x, y)
|
||||
self.assertEqual(z.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(z.aval.sharding.spec, P('x', 'y'))
|
||||
out = z * 2
|
||||
self.assertEqual(out.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(out.aval.sharding.spec, P('x', 'y'))
|
||||
return out
|
||||
|
||||
out = f(arr, arr2)
|
||||
@ -5506,8 +5533,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
|
||||
|
||||
def g(x, y):
|
||||
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(x.aval.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(y.aval.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective)
|
||||
allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True)
|
||||
z = x @ allgatherd_y
|
||||
@ -5515,11 +5542,12 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec),
|
||||
z = shard_map(g, mesh=mesh,
|
||||
in_specs=(x.aval.sharding.spec, y.aval.sharding.spec),
|
||||
out_specs=P('x', None))(x, y)
|
||||
self.assertEqual(z.sharding.spec, P('x', None))
|
||||
self.assertEqual(z.aval.sharding.spec, P('x', None))
|
||||
out = z * 2
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
self.assertEqual(out.aval.sharding.spec, P('x', None))
|
||||
return out
|
||||
|
||||
out = f(arr, arr2)
|
||||
@ -5534,7 +5562,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = lax.slice(x, (0, 0), (4, 3))
|
||||
self.assertEqual(y.sharding.spec, P('x', None))
|
||||
self.assertEqual(y.aval.sharding.spec, P('x', None))
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
@ -5565,7 +5593,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = lax.squeeze(x, (2,))
|
||||
self.assertEqual(y.sharding.spec, P('x', None))
|
||||
self.assertEqual(y.aval.sharding.spec, P('x', None))
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
@ -5591,7 +5619,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@partial(jax.jit, static_argnums=(1, 2))
|
||||
def f(x, padding_config, spec):
|
||||
y = lax.pad(x, 0., padding_config)
|
||||
self.assertEqual(y.sharding.spec, spec)
|
||||
self.assertEqual(y.aval.sharding.spec, spec)
|
||||
return y
|
||||
|
||||
out = f(arr, ((2, 2, 0),), P('x'))
|
||||
@ -5639,7 +5667,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
else:
|
||||
assert method == 'lax'
|
||||
y = lax.concatenate([x, y], dimension=1)
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
|
||||
return y
|
||||
|
||||
out = f(arr1, arr2)
|
||||
@ -5677,14 +5705,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(carry, xs):
|
||||
def g(carry, x):
|
||||
self.assertEqual(carry.sharding.spec, P(None, 'x'))
|
||||
self.assertEqual(x.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(carry.aval.sharding.spec, P(None, 'x'))
|
||||
self.assertEqual(x.aval.sharding.spec, P('x', 'y'))
|
||||
y = carry @ x
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
|
||||
z = jax.nn.relu(y)
|
||||
self.assertEqual(z.sharding.spec, P(None, 'y'))
|
||||
self.assertEqual(z.aval.sharding.spec, P(None, 'y'))
|
||||
a = z @ x.T
|
||||
self.assertEqual(a.sharding.spec, P(None, 'x'))
|
||||
self.assertEqual(a.aval.sharding.spec, P(None, 'x'))
|
||||
return a, y
|
||||
return jax.lax.scan(g, carry, xs)
|
||||
|
||||
@ -5714,9 +5742,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
z = jnp.argmax(x, axis=0)
|
||||
self.assertEqual(z.sharding.spec, P('y'))
|
||||
self.assertEqual(z.aval.sharding.spec, P('y'))
|
||||
a = jnp.argmin(x, axis=1)
|
||||
self.assertEqual(a.sharding.spec, P('x'))
|
||||
self.assertEqual(a.aval.sharding.spec, P('x'))
|
||||
return z, a
|
||||
|
||||
out1, out2 = f(arr)
|
||||
@ -5734,11 +5762,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x, x2):
|
||||
y = x * 2
|
||||
self.assertEqual(y.sharding.spec, P(None, None))
|
||||
self.assertEqual(y.aval.sharding.spec, P(None, None))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P(None, None))
|
||||
self.assertEqual(z.aval.sharding.spec, P(None, None))
|
||||
a = z @ x2
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
self.assertEqual(a.aval.sharding.spec, P(None, None))
|
||||
return a
|
||||
|
||||
out = f(arr, arr.T)
|
||||
@ -5819,13 +5847,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
y = x * 2
|
||||
with use_hidden_axes('x', 'y'):
|
||||
y = mesh_cast(y, P(None, None))
|
||||
self.assertEqual(y.sharding.spec, P(None, None))
|
||||
self.assertEqual(y.aval.sharding.spec, P(None, None))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P(None, None))
|
||||
self.assertEqual(z.aval.sharding.spec, P(None, None))
|
||||
a = z @ z.T
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
self.assertEqual(a.aval.sharding.spec, P(None, None))
|
||||
a = mesh_cast(a, P('x', None))
|
||||
self.assertEqual(a.sharding.spec, P('x', None))
|
||||
self.assertEqual(a.aval.sharding.spec, P('x', None))
|
||||
return a
|
||||
|
||||
out = f(arr)
|
||||
@ -5847,13 +5875,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
y = x * 2
|
||||
with use_visible_axes('x', 'y'):
|
||||
y = mesh_cast(y, P(None, 'y'))
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P(None, 'y'))
|
||||
self.assertEqual(z.aval.sharding.spec, P(None, 'y'))
|
||||
a = z @ z.T
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
self.assertEqual(a.aval.sharding.spec, P(None, None))
|
||||
a = mesh_cast(a, P(None, None))
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
self.assertEqual(a.aval.sharding.spec, P(None, None))
|
||||
return a
|
||||
|
||||
out = f(arr)
|
||||
@ -5873,13 +5901,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
y = x * 2
|
||||
with use_hidden_axes('x'):
|
||||
y = mesh_cast(y, P(None, 'y'))
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P(None, 'y'))
|
||||
self.assertEqual(z.aval.sharding.spec, P(None, 'y'))
|
||||
a = z @ z.T
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
self.assertEqual(a.aval.sharding.spec, P(None, None))
|
||||
a = mesh_cast(a, P('x', None))
|
||||
self.assertEqual(a.sharding.spec, P('x', None))
|
||||
self.assertEqual(a.aval.sharding.spec, P('x', None))
|
||||
return a
|
||||
|
||||
out = f(arr)
|
||||
@ -5915,8 +5943,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@partial(jax.jit, static_argnums=(1, 2))
|
||||
def f(x, sizes=(4, 4), axis=0):
|
||||
ys = lax.split(x, sizes, axis=axis)
|
||||
self.assertEqual(ys[0].sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(ys[1].sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(ys[0].aval.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(ys[1].aval.sharding.spec, P('x', 'y'))
|
||||
return ys
|
||||
|
||||
f(arr)
|
||||
@ -6010,7 +6038,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
self.assertEqual(y.sharding.spec, P('x', None, None))
|
||||
self.assertEqual(y.aval.sharding.spec, P('x', None, None))
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
@ -6032,18 +6060,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@partial(hidden_axes, axes='x', out_shardings=P('x', None))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P(None, 'y'))
|
||||
self.assertEqual(z.aval.sharding.spec, P(None, 'y'))
|
||||
a = z @ z.T
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
self.assertEqual(a.aval.sharding.spec, P(None, None))
|
||||
return a
|
||||
|
||||
@jax.jit
|
||||
def g(x):
|
||||
y = x * 2
|
||||
a = h(y)
|
||||
self.assertEqual(a.sharding.spec, P('x', None))
|
||||
self.assertEqual(a.aval.sharding.spec, P('x', None))
|
||||
return a
|
||||
|
||||
out = g(arr)
|
||||
@ -6063,18 +6091,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
# No axes specified means full visible mode.
|
||||
@partial(visible_axes, in_shardings=P('x', 'y'))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(z.aval.sharding.spec, P('x', 'y'))
|
||||
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', None))
|
||||
self.assertEqual(a.sharding.spec, P('x', None))
|
||||
self.assertEqual(a.aval.sharding.spec, P('x', None))
|
||||
return a
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
a = h(y)
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
self.assertEqual(a.aval.sharding.spec, P(None, None))
|
||||
return a
|
||||
|
||||
out = f(arr)
|
||||
@ -6093,18 +6121,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@partial(visible_axes, axes='y', in_shardings=P('x', 'y'))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(z.aval.sharding.spec, P('x', 'y'))
|
||||
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', 'y'))
|
||||
self.assertEqual(a.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(a.aval.sharding.spec, P('x', 'y'))
|
||||
return a
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
a = h(y)
|
||||
self.assertEqual(a.sharding.spec, P('x', None))
|
||||
self.assertEqual(a.aval.sharding.spec, P('x', None))
|
||||
return a
|
||||
|
||||
out = f(arr)
|
||||
@ -6119,18 +6147,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@partial(visible_axes, axes='y', in_shardings=P(None, 'y'))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P(None, 'y'))
|
||||
self.assertEqual(z.aval.sharding.spec, P(None, 'y'))
|
||||
a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P(None, 'y'))
|
||||
self.assertEqual(a.sharding.spec, P(None, 'y'))
|
||||
self.assertEqual(a.aval.sharding.spec, P(None, 'y'))
|
||||
return a
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
a = h(y)
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
self.assertEqual(a.aval.sharding.spec, P(None, None))
|
||||
return a
|
||||
|
||||
out = f(arr)
|
||||
@ -6147,7 +6175,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def f(embed_vd, token_bt):
|
||||
out = embed_vd.at[token_bt].get(out_sharding=P('x', None, None))
|
||||
self.assertEqual(out.shape, (8, 4, 16))
|
||||
self.assertEqual(out.sharding.spec, P('x', None, None))
|
||||
self.assertEqual(out.aval.sharding.spec, P('x', None, None))
|
||||
return out
|
||||
|
||||
out = f(embed, tok)
|
||||
@ -6213,11 +6241,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
self.assertEqual(y.sharding.spec, P(None, None))
|
||||
self.assertEqual(y.aval.sharding.spec, P(None, None))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P(None, None))
|
||||
self.assertEqual(z.aval.sharding.spec, P(None, None))
|
||||
a = z @ z.T
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
self.assertEqual(a.aval.sharding.spec, P(None, None))
|
||||
return a
|
||||
|
||||
hf = hidden_axes(f, axes=('x', 'y'), out_shardings=P('x', 'y'))
|
||||
@ -6234,9 +6262,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(y.aval.sharding.spec, P('x', 'y'))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P('x', 'y'))
|
||||
self.assertEqual(z.aval.sharding.spec, P('x', 'y'))
|
||||
return z
|
||||
|
||||
hf = visible_axes(f, axes=('x', 'y'), in_shardings=P('x', 'y'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user