Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).

PiperOrigin-RevId: 522360232
This commit is contained in:
Peter Hawkins 2023-04-06 09:48:14 -07:00 committed by jax authors
parent b8ade584bf
commit dfe95dcb4e
8 changed files with 451 additions and 373 deletions

View File

@ -182,6 +182,7 @@ py_library_providing_imports_info(
":pretty_printer",
":profiler",
":sharding",
":sharding_specs",
":source_info_util",
":traceback_util",
":tree_util",
@ -411,6 +412,16 @@ pytype_strict_library(
],
)
pytype_strict_library(
name = "sharding_specs",
srcs = ["_src/sharding_specs.py"],
deps = [
":op_shardings",
":util",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "source_info_util",
srcs = ["_src/source_info_util.py"],

View File

@ -45,6 +45,7 @@ from jax._src import dispatch
from jax._src import effects
from jax._src import array
from jax._src import dtypes
from jax._src import sharding_specs
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import pjit
@ -1782,13 +1783,13 @@ class _PmapFastpathData(NamedTuple):
out_handler: Any
out_pytree_def: Any
# Data needed to handle the inputs.
input_sharding_specs: Sequence[pxla.ShardingSpec]
input_sharding_specs: Sequence[sharding_specs.ShardingSpec]
input_devices: Sequence[xc.Device]
input_indices: Sequence[pxla.Index]
input_indices: Sequence[sharding_specs.Index]
input_array_shardings: Sequence[Any]
# Data needed to build the Array from C++.
out_sharding_specs: Sequence[pxla.ShardingSpec]
out_indices: Sequence[pxla.Index]
out_sharding_specs: Sequence[sharding_specs.ShardingSpec]
out_indices: Sequence[sharding_specs.Index]
out_avals: Sequence[Any]
out_array_shardings: Sequence[Any]
out_committed: Sequence[Any]
@ -2582,7 +2583,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
raise ValueError("the shards passed to device_put_sharded must have "
f"consistent shape and dtype, but got {a1} and {a2}.")
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape)
return pxla.batched_device_put(
stacked_aval, PmapSharding(np.array(devices), sharding_spec),
xs, list(devices))
@ -2630,7 +2631,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
core.raise_to_shaped(core.get_aval(x)))
assert (isinstance(aval, ShapedArray) and
len(xla.aval_to_xla_shapes(aval)) == 1)
sharding_spec = pxla._create_pmap_sharding_spec(aval)
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
buf = device_put(x, devices[0])
return pxla.batched_device_put(
aval, PmapSharding(np.array(devices), sharding_spec),

View File

@ -13,21 +13,6 @@
# limitations under the License.
"""Implementation of pmap and related functionality."""
# A ShardingSpec describes at a high level how a logical array is sharded across
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
# describe how to shard inputs to a parallel computation). spec_to_indices()
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
# how the sharded array is "laid out" across devices. Given a sequence of
# devices, we shard the data across the devices in row-major order, with
# replication treated as an extra inner dimension.
#
# For example, given the logical data array [1, 2, 3, 4], if we were to
# partition this array 4 ways with a replication factor of 2, for a total of 8
# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
#
# This encoding is assumed by various parts of the system, e.g. generating
# replica groups for collective operations.
from __future__ import annotations
import enum
@ -40,8 +25,8 @@ import logging
import math
import sys
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
TYPE_CHECKING)
Sequence, Set, Tuple, Type, Union, Iterable,
TYPE_CHECKING, cast)
import numpy as np
@ -57,6 +42,7 @@ from jax._src import effects
from jax._src import linear_util as lu
from jax._src import mesh
from jax._src import op_shardings
from jax._src import sharding_specs
from jax._src import profiler
from jax._src import sharding_impls
from jax._src import source_info_util
@ -73,12 +59,10 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log,
wrap_name, tuple_delete, distributed_debug_log,
unzip2, HashableFunction)
@ -100,214 +84,23 @@ logger = logging.getLogger(__name__)
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
NoSharding = pmap_lib.NoSharding
Chunked = pmap_lib.Chunked
Unstacked = pmap_lib.Unstacked
NoSharding = sharding_specs.NoSharding
Chunked = sharding_specs.Chunked
Unstacked = sharding_specs.Unstacked
ShardedAxis = pmap_lib.ShardedAxis
Replicated = pmap_lib.Replicated
ShardedAxis = sharding_specs.ShardedAxis
Replicated = sharding_specs.Replicated
_UNSHARDED_INSTANCE = NoSharding()
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
Mesh = jax._src.mesh.Mesh
MeshAxisName = mesh.MeshAxisName
MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = pmap_lib.ShardingSpec
ShardingSpec = sharding_specs.ShardingSpec
OpShardingType = Any
PartitionSpec = sharding_impls.PartitionSpec
def sharding_spec_mesh_shape(self):
sharded_axis_sizes = []
for sharding in self.sharding:
if isinstance(sharding, NoSharding):
continue
elif isinstance(sharding, Unstacked):
sharded_axis_sizes.append(sharding.size)
elif isinstance(sharding, Chunked):
sharded_axis_sizes.extend(sharding.chunks)
else:
assert_unreachable(sharding)
return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis) else a.replicas
for a in self.mesh_mapping)
def _get_logical_mesh_ids(mesh_shape):
return np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
def sharding_spec_sharding_proto(self, special_axes: Mapping[int, OpShardingType] = {}):
"""Converts a ShardingSpec to an OpSharding proto.
See
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
Unfortunately the semantics are not very well described in the proto spec, but the code here might help:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py
"""
mesh_shape = cast(Tuple[int, ...], self.mesh_shape)
mesh = _get_logical_mesh_ids(self.mesh_shape)
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
replicated_maxes = [] # lists mesh axis identifiers to replicate over
for maxis, assignment in enumerate(self.mesh_mapping):
if isinstance(assignment, Replicated):
replicated_maxes.append((maxis, assignment.replicas))
elif isinstance(assignment, ShardedAxis):
sharded_axes[assignment.axis] = maxis
else:
assert_unreachable(assignment)
proto = xc.OpSharding()
if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes:
proto.type = xc.OpSharding.Type.REPLICATED
return proto
else:
proto.type = xc.OpSharding.Type.OTHER
mesh_permutation = []
new_mesh_shape = []
next_sharded_axis = 0
for axis, sharding in enumerate(self.sharding):
if isinstance(sharding, NoSharding):
new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over
elif isinstance(sharding, Chunked):
for nchunks in sharding.chunks:
maxis = sharded_axes[next_sharded_axis]
assert mesh_shape[maxis] == nchunks
mesh_permutation.append(maxis)
next_sharded_axis += 1
new_mesh_shape.append(int(np.prod(sharding.chunks)))
elif isinstance(sharding, Unstacked):
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
else:
assert_unreachable(sharding)
# Create a partial sharding proto if tensor is replicated or partitioned
# specially over some mesh axes.
if replicated_maxes:
last_tile_dims = []
axes_by_type: Dict[OpShardingType, List[jax._src.mesh.MeshAxisName]] = {}
size_by_type: Dict[OpShardingType, int] = defaultdict(lambda: 1)
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys()))
for axis, size in replicated_maxes:
ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED)
axes_by_type.setdefault(ty, []).append(axis)
size_by_type[ty] *= size
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
last_tile_dims.append(ty)
new_mesh_shape.append(size_by_type[ty])
mesh_permutation.extend(axes)
proto.last_tile_dims = last_tile_dims
proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape)
proto.tile_assignment_dimensions = list(proto_mesh.shape)
proto.tile_assignment_devices = list(proto_mesh.flat)
return proto
def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
"""Returns NumPy-style indices corresponding to a sharding spec.
Args:
shape: The shape of the logical array being sharded.
Returns:
An ndarray with the same shape as the logical mesh (as derived form
`mesh_mapping`). Each entry is a NumPy-style index selecting the subset of
the data array to be placed on a corresponding device. The indices can be
ints, slice objects with step=1, or tuples of those.
"""
assert len(shape) == len(self.sharding), (shape, self.sharding)
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding)
# Take the op sharding indices generation route for pjit/xmap cases.
if not has_unstacked:
op_sharding_proto = sharding_spec_sharding_proto(self)
return op_shardings.op_sharding_to_numpy_indices(
op_sharding_proto, shape, math.prod(self.mesh_shape)
).reshape(self.mesh_shape)
axis_indices: List[Sequence[Index]] = []
shard_indices_shape = []
for dim, sharding in enumerate(self.sharding):
axis_size = shape[dim]
if isinstance(sharding, NoSharding):
axis_indices.append([slice(None)])
# NOTE: We don't append unsharded dimensions to shard_indices_shape here,
# because they do not appear in the mesh mapping.
elif isinstance(sharding, Unstacked):
assert axis_size == sharding.size, f'{axis_size} != {sharding.size}'
axis_indices.append(range(axis_size))
shard_indices_shape.append(axis_size)
elif isinstance(sharding, Chunked):
total_chunks = int(np.prod(sharding.chunks))
shard_size, ragged = divmod(axis_size, total_chunks)
assert not ragged, (axis_size, total_chunks, dim)
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
for i in range(total_chunks)])
shard_indices_shape.extend(sharding.chunks)
else:
assert_unreachable(sharding)
# shard_indices is an ndarray representing the sharded axes of the logical array,
# with each dimension having size equal to the number of shards across the corresponding
# logical array dimension, and each element containing the multi-dimensional index that
# is used to extract the corresponding shard of the logical array.
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
for i, idxs in enumerate(it.product(*axis_indices)):
shard_indices[i] = idxs
shard_indices = shard_indices.reshape(shard_indices_shape)
# Ensure that each sharded axis is used exactly once in the mesh mapping
num_sharded_dim = len(shard_indices_shape)
sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)]
assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and
len(sharded_dim_perm) == num_sharded_dim)
# Replicate/reorder the indices according to the mesh mapping
replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated))
replica_dim, sharded_dim = it.count(0), iter(sharded_dim_perm)
perm = [next(replica_dim) if isinstance(a, Replicated) else
len(replica_sizes) + next(sharded_dim)
for a in self.mesh_mapping]
return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape)
.transpose(perm))
def sharding_spec_repr(self):
return f'ShardingSpec({self.sharding}, {self.mesh_mapping})'
ShardingSpec.mesh_shape = property(sharding_spec_mesh_shape)
ShardingSpec.sharding_proto = sharding_spec_sharding_proto
ShardingSpec.indices = sharding_spec_indices
# mypy raises: error: Cannot assign to a method [assignment]
ShardingSpec.__repr__ = sharding_spec_repr # type: ignore
# Do not pollute the namespace
del sharding_spec_mesh_shape, sharding_spec_indices, sharding_spec_repr
def spec_to_indices(shape: Sequence[int],
spec: ShardingSpec) -> Tuple[Index, ...]:
"""Returns numpy-style indices corresponding to a sharding spec.
Each index describes a shard of the array. The order of the indices is the
same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out
row-major).
Args:
shape: The shape of the logical array being sharded.
spec: Describes how the array is sharded and how the shards are assigned to
the logical mesh.
Returns:
A tuple of length equal to the size of the mesh (inferred as the product of
sharded dimension sizes and all replication factors). Each element is an
int, a slice object with step=1, or a tuple thereof, to be treated as an
index into the full logical array.
"""
return tuple(spec.indices(shape).flat) # type: ignore
### util
@ -554,21 +347,6 @@ global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
### lazy device-memory persistence and result handling
def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
if sharded_dim is not None:
sharded_aval = aval.update(
shape=aval.shape[:sharded_dim] + aval.shape[sharded_dim+1:])
if sharded_dim_size is None:
sharded_dim_size = aval.shape[sharded_dim]
else:
assert sharded_dim_size is not None
sharded_aval = aval
return _pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None,
sharded_aval, sharded_dim)
# TODO(yashkatariya, phawkins): Remove this function after March 15, 2023.
def make_sharded_device_array(
aval: ShapedArray,
@ -594,7 +372,7 @@ def make_sharded_device_array(
from jax._src import pjit
if sharding_spec is None:
sharding_spec = _create_pmap_sharding_spec(aval)
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
mesh = jax._src.mesh.thread_resources.env.physical_mesh
sharding: sharding_impls.XLACompatibleSharding
@ -602,7 +380,7 @@ def make_sharded_device_array(
sharding = sharding_impls.PmapSharding(
np.asarray([d.device() for d in device_buffers]), sharding_spec)
else:
op_sharding = sharding_spec_sharding_proto(sharding_spec)
op_sharding = sharding_specs.sharding_spec_sharding_proto(sharding_spec)
pspec = pjit.parse_flatten_op_sharding(
op_sharding, mesh)[0].get_partition_spec()
sharding = sharding_impls.NamedSharding(mesh, pspec)
@ -1323,8 +1101,10 @@ class UnloadedPmapExecutable:
local_arg_parts_ = parts.local_arg_parts or [None] * len(pci.avals)
input_sharding_specs = [
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
parts.local_num_partitions, arg_parts, aval, in_axis)
sharding_specs.pmap_sharding_spec(
replicas.num_local_replicas, pci.axis_size,
parts.local_num_partitions, arg_parts,
cast(ShapedArray, aval).shape, in_axis)
for aval, arg_parts, in_axis in safe_zip(
shards.sharded_avals, local_arg_parts_, pci.in_axes)]
in_shardings = _get_pmap_sharding(local_device_assignment, input_sharding_specs)
@ -1342,15 +1122,16 @@ class UnloadedPmapExecutable:
if out_axis is not None else aval
for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)]
out_specs = [
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
parts.local_num_partitions, out_parts, aval, out_axis)
sharding_specs.pmap_sharding_spec(
replicas.num_local_replicas, pci.axis_size,
parts.local_num_partitions, out_parts, aval.shape, out_axis)
for out_parts, aval, out_axis in safe_zip(
local_out_parts, local_out_avals, pci.out_axes)]
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
if hasattr(pci.backend, "compile_replicated"):
input_indices = [
spec_to_indices(aval.shape, spec) # pytype: disable=attribute-error
sharding_specs.spec_to_indices(aval.shape, spec)
if spec is not None else None
for aval, spec in safe_zip(pci.avals, input_sharding_specs)
]
@ -1384,8 +1165,9 @@ class UnloadedPmapExecutable:
input_indices = []
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
assert isinstance(spec, sharding_impls.PmapSharding), spec
input_indices.append(spec_to_indices(aval.shape, spec.sharding_spec)
if spec.sharding_spec is not None else None)
input_indices.append(
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
if spec.sharding_spec is not None else None)
handle_outs = local_avals_to_results_handler(self.local_output_avals,
self.output_shardings)
handle_args = InputsHandler(self.compiled.local_devices(),
@ -1686,7 +1468,8 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
else:
replicated_aval = aval
# TODO(skye): figure out how partitioning should work here
sharding_spec = _pmap_sharding_spec(nrep, axis_size, 1, None, aval, in_axis)
sharding_spec = sharding_specs.pmap_sharding_spec(
nrep, axis_size, 1, None, aval.shape, in_axis)
buf = jax.device_put(val, devices[0])
sharding = sharding_impls.PmapSharding(
@ -1695,59 +1478,6 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
devices)
def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval,
map_axis: Optional[int]) -> ShardingSpec:
"""Sharding spec for arguments or results of a pmap.
Args:
nrep: number of local XLA replicas (product of local axis sizes)
axis_size: local axis size for outer pmap
npart: total number of XLA partitions (required by sharded_jit calls)
parts: the partitioning of the value or None
sharded_aval: the aval of the value inside the outer pmap, an instance of
a ShapedArray.
map_axis: the axis along which the value is mapped in the outer pmap
Returns:
A ShardingSpec.
"""
assert isinstance(sharded_aval, ShapedArray), sharded_aval
replication_factor, ragged = divmod(nrep, axis_size)
assert not ragged
# get the sharding spec from inner sharded_jits as if we weren't in a pmap
pspec = partitioned_sharding_spec(npart, parts, sharded_aval)
maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),)
if map_axis is not None:
sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis])
def shift_sharded_axis(a: MeshDimAssignment):
if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis:
return ShardedAxis(a.axis + 1)
return a
# replication_factor represents the product of inner pmaps, so it goes
# after the outer pmapped axis at index 0
return ShardingSpec(
sharding=tuple_insert(pspec.sharding, map_axis, Unstacked(axis_size)),
mesh_mapping=it.chain([ShardedAxis(sharded_in_axis)],
maybe_replicate,
map(shift_sharded_axis, pspec.mesh_mapping)))
else:
return ShardingSpec(
sharding=pspec.sharding,
mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping)
def partitioned_sharding_spec(num_partitions: int,
partitions: Optional[Sequence[int]],
aval) -> ShardingSpec:
if partitions is None:
maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),)
return ShardingSpec(
sharding=[_UNSHARDED_INSTANCE] * len(aval.shape),
mesh_mapping=maybe_replicate)
else:
assert len(partitions) == len(aval.shape)
return ShardingSpec(
# Chunked expects a list of integers
sharding=map(Chunked, [[x] for x in partitions]),
mesh_mapping=map(ShardedAxis, range(len(partitions))))
class ExecuteReplicated:
"""The logic to shard inputs, execute a replicated model, returning outputs."""
@ -2893,18 +2623,18 @@ class UnloadedMeshExecutable:
use_auto_spmd_partitioning=auto_spmd_lowering,
env_options_overrides=compiler_options,
)
opts = compile_options.executable_build_options
if auto_spmd_lowering:
assert mesh is not None
compile_options.executable_build_options.auto_spmd_partitioning_mesh_shape = \
list(mesh.shape.values())
compile_options.executable_build_options.auto_spmd_partitioning_mesh_ids = \
_get_logical_mesh_ids(list(mesh.shape.values())).reshape(-1)
opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values())
opts.auto_spmd_partitioning_mesh_ids = (
sharding_specs.get_logical_mesh_ids(list(mesh.shape.values()))
.reshape(-1))
compile_options.parameter_is_tupled_arguments = tuple_args
if _allow_propagation_to_outputs is None:
_allow_propagation_to_outputs = [False] * len(out_shardings)
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \
_allow_propagation_to_outputs
opts.allow_spmd_sharding_propagation_to_output = _allow_propagation_to_outputs
if hasattr(backend, "compile_replicated"):
return _compile_replicated_mesh_executable_from_hlo(
@ -3308,28 +3038,6 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
_check_aval(v.aval, what_thunk)
def _make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes):
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
sharding = [_UNSHARDED_INSTANCE] * num_dimensions
next_sharded_axis = 0
# NOTE: sorted is stable, which is important when multiple resources
# map to the same axis.
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
chunked = sharding[axis]
if isinstance(chunked, NoSharding):
chunked = Chunked([])
sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]])
assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \
"Value mapped to the same mesh axis twice"
mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis)
next_sharded_axis += 1
return ShardingSpec(sharding, mesh_mapping)
def new_mesh_sharding_specs(axis_sizes, axis_names):
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
return partial(_make_sharding_spec, axis_sizes, mesh_axis_pos)
def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False):
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
@ -3350,7 +3058,8 @@ def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False):
'axis size should be zero but got '
f'{aval_shape[axis] % axis_sizes[name]}')
aval_shape[axis] //= axis_sizes[name]
return _make_sharding_spec(axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes)
return sharding_specs.make_sharding_spec(
axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes)
return mk_sharding_spec
@ -3367,3 +3076,13 @@ def device_put(x, devices: Sequence[xc.ArrayImpl],
return [jax.device_put(x, device) for device in devices]
else:
return [jax.device_put(val, device) for val, device in safe_zip(x, devices)]
# TODO(phawkins): fix external users not to use these functions.
def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
return sharding_specs.create_pmap_sharding_spec(
aval.shape, sharded_dim, sharded_dim_size)
def _pmap_sharding_spec(nrep, axis_size, npart, parts,
sharded_aval, map_axis: Optional[int]) -> ShardingSpec:
return sharding_specs.pmap_sharding_spec(nrep, axis_size, npart, parts,
sharded_aval.shape, map_axis)

View File

@ -31,6 +31,7 @@ from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import sharding_specs
from jax._src import typing
from jax._src.api import jit, vmap
from jax._src.config import config
@ -293,8 +294,8 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla):
return sharding
elif isinstance(sharding, PmapSharding):
key_shape = aval.dtype.impl.key_shape
trailing_sharding = [pxla.NoSharding()] * len(key_shape)
phys_sharding_spec = pxla.ShardingSpec(
trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape)
phys_sharding_spec = sharding_specs.ShardingSpec(
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
mesh_mapping=sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=sharding.devices,

View File

@ -20,15 +20,14 @@ import operator as op
from typing import (Any, Sequence, List, Tuple, Optional, Mapping, Dict, Set,
FrozenSet, Union, cast)
from jax._src import core
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import sharding
from jax._src import sharding_specs
from jax._src import xla_bridge
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
import numpy as np
@ -296,7 +295,7 @@ class NamedSharding(XLACompatibleSharding):
array_mapping = get_array_mapping(self._parsed_pspec)
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
# since we don't really need sharding spec.
sharding_spec = pxla.new_mesh_sharding_specs(
sharding_spec = sharding_specs.new_mesh_sharding_specs(
self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping)
# Used in `with_sharding_constraint`.
special_axes = {}
@ -369,11 +368,11 @@ class SingleDeviceSharding(XLACompatibleSharding):
@use_cpp_class(xc.PmapSharding)
class PmapSharding(XLACompatibleSharding):
devices: np.ndarray
sharding_spec: pxla.ShardingSpec
sharding_spec: sharding_specs.ShardingSpec
@use_cpp_method()
def __init__(self, devices: Union[Sequence[Device], np.ndarray],
sharding_spec: pxla.ShardingSpec):
sharding_spec: sharding_specs.ShardingSpec):
self.devices = np.asarray(devices)
# The sharding spec should be pmap's sharding spec.
self.sharding_spec = sharding_spec
@ -421,12 +420,12 @@ class PmapSharding(XLACompatibleSharding):
"""
# The dtype doesn't matter here. Its only used for creating the
# sharding_spec.
aval = core.ShapedArray(shape, np.int32)
sharding_spec = pxla._create_pmap_sharding_spec(aval, sharded_dim)
sharding_spec = sharding_specs.create_pmap_sharding_spec(
tuple(shape), sharded_dim)
num_ways_sharded = None
for s in sharding_spec.sharding:
if isinstance(s, pxla.Unstacked):
if isinstance(s, sharding_specs.Unstacked):
num_ways_sharded = s.size
if num_ways_sharded is None:
raise NotImplementedError(
@ -444,7 +443,7 @@ class PmapSharding(XLACompatibleSharding):
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
self.shard_shape(global_shape) # raises a good error message
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
indices = sharding_specs.spec_to_indices(global_shape, self.sharding_spec)
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
@functools.cached_property
@ -459,7 +458,7 @@ class PmapSharding(XLACompatibleSharding):
sharded_dim = None
sharded_dim_size = None
for i, s in enumerate(self.sharding_spec.sharding):
if isinstance(s, pxla.Unstacked):
if isinstance(s, sharding_specs.Unstacked):
sharded_dim = i
sharded_dim_size = s.size
break

346
jax/_src/sharding_specs.py Normal file
View File

@ -0,0 +1,346 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A ShardingSpec describes at a high level how a logical array is sharded across
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
# describe how to shard inputs to a parallel computation). spec_to_indices()
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
# how the sharded array is "laid out" across devices. Given a sequence of
# devices, we shard the data across the devices in row-major order, with
# replication treated as an extra inner dimension.
#
# For example, given the logical data array [1, 2, 3, 4], if we were to
# partition this array 4 ways with a replication factor of 2, for a total of 8
# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
#
# This encoding is assumed by various parts of the system, e.g. generating
# replica groups for collective operations.
import collections
import functools
import itertools
import math
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
import numpy as np
from jax._src import op_shardings
from jax._src import util
from jax._src.lib import pmap_lib
from jax._src.lib import xla_client as xc
unsafe_map, map = map, util.safe_map
NoSharding = pmap_lib.NoSharding
Chunked = pmap_lib.Chunked
Unstacked = pmap_lib.Unstacked
_UNSHARDED_INSTANCE = NoSharding()
ShardedAxis = pmap_lib.ShardedAxis
Replicated = pmap_lib.Replicated
MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = pmap_lib.ShardingSpec
OpShardingType = Any
def _sharding_spec_mesh_shape(self):
sharded_axis_sizes = []
for sharding in self.sharding:
if isinstance(sharding, NoSharding):
continue
elif isinstance(sharding, Unstacked):
sharded_axis_sizes.append(sharding.size)
elif isinstance(sharding, Chunked):
sharded_axis_sizes.extend(sharding.chunks)
else:
util.assert_unreachable(sharding)
return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis)
else a.replicas
for a in self.mesh_mapping)
def get_logical_mesh_ids(mesh_shape):
return np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
_MeshAxisName = Any
def sharding_spec_sharding_proto(
self, special_axes: Mapping[int, OpShardingType] = {}):
"""Converts a ShardingSpec to an OpSharding proto.
See
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
Unfortunately the semantics are not very well described in the proto spec, but
the code here might help:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py
"""
mesh_shape = cast(Tuple[int, ...], self.mesh_shape)
mesh = get_logical_mesh_ids(self.mesh_shape)
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
replicated_maxes = [] # lists mesh axis identifiers to replicate over
for maxis, assignment in enumerate(self.mesh_mapping):
if isinstance(assignment, Replicated):
replicated_maxes.append((maxis, assignment.replicas))
elif isinstance(assignment, ShardedAxis):
sharded_axes[assignment.axis] = maxis
else:
util.assert_unreachable(assignment)
proto = xc.OpSharding()
if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes:
proto.type = xc.OpSharding.Type.REPLICATED
return proto
else:
proto.type = xc.OpSharding.Type.OTHER
mesh_permutation = []
new_mesh_shape = []
next_sharded_axis = 0
for axis, sharding in enumerate(self.sharding):
if isinstance(sharding, NoSharding):
new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over
elif isinstance(sharding, Chunked):
for nchunks in sharding.chunks:
maxis = sharded_axes[next_sharded_axis]
assert mesh_shape[maxis] == nchunks
mesh_permutation.append(maxis)
next_sharded_axis += 1
new_mesh_shape.append(int(np.prod(sharding.chunks)))
elif isinstance(sharding, Unstacked):
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
else:
util.assert_unreachable(sharding)
# Create a partial sharding proto if tensor is replicated or partitioned
# specially over some mesh axes.
if replicated_maxes:
last_tile_dims = []
axes_by_type: Dict[OpShardingType, List[_MeshAxisName]] = {}
size_by_type: Dict[OpShardingType, int] = collections.defaultdict(lambda: 1)
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys()))
for axis, size in replicated_maxes:
ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED)
axes_by_type.setdefault(ty, []).append(axis)
size_by_type[ty] *= size
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
last_tile_dims.append(ty)
new_mesh_shape.append(size_by_type[ty])
mesh_permutation.extend(axes)
proto.last_tile_dims = last_tile_dims
proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape)
proto.tile_assignment_dimensions = list(proto_mesh.shape)
proto.tile_assignment_devices = list(proto_mesh.flat)
return proto
def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
"""Returns NumPy-style indices corresponding to a sharding spec.
Args:
shape: The shape of the logical array being sharded.
Returns:
An ndarray with the same shape as the logical mesh (as derived form
`mesh_mapping`). Each entry is a NumPy-style index selecting the subset of
the data array to be placed on a corresponding device. The indices can be
ints, slice objects with step=1, or tuples of those.
"""
assert len(shape) == len(self.sharding), (shape, self.sharding)
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding)
# Take the op sharding indices generation route for pjit/xmap cases.
if not has_unstacked:
op_sharding_proto = sharding_spec_sharding_proto(self)
return op_shardings.op_sharding_to_numpy_indices(
op_sharding_proto, shape, math.prod(self.mesh_shape)
).reshape(self.mesh_shape)
axis_indices: List[Sequence[Index]] = []
shard_indices_shape = []
for dim, sharding in enumerate(self.sharding):
axis_size = shape[dim]
if isinstance(sharding, NoSharding):
axis_indices.append([slice(None)])
# NOTE: We don't append unsharded dimensions to shard_indices_shape here,
# because they do not appear in the mesh mapping.
elif isinstance(sharding, Unstacked):
assert axis_size == sharding.size, f'{axis_size} != {sharding.size}'
axis_indices.append(range(axis_size))
shard_indices_shape.append(axis_size)
elif isinstance(sharding, Chunked):
total_chunks = int(np.prod(sharding.chunks))
shard_size, ragged = divmod(axis_size, total_chunks)
assert not ragged, (axis_size, total_chunks, dim)
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
for i in range(total_chunks)])
shard_indices_shape.extend(sharding.chunks)
else:
util.assert_unreachable(sharding)
# shard_indices is an ndarray representing the sharded axes of the logical array,
# with each dimension having size equal to the number of shards across the corresponding
# logical array dimension, and each element containing the multi-dimensional index that
# is used to extract the corresponding shard of the logical array.
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
for i, idxs in enumerate(itertools.product(*axis_indices)):
shard_indices[i] = idxs
shard_indices = shard_indices.reshape(shard_indices_shape)
# Ensure that each sharded axis is used exactly once in the mesh mapping
num_sharded_dim = len(shard_indices_shape)
sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)]
assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and
len(sharded_dim_perm) == num_sharded_dim)
# Replicate/reorder the indices according to the mesh mapping
replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated))
replica_dim, sharded_dim = itertools.count(0), iter(sharded_dim_perm)
perm = [next(replica_dim) if isinstance(a, Replicated) else
len(replica_sizes) + next(sharded_dim)
for a in self.mesh_mapping]
return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape)
.transpose(perm))
def _sharding_spec_repr(self):
return f'ShardingSpec({self.sharding}, {self.mesh_mapping})'
ShardingSpec.mesh_shape = property(_sharding_spec_mesh_shape)
ShardingSpec.sharding_proto = sharding_spec_sharding_proto
ShardingSpec.indices = _sharding_spec_indices
# mypy raises: error: Cannot assign to a method [assignment]
ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
def spec_to_indices(shape: Sequence[int],
spec: ShardingSpec) -> Tuple[Index, ...]:
"""Returns numpy-style indices corresponding to a sharding spec.
Each index describes a shard of the array. The order of the indices is the
same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out
row-major).
Args:
shape: The shape of the logical array being sharded.
spec: Describes how the array is sharded and how the shards are assigned to
the logical mesh.
Returns:
A tuple of length equal to the size of the mesh (inferred as the product of
sharded dimension sizes and all replication factors). Each element is an
int, a slice object with step=1, or a tuple thereof, to be treated as an
index into the full logical array.
"""
return tuple(spec.indices(shape).flat) # type: ignore
def partitioned_sharding_spec(num_partitions: int,
partitions: Optional[Sequence[int]],
shape: Sequence[int]) -> ShardingSpec:
if partitions is None:
maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),)
return ShardingSpec(
sharding=[_UNSHARDED_INSTANCE] * len(shape),
mesh_mapping=maybe_replicate)
else:
assert len(partitions) == len(shape)
return ShardingSpec(
# Chunked expects a list of integers
sharding=map(Chunked, [[x] for x in partitions]),
mesh_mapping=map(ShardedAxis, range(len(partitions))))
def make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes):
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
sharding = [_UNSHARDED_INSTANCE] * num_dimensions
next_sharded_axis = 0
# NOTE: sorted is stable, which is important when multiple resources
# map to the same axis.
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
chunked = sharding[axis]
if isinstance(chunked, NoSharding):
chunked = Chunked([])
sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]])
assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \
"Value mapped to the same mesh axis twice"
mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis)
next_sharded_axis += 1
return ShardingSpec(sharding, mesh_mapping)
def new_mesh_sharding_specs(axis_sizes, axis_names):
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
return functools.partial(make_sharding_spec, axis_sizes, mesh_axis_pos)
def pmap_sharding_spec(nrep, axis_size, npart, parts,
sharded_shape: Sequence[int],
map_axis: Optional[int]) -> ShardingSpec:
"""Sharding spec for arguments or results of a pmap.
Args:
nrep: number of local XLA replicas (product of local axis sizes)
axis_size: local axis size for outer pmap
npart: total number of XLA partitions (required by sharded_jit calls)
parts: the partitioning of the value or None
sharded_aval: the aval of the value inside the outer pmap, an instance of
a ShapedArray.
map_axis: the axis along which the value is mapped in the outer pmap
Returns:
A ShardingSpec.
"""
replication_factor, ragged = divmod(nrep, axis_size)
assert not ragged
# get the sharding spec from inner sharded_jits as if we weren't in a pmap
pspec = partitioned_sharding_spec(npart, parts, sharded_shape)
maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),)
if map_axis is not None:
sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis])
def shift_sharded_axis(a: MeshDimAssignment):
if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis:
return ShardedAxis(a.axis + 1)
return a
# replication_factor represents the product of inner pmaps, so it goes
# after the outer pmapped axis at index 0
return ShardingSpec(
sharding=util.tuple_insert(
pspec.sharding, map_axis, Unstacked(axis_size)),
mesh_mapping=itertools.chain(
[ShardedAxis(sharded_in_axis)], maybe_replicate,
map(shift_sharded_axis, pspec.mesh_mapping)))
else:
return ShardingSpec(
sharding=pspec.sharding,
mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping)
def create_pmap_sharding_spec(shape: Tuple[int, ...], sharded_dim: int = 0,
sharded_dim_size: Optional[int] = None):
if sharded_dim is not None:
sharded_shape = shape[:sharded_dim] + shape[sharded_dim+1:]
if sharded_dim_size is None:
sharded_dim_size = shape[sharded_dim]
else:
assert sharded_dim_size is not None
sharded_shape = shape
return pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None,
sharded_shape, sharded_dim)

View File

@ -17,7 +17,6 @@ from jax._src.interpreters.pxla import (
ArrayMapping as ArrayMapping,
ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified,
AvalDimSharding as AvalDimSharding,
Chunked as Chunked,
EmapInfo as EmapInfo,
ExecuteReplicated as ExecuteReplicated,
Index as Index,
@ -27,8 +26,6 @@ from jax._src.interpreters.pxla import (
MeshComputation as MeshComputation,
MeshDimAssignment as MeshDimAssignment,
MeshExecutable as MeshExecutable,
NoSharding as NoSharding,
OpShardingType as OpShardingType,
OrderedDictType as OrderedDictType,
ParallelCallableInfo as ParallelCallableInfo,
PartitionInfo as PartitionInfo,
@ -37,18 +34,14 @@ from jax._src.interpreters.pxla import (
PmapExecutable as PmapExecutable,
PxlaResultHandler as PxlaResultHandler,
ReplicaInfo as ReplicaInfo,
Replicated as Replicated,
ResultsHandler as ResultsHandler,
SPMDBatchTrace as SPMDBatchTrace,
ShardInfo as ShardInfo,
ShardedAxis as ShardedAxis,
ShardingSpec as ShardingSpec,
TileManual as TileManual,
TileVectorize as TileVectorize,
TilingMethod as TilingMethod,
UnloadedMeshExecutable as UnloadedMeshExecutable,
UnloadedPmapExecutable as UnloadedPmapExecutable,
Unstacked as Unstacked,
WeakRefList as WeakRefList,
_UNSPECIFIED as _UNSPECIFIED,
_create_pmap_sharding_spec as _create_pmap_sharding_spec,
@ -76,9 +69,7 @@ from jax._src.interpreters.pxla import (
maybe_extend_axis_env as maybe_extend_axis_env,
mesh_sharding_specs as mesh_sharding_specs,
multi_host_supported_collectives as multi_host_supported_collectives,
new_mesh_sharding_specs as new_mesh_sharding_specs,
parallel_callable as parallel_callable,
partitioned_sharding_spec as partitioned_sharding_spec,
reconcile_num_partitions as reconcile_num_partitions,
replicate as replicate,
resource_typecheck as resource_typecheck,
@ -88,8 +79,6 @@ from jax._src.interpreters.pxla import (
shard_aval as shard_aval,
shard_aval_handlers as shard_aval_handlers,
shard_to_full_p as shard_to_full_p,
sharding_spec_sharding_proto as sharding_spec_sharding_proto,
spec_to_indices as spec_to_indices,
spmd_primitive_batchers as spmd_primitive_batchers,
stage_parallel_callable as stage_parallel_callable,
tile_aval_nd as tile_aval_nd,
@ -114,6 +103,19 @@ from jax._src.op_shardings import (
op_sharding_to_indices as op_sharding_to_indices,
)
from jax._src.sharding_specs import (
Chunked as Chunked,
NoSharding as NoSharding,
OpShardingType as OpShardingType,
Replicated as Replicated,
ShardedAxis as ShardedAxis,
ShardingSpec as ShardingSpec,
Unstacked as Unstacked,
new_mesh_sharding_specs as new_mesh_sharding_specs,
sharding_spec_sharding_proto as sharding_spec_sharding_proto,
spec_to_indices as spec_to_indices,
)
# Deprecations
from jax._src.mesh import Mesh as _deprecated_Mesh

View File

@ -40,10 +40,10 @@ from jax._src.lax import parallel
from jax._src import api as src_api
from jax import random
from jax._src import core
from jax._src.core import ShapedArray
from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax._src import config as jax_config
from jax._src import sharding_specs
from jax._src import xla_bridge
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
@ -118,13 +118,11 @@ ignore_xmap_warning = partial(
def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
devices=None, sharded_dim_size=None):
dtype = np.int32
aval = ShapedArray(input_shape, dtype)
if input_data is None:
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes, sharded_dim_size)
sharding_spec = sharding_specs.create_pmap_sharding_spec(
input_shape, in_axes, sharded_dim_size)
if devices is None:
devices = jax.devices()
@ -2829,7 +2827,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
shape = (4, 8)
spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])),
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
((slice(0,2), slice(0,4)),
(slice(0,2), slice(4,8)),
(slice(2,4), slice(0,4)),
@ -2839,7 +2837,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
shape = (4, 8)
spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])),
mesh_mapping=map(pxla.ShardedAxis, (1, 0)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
((slice(0,2), slice(0,4)),
(slice(2,4), slice(0,4)),
(slice(0,2), slice(4,8)),
@ -2851,7 +2849,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
mesh_mapping=(pxla.Replicated(2),
pxla.ShardedAxis(1),
pxla.ShardedAxis(0)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
((slice(0,2), slice(0,4)),
(slice(2,4), slice(0,4)),
(slice(0,2), slice(4,8)),
@ -2861,7 +2859,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
shape = (4, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0),))
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
((slice(0,2), slice(None)),
(slice(2,4), slice(None))))
@ -2869,14 +2867,14 @@ class SpecToIndicesTest(jtu.JaxTestCase):
shape = (4, 8)
spec = pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
mesh_mapping=())
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
((slice(None), slice(None)),))
def testUnmaterializedAxis(self):
shape = (4, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(4), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0),))
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
((0, slice(None)),
(1, slice(None)),
(2, slice(None)),
@ -2885,7 +2883,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
shape = (2, 2)
spec = pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.Unstacked(2)),
mesh_mapping=(pxla.ShardedAxis(0),))
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
((slice(None), 0),
(slice(None), 1)))
@ -2893,14 +2891,14 @@ class SpecToIndicesTest(jtu.JaxTestCase):
shape = (2, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
tuple([(0, slice(None))] * 3 + [(1, slice(None))] * 3))
def testReplicationPosition2(self):
shape = (2, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])),
mesh_mapping=(pxla.ShardedAxis(0), pxla.ShardedAxis(1), pxla.Replicated(3)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
((0, slice(0, 4)), (0, slice(0, 4)), (0, slice(0, 4)),
(0, slice(4, 8)), (0, slice(4, 8)), (0, slice(4, 8)),
(1, slice(0, 4)), (1, slice(0, 4)), (1, slice(0, 4)),
@ -2910,7 +2908,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
shape = (2, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])),
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3), pxla.ShardedAxis(1)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
((0, slice(0, 4)), (0, slice(4, 8)),
(0, slice(0, 4)), (0, slice(4, 8)),
(0, slice(0, 4)), (0, slice(4, 8)),
@ -2922,7 +2920,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
shape = (2, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
mesh_mapping=(pxla.Replicated(3), pxla.ShardedAxis(0)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
tuple([(0, slice(None)), (1, slice(None))] * 3))
def testMultipleReplications(self):
@ -2933,7 +2931,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
pxla.ShardedAxis(0), pxla.Replicated(2),
pxla.ShardedAxis(1)))
self.assertEqual(
pxla.spec_to_indices(shape, spec),
sharding_specs.spec_to_indices(shape, spec),
((0, slice(None), slice(0, 2)), (0, slice(None), slice(2, 4)),
(0, slice(None), slice(0, 2)), (0, slice(None), slice(2, 4)),
(1, slice(None), slice(0, 2)), (1, slice(None), slice(2, 4)),
@ -2943,7 +2941,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
shape = ()
spec = pxla.ShardingSpec(sharding=(),
mesh_mapping=(pxla.Replicated(3),))
self.assertEqual(pxla.spec_to_indices(shape, spec),
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
((), (), ()))
@ -3003,7 +3001,7 @@ class ShardArgsTest(jtu.JaxTestCase):
mesh_mapping=(pxla.Replicated(2), pxla.Replicated(3)))],
])
def testShardArgs(self, shape, spec, make_arg):
indices = pxla.spec_to_indices(shape, spec)
indices = sharding_specs.spec_to_indices(shape, spec)
nshards = len(indices)
if jax.device_count() < nshards:
raise SkipTest
@ -3014,7 +3012,8 @@ class ShardArgsTest(jtu.JaxTestCase):
sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
else:
sharding = jax.sharding.GSPMDSharding(
jax.devices()[:nshards], pxla.sharding_spec_sharding_proto(spec))
jax.devices()[:nshards],
sharding_specs.sharding_spec_sharding_proto(spec))
results = pxla.shard_args(
jax.devices()[:nshards], [indices], [sharding], [arg]