mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec
instead of just NamedSharding
as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`. We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions. PiperOrigin-RevId: 713352542
This commit is contained in:
parent
c1a60c676a
commit
3848f0d2ac
@ -866,6 +866,7 @@ pytype_strict_library(
|
||||
":partition_spec",
|
||||
":sharding",
|
||||
":sharding_specs",
|
||||
":source_info_util",
|
||||
":tree_util",
|
||||
":util",
|
||||
":xla_bridge",
|
||||
|
@ -1696,6 +1696,9 @@ class ShapedArray(UnshapedArray):
|
||||
self.weak_type = weak_type
|
||||
if config.sharding_in_types.value:
|
||||
self.sharding = get_sharding(sharding, len(self.shape))
|
||||
if not isinstance(self.sharding.mesh, mesh_lib.AbstractMesh):
|
||||
raise ValueError(
|
||||
f"Mesh of an aval must be an AbstractMesh. Got {self.sharding.mesh}")
|
||||
|
||||
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
|
||||
if shape is None:
|
||||
|
@ -64,7 +64,7 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
|
||||
PartitionSpec as P)
|
||||
PartitionSpec as P, canonicalize_sharding)
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
|
||||
from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis,
|
||||
safe_map, safe_zip, split_list, weakref_lru_cache)
|
||||
@ -586,6 +586,8 @@ def _convert_element_type(
|
||||
isinstance(operand, Array)):
|
||||
sharding = operand.sharding
|
||||
|
||||
sharding = canonicalize_sharding(sharding, check_mesh_consistency=False) # type: ignore
|
||||
|
||||
if (warn_on_complex_to_real_cast and
|
||||
dtypes.issubdtype(old_dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||||
@ -1431,6 +1433,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
|
||||
if not config.sharding_in_types.value and sharding is not None:
|
||||
raise NotImplementedError("sharding argument to broadcast_in_dim is only "
|
||||
"allowed when sharding_in_types config is on.")
|
||||
sharding = canonicalize_sharding(sharding)
|
||||
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and
|
||||
isinstance(operand, Array) and sharding is None):
|
||||
return operand
|
||||
@ -1505,7 +1508,7 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
|
||||
return operand
|
||||
else:
|
||||
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
|
||||
|
||||
sharding = canonicalize_sharding(sharding)
|
||||
return reshape_p.bind(
|
||||
operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
|
||||
dimensions=None if dims is None or same_dims else dims,
|
||||
@ -1947,7 +1950,7 @@ def iota(dtype: DTypeLike, size: int) -> Array:
|
||||
return broadcasted_iota(dtype, (size,), 0)
|
||||
|
||||
def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
|
||||
_sharding=None) -> Array:
|
||||
sharding=None) -> Array:
|
||||
"""Convenience wrapper around ``iota``."""
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
shape = canonicalize_shape(shape)
|
||||
@ -1955,11 +1958,12 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
|
||||
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
|
||||
dimension = core.concrete_or_error(
|
||||
int, dimension, "dimension argument of lax.broadcasted_iota")
|
||||
if not config.sharding_in_types.value and _sharding is not None:
|
||||
if not config.sharding_in_types.value and sharding is not None:
|
||||
raise NotImplementedError('sharding support for broadcasted_iota is not '
|
||||
'implemented outside of sharding_in_types mode.')
|
||||
sharding = canonicalize_sharding(sharding)
|
||||
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
|
||||
dimension=dimension, sharding=_sharding)
|
||||
dimension=dimension, sharding=sharding)
|
||||
|
||||
def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array:
|
||||
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
|
||||
@ -5560,7 +5564,7 @@ def _compute_argminmax(value_comparator, get_identity,
|
||||
axis, = axes
|
||||
indices = broadcasted_iota(
|
||||
index_dtype, np.shape(operand), axis,
|
||||
_sharding=operand.sharding if config.sharding_in_types.value else None)
|
||||
sharding=operand.sharding if config.sharding_in_types.value else None)
|
||||
res = reduce([operand, indices],
|
||||
[get_identity(operand.dtype), np.array(0, index_dtype)],
|
||||
_ArgMinMaxReducer(value_comparator),
|
||||
|
@ -671,7 +671,7 @@ def _one_hot(x: Array, num_classes: int, *,
|
||||
else:
|
||||
rhs_sharding = None
|
||||
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
|
||||
_sharding=rhs_sharding)
|
||||
sharding=rhs_sharding)
|
||||
return (lhs == rhs).astype(dtype)
|
||||
|
||||
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
|
||||
|
@ -54,7 +54,7 @@ from jax._src.array import ArrayImpl
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.custom_derivatives import custom_jvp
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax.lax import ( PrecisionLike,_array_copy,
|
||||
from jax._src.lax.lax import (PrecisionLike,_array_copy,
|
||||
_sort_le_comparator, _sort_lt_comparator)
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.numpy import reductions
|
||||
@ -69,8 +69,9 @@ from jax._src.util import (
|
||||
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
|
||||
ceil_of_ratio, partition_list, safe_zip, set_module, unzip2,
|
||||
tuple_replace)
|
||||
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
|
||||
PartitionSpec as P)
|
||||
from jax.sharding import Sharding
|
||||
from jax._src.sharding_impls import (SingleDeviceSharding, NamedSharding,
|
||||
PartitionSpec as P, canonicalize_sharding)
|
||||
from jax.tree_util import tree_flatten, tree_leaves, tree_map
|
||||
import numpy as np
|
||||
import opt_einsum
|
||||
@ -9873,6 +9874,7 @@ def _einsum(
|
||||
if out_type is not None and not config.sharding_in_types.value:
|
||||
raise NotImplementedError("out_type only works when sharding_in_types "
|
||||
"config is True.")
|
||||
out_type = canonicalize_sharding(out_type)
|
||||
if out_type is not None and not isinstance(out_type, NamedSharding):
|
||||
raise NotImplementedError(
|
||||
"`out_type` argument of `einsum` only supports NamedSharding instances."
|
||||
|
@ -67,7 +67,8 @@ from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, GSPMDSharding,
|
||||
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
|
||||
ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, parse_flatten_op_sharding)
|
||||
ParsedPartitionSpec, get_single_pspec, prepare_axis_resources,
|
||||
parse_flatten_op_sharding, canonicalize_sharding)
|
||||
from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout
|
||||
from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef
|
||||
from jax._src.traceback_util import api_boundary
|
||||
@ -2670,13 +2671,20 @@ batching.skippable_batchers[sharding_constraint_p] = lambda _: ()
|
||||
|
||||
def sharding_cast(xs, shardings):
|
||||
if isinstance(shardings, NamedSharding):
|
||||
return tree_map(lambda x: sharding_cast_p.bind(
|
||||
x, src_sharding=x.sharding, dst_sharding=shardings), xs)
|
||||
return tree_map(
|
||||
lambda x: sharding_cast_p.bind(
|
||||
x, src_sharding=x.sharding, dst_sharding=canonicalize_sharding(
|
||||
shardings, check_mesh_consistency=False)),
|
||||
xs)
|
||||
|
||||
x_flat, treedef = tree_flatten(xs)
|
||||
shardings_flat = flatten_axes("sharding_cast shardings", treedef, shardings)
|
||||
out_flat = [sharding_cast_p.bind(x, src_sharding=x.sharding, dst_sharding=s)
|
||||
for x, s in safe_zip(x_flat, shardings_flat)]
|
||||
out_flat = [
|
||||
sharding_cast_p.bind(
|
||||
x, src_sharding=x.sharding,
|
||||
dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False))
|
||||
for x, s in safe_zip(x_flat, shardings_flat)
|
||||
]
|
||||
return tree_unflatten(treedef, out_flat)
|
||||
|
||||
sharding_cast_p = core.Primitive('sharding_cast')
|
||||
|
@ -24,11 +24,13 @@ import math
|
||||
from typing import Any, NamedTuple, Union, cast
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import sharding
|
||||
from jax._src import sharding as jsharding
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import source_info_util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import mesh_utils
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -45,7 +47,7 @@ Device = xc.Device
|
||||
Index = tuple[slice, ...]
|
||||
XLADeviceAssignment = tuple[Device, ...]
|
||||
# TODO(yashkatariya): Remove this after 3 months of deprecation.
|
||||
XLACompatibleSharding = sharding.Sharding
|
||||
XLACompatibleSharding = jsharding.Sharding
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TransferToMemoryKind:
|
||||
@ -219,7 +221,7 @@ def named_sharding_to_xla_hlo_sharding(
|
||||
|
||||
|
||||
@use_cpp_class(xc.NamedSharding)
|
||||
class NamedSharding(sharding.Sharding):
|
||||
class NamedSharding(jsharding.Sharding):
|
||||
r"""A :class:`NamedSharding` expresses sharding using named axes.
|
||||
|
||||
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
|
||||
@ -388,9 +390,6 @@ class NamedSharding(sharding.Sharding):
|
||||
spec = PartitionSpec(*spec)
|
||||
return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind)
|
||||
|
||||
def with_mesh(self, new_mesh: mesh_lib.Mesh) -> NamedSharding:
|
||||
return NamedSharding(new_mesh, self.spec, memory_kind=self.memory_kind)
|
||||
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
|
||||
|
||||
@ -415,7 +414,7 @@ def get_replicated_hlo_sharding():
|
||||
|
||||
|
||||
@use_cpp_class(xc.SingleDeviceSharding)
|
||||
class SingleDeviceSharding(sharding.Sharding):
|
||||
class SingleDeviceSharding(jsharding.Sharding):
|
||||
"""A :class:`Sharding` that places its data on a single device.
|
||||
|
||||
Args:
|
||||
@ -503,7 +502,7 @@ def pmap_sharding_devices_indices_map(
|
||||
|
||||
|
||||
@use_cpp_class(xc.PmapSharding)
|
||||
class PmapSharding(sharding.Sharding):
|
||||
class PmapSharding(jsharding.Sharding):
|
||||
"""Describes a sharding used by :func:`jax.pmap`."""
|
||||
devices: np.ndarray
|
||||
sharding_spec: sharding_specs.ShardingSpec
|
||||
@ -713,7 +712,7 @@ def _positional_sharding_to_xla_hlo_sharding(
|
||||
return xc.HloSharding.from_proto(pbuf)
|
||||
|
||||
|
||||
class PositionalSharding(sharding.Sharding):
|
||||
class PositionalSharding(jsharding.Sharding):
|
||||
_devices: tuple[xc.Device, ...]
|
||||
_memory_kind: str | None
|
||||
_ids: np.ndarray # dtype DeviceIdSet
|
||||
@ -820,7 +819,7 @@ class PositionalSharding(sharding.Sharding):
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return self.shape == (1,) * self.ndim
|
||||
|
||||
# sharding.Sharding interface
|
||||
# jsharding.Sharding interface
|
||||
|
||||
@property
|
||||
def _device_assignment(self) -> XLADeviceAssignment:
|
||||
@ -868,7 +867,7 @@ class DeviceIdSet:
|
||||
|
||||
|
||||
@use_cpp_class(xc.GSPMDSharding)
|
||||
class GSPMDSharding(sharding.Sharding):
|
||||
class GSPMDSharding(jsharding.Sharding):
|
||||
_devices: tuple[Device, ...]
|
||||
_hlo_sharding: xc.HloSharding
|
||||
_memory_kind: str | None
|
||||
@ -1122,7 +1121,7 @@ def prepare_axis_resources(axis_resources, arg_name,
|
||||
for entry in entries:
|
||||
if isinstance(entry, (UnspecifiedValue, AUTO)) or entry is None:
|
||||
new_entries.append(entry)
|
||||
elif isinstance(entry, sharding.Sharding):
|
||||
elif isinstance(entry, jsharding.Sharding):
|
||||
if isinstance(entry, PmapSharding):
|
||||
raise ValueError(f'One of {what} got sharding {entry} which is not '
|
||||
'allowed.')
|
||||
@ -1138,7 +1137,7 @@ def prepare_axis_resources(axis_resources, arg_name,
|
||||
def _check_unique_resources(axis_resources, arg_name):
|
||||
for arg_axis_resources in axis_resources:
|
||||
if not arg_axis_resources: continue
|
||||
if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, sharding.Sharding)):
|
||||
if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, jsharding.Sharding)):
|
||||
continue
|
||||
constrained_dims = [d for d in arg_axis_resources if d is not None]
|
||||
resource_counts = collections.Counter(
|
||||
@ -1371,7 +1370,7 @@ class NonUniformShardingError(ValueError):
|
||||
|
||||
|
||||
def get_process_index_and_count(
|
||||
tensor_sharding: sharding.Sharding, dim: int, ndims: int) -> tuple[int, int]:
|
||||
tensor_sharding: jsharding.Sharding, dim: int, ndims: int) -> tuple[int, int]:
|
||||
"""Get current process index and number of unique processes for given dimension.
|
||||
|
||||
This function facilitates mapping of process-level data to individual
|
||||
@ -1486,7 +1485,7 @@ def get_process_index_and_count(
|
||||
|
||||
|
||||
def local_to_global_shape(
|
||||
sharding: sharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]:
|
||||
sharding: jsharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]:
|
||||
"""Computes the global shape given the per process if possible.
|
||||
|
||||
The returned shape will have the size of the global tensor in that dimension
|
||||
@ -1545,7 +1544,7 @@ def local_to_global_shape(
|
||||
|
||||
|
||||
def num_addressable_indices(
|
||||
tensor_sharding: sharding.Sharding, dim: int, global_shape: Shape) -> int:
|
||||
tensor_sharding: jsharding.Sharding, dim: int, global_shape: Shape) -> int:
|
||||
"""Returns the number of indices for given dimension this host has access to.
|
||||
|
||||
Each host can have multiple number of devices that are spanning
|
||||
@ -1579,7 +1578,7 @@ def num_addressable_indices(
|
||||
"""
|
||||
# TODO(sandler, yashkatariya): Consider making this function public.
|
||||
addressables = tensor_sharding.addressable_devices_indices_map(global_shape)
|
||||
addressables = cast(Mapping[sharding.Device, Index], addressables)
|
||||
addressables = cast(Mapping[jsharding.Device, Index], addressables)
|
||||
num_unique_slices = len({
|
||||
_slice_as_tuple(addressable[dim]) for addressable in addressables.values()
|
||||
})
|
||||
@ -1596,7 +1595,7 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
|
||||
new_op_sharding.tile_assignment_dimensions = tad
|
||||
return xc.HloSharding.from_proto(new_op_sharding)
|
||||
|
||||
def is_single_device_sharding(sharding: sharding.Sharding) -> bool:
|
||||
def is_single_device_sharding(sharding: jsharding.Sharding) -> bool:
|
||||
# Special case PmapSharding here because PmapSharding maps away an axis
|
||||
# and needs to be handled separately.test_pjit_single_device_sharding_add
|
||||
return sharding.num_devices == 1 and not isinstance(sharding, PmapSharding)
|
||||
@ -1625,7 +1624,7 @@ def make_key_array_phys_sharding(aval, sharding):
|
||||
|
||||
|
||||
def physical_sharding(
|
||||
aval, sharding: sharding.Sharding) -> sharding.Sharding:
|
||||
aval, sharding: jsharding.Sharding) -> jsharding.Sharding:
|
||||
return make_key_array_phys_sharding(aval, sharding)
|
||||
|
||||
|
||||
@ -1642,7 +1641,7 @@ def get_logical_gspmd_sharding(aval, phys_sharding):
|
||||
return GSPMDSharding(phys_sharding._device_assignment,
|
||||
xc.HloSharding.from_proto(logical_op_sharding))
|
||||
|
||||
def check_replicated_trailing_dims(sharding: sharding.Sharding, aval):
|
||||
def check_replicated_trailing_dims(sharding: jsharding.Sharding, aval):
|
||||
if isinstance(sharding, PmapSharding):
|
||||
return
|
||||
phys_aval = core.physical_aval(aval)
|
||||
@ -1655,7 +1654,7 @@ def check_replicated_trailing_dims(sharding: sharding.Sharding, aval):
|
||||
f" sharding: {sharding}, partitions: {partitions}, "
|
||||
f"num_trailing_dims: {num_trailing_dims}")
|
||||
|
||||
def logical_sharding(aval, phys_sharding) -> sharding.Sharding:
|
||||
def logical_sharding(aval, phys_sharding) -> jsharding.Sharding:
|
||||
# The trailing dims should always be replicated.
|
||||
check_replicated_trailing_dims(phys_sharding, aval)
|
||||
|
||||
@ -1695,6 +1694,44 @@ def _gspmd_to_named_sharding_via_mesh(
|
||||
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
|
||||
out_s.memory_kind)
|
||||
|
||||
def flatten_spec(spec):
|
||||
out = []
|
||||
for s in spec:
|
||||
if s is None:
|
||||
continue
|
||||
if isinstance(s, tuple):
|
||||
out.extend(s)
|
||||
else:
|
||||
out.append(s)
|
||||
return out
|
||||
|
||||
def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
|
||||
check_mesh_consistency: bool = True
|
||||
) -> NamedSharding | None:
|
||||
if not config.sharding_in_types.value:
|
||||
return sharding # type: ignore
|
||||
if sharding is None:
|
||||
return sharding
|
||||
|
||||
if isinstance(sharding, PartitionSpec):
|
||||
sharding = NamedSharding(mesh_lib.get_abstract_mesh(), sharding) # type: ignore
|
||||
else:
|
||||
if (check_mesh_consistency and
|
||||
sharding.mesh != mesh_lib.get_abstract_mesh()):
|
||||
raise ValueError(
|
||||
f'Context mesh {mesh_lib.get_abstract_mesh()} should match the mesh'
|
||||
f' of sharding {sharding.mesh}. This error occurs at source: '
|
||||
f' {source_info_util.summarize(source_info_util.current())}')
|
||||
|
||||
for s in flatten_spec(sharding.spec):
|
||||
if sharding.mesh._name_to_type[s] in {
|
||||
mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Collective}:
|
||||
raise ValueError(
|
||||
'PartitionSpec cannot contain axis names that are of type Auto or'
|
||||
f' Collective. Got PartitionSpec: {sharding.spec} with axis name:'
|
||||
f' {s} or type: {sharding.mesh._name_to_type[s]}')
|
||||
return sharding
|
||||
|
||||
|
||||
def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
|
||||
*, devices: Sequence[xc.Device] | None = None,
|
||||
|
@ -4858,8 +4858,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y,
|
||||
out_type=NamedSharding(x.sharding.mesh, P('x', None)))
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
return jnp.sum(out)
|
||||
|
||||
@ -5155,9 +5154,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def g(x):
|
||||
x = x * 2
|
||||
y = jax.lax.broadcasted_iota(
|
||||
x.dtype, (8, 2), 0,
|
||||
_sharding=NamedSharding(mesh.abstract_mesh, P('x', 'y')))
|
||||
y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, sharding=P('x', 'y'))
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
return x, y
|
||||
|
||||
@ -5186,8 +5183,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y,
|
||||
out_type=NamedSharding(x.sharding.mesh, P('x', None)))
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
return out
|
||||
|
||||
@ -5216,9 +5212,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def h(x, y):
|
||||
s = NamedSharding(x.sharding.mesh, P('x', None, 'y', None))
|
||||
out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=s)
|
||||
self.assertEqual(out.sharding.spec, s.spec)
|
||||
spec = P('x', None, 'y', None)
|
||||
out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=spec)
|
||||
self.assertEqual(out.sharding.spec, spec)
|
||||
return out
|
||||
|
||||
arr1 = jax.device_put(np_inp.reshape(8, 4, 2),
|
||||
@ -5263,8 +5259,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(y.sharding.spec, dst_spec)
|
||||
return y
|
||||
|
||||
new_s = (NamedSharding(mesh.abstract_mesh, dst_spec)
|
||||
if use_sharding_arg else None)
|
||||
new_s = dst_spec if use_sharding_arg else None
|
||||
out = f(arr, new_s)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, dst_spec))
|
||||
self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2)
|
||||
@ -5659,15 +5654,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
y = x * 2
|
||||
auto_mesh = mesh_lib.get_abstract_mesh().with_axis_types(
|
||||
{mesh_lib.AxisTypes.Auto: ('x', 'y')})
|
||||
y = sharding_cast(y, y.sharding.with_mesh(auto_mesh))
|
||||
with mesh_lib.set_abstract_mesh(auto_mesh):
|
||||
y = sharding_cast(y, P(None, None))
|
||||
self.assertEqual(y.sharding.spec, P(None, None))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P(None, None))
|
||||
a = z @ z.T
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
a = sharding_cast(
|
||||
a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None)))
|
||||
a = sharding_cast(a, P('x', None))
|
||||
self.assertEqual(a.sharding.spec, P('x', None))
|
||||
return a
|
||||
|
||||
@ -5697,8 +5691,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(z.sharding.spec, P(None, 'y'))
|
||||
a = z @ z.T
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
a = sharding_cast(
|
||||
a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None)))
|
||||
a = sharding_cast(a, P(None, None))
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
return a
|
||||
|
||||
@ -5719,15 +5712,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
y = x * 2
|
||||
mix_mesh = mesh_lib.get_abstract_mesh().with_axis_types(
|
||||
{mesh_lib.AxisTypes.Auto: 'x', mesh_lib.AxisTypes.User: 'y'})
|
||||
y = sharding_cast(y, y.sharding.with_mesh(mix_mesh))
|
||||
with mesh_lib.set_abstract_mesh(mix_mesh):
|
||||
y = sharding_cast(y, P(None, 'y'))
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P(None, 'y'))
|
||||
a = z @ z.T
|
||||
self.assertEqual(a.sharding.spec, P(None, None))
|
||||
a = sharding_cast(
|
||||
a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None)))
|
||||
a = sharding_cast(a, P('x', None))
|
||||
self.assertEqual(a.sharding.spec, P('x', None))
|
||||
return a
|
||||
|
||||
@ -5759,8 +5751,10 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_sharding_cast_src_dst_mesh_mismatch(self):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
mesh2 = jtu.create_mesh((2, 1), ('a', 'b'))
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.User: ('x', 'y')})
|
||||
mesh2 = jtu.create_mesh((2, 1), ('a', 'b'),
|
||||
axis_types={mesh_lib.AxisTypes.User: ('a', 'b')})
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
f = lambda x: sharding_cast(x, NamedSharding(mesh2, P('a', 'b')))
|
||||
@ -5812,7 +5806,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
auto_mesh = get_abstract_mesh().with_axis_types({AxisTypes.Auto: 'x'})
|
||||
with set_abstract_mesh(auto_mesh):
|
||||
x = sharding_cast(x, x.sharding.with_mesh(auto_mesh))
|
||||
x = sharding_cast(x, P(None, None))
|
||||
return x
|
||||
|
||||
self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.User: 'x'})
|
||||
@ -5850,13 +5844,23 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y,
|
||||
out_type=NamedSharding(auto_mesh, P('x', None)))
|
||||
out_type=NamedSharding(auto_mesh, P(None, None)))
|
||||
return out
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "context mesh.* should match the aval mesh"):
|
||||
ValueError, "Context mesh.*should match the mesh of sharding"):
|
||||
f(arr1, arr2)
|
||||
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
with mesh_lib.set_abstract_mesh(auto_mesh):
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_type=P('x', None))
|
||||
return out
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "PartitionSpec cannot contain axis names.*Auto"):
|
||||
g(arr1, arr2)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user