mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).
PiperOrigin-RevId: 522360232
This commit is contained in:
parent
b8ade584bf
commit
dfe95dcb4e
11
jax/BUILD
11
jax/BUILD
@ -182,6 +182,7 @@ py_library_providing_imports_info(
|
||||
":pretty_printer",
|
||||
":profiler",
|
||||
":sharding",
|
||||
":sharding_specs",
|
||||
":source_info_util",
|
||||
":traceback_util",
|
||||
":tree_util",
|
||||
@ -411,6 +412,16 @@ pytype_strict_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "sharding_specs",
|
||||
srcs = ["_src/sharding_specs.py"],
|
||||
deps = [
|
||||
":op_shardings",
|
||||
":util",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "source_info_util",
|
||||
srcs = ["_src/source_info_util.py"],
|
||||
|
@ -45,6 +45,7 @@ from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import array
|
||||
from jax._src import dtypes
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import pjit
|
||||
@ -1782,13 +1783,13 @@ class _PmapFastpathData(NamedTuple):
|
||||
out_handler: Any
|
||||
out_pytree_def: Any
|
||||
# Data needed to handle the inputs.
|
||||
input_sharding_specs: Sequence[pxla.ShardingSpec]
|
||||
input_sharding_specs: Sequence[sharding_specs.ShardingSpec]
|
||||
input_devices: Sequence[xc.Device]
|
||||
input_indices: Sequence[pxla.Index]
|
||||
input_indices: Sequence[sharding_specs.Index]
|
||||
input_array_shardings: Sequence[Any]
|
||||
# Data needed to build the Array from C++.
|
||||
out_sharding_specs: Sequence[pxla.ShardingSpec]
|
||||
out_indices: Sequence[pxla.Index]
|
||||
out_sharding_specs: Sequence[sharding_specs.ShardingSpec]
|
||||
out_indices: Sequence[sharding_specs.Index]
|
||||
out_avals: Sequence[Any]
|
||||
out_array_shardings: Sequence[Any]
|
||||
out_committed: Sequence[Any]
|
||||
@ -2582,7 +2583,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
raise ValueError("the shards passed to device_put_sharded must have "
|
||||
f"consistent shape and dtype, but got {a1} and {a2}.")
|
||||
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
|
||||
sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape)
|
||||
return pxla.batched_device_put(
|
||||
stacked_aval, PmapSharding(np.array(devices), sharding_spec),
|
||||
xs, list(devices))
|
||||
@ -2630,7 +2631,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
core.raise_to_shaped(core.get_aval(x)))
|
||||
assert (isinstance(aval, ShapedArray) and
|
||||
len(xla.aval_to_xla_shapes(aval)) == 1)
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval)
|
||||
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
|
||||
buf = device_put(x, devices[0])
|
||||
return pxla.batched_device_put(
|
||||
aval, PmapSharding(np.array(devices), sharding_spec),
|
||||
|
@ -13,21 +13,6 @@
|
||||
# limitations under the License.
|
||||
"""Implementation of pmap and related functionality."""
|
||||
|
||||
# A ShardingSpec describes at a high level how a logical array is sharded across
|
||||
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
|
||||
# describe how to shard inputs to a parallel computation). spec_to_indices()
|
||||
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
|
||||
# how the sharded array is "laid out" across devices. Given a sequence of
|
||||
# devices, we shard the data across the devices in row-major order, with
|
||||
# replication treated as an extra inner dimension.
|
||||
#
|
||||
# For example, given the logical data array [1, 2, 3, 4], if we were to
|
||||
# partition this array 4 ways with a replication factor of 2, for a total of 8
|
||||
# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
|
||||
#
|
||||
# This encoding is assumed by various parts of the system, e.g. generating
|
||||
# replica groups for collective operations.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
@ -40,8 +25,8 @@ import logging
|
||||
import math
|
||||
import sys
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
|
||||
TYPE_CHECKING)
|
||||
Sequence, Set, Tuple, Type, Union, Iterable,
|
||||
TYPE_CHECKING, cast)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -57,6 +42,7 @@ from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import profiler
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
@ -73,12 +59,10 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
|
||||
wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, distributed_debug_log,
|
||||
wrap_name, tuple_delete, distributed_debug_log,
|
||||
unzip2, HashableFunction)
|
||||
|
||||
|
||||
@ -100,214 +84,23 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
|
||||
NoSharding = pmap_lib.NoSharding
|
||||
Chunked = pmap_lib.Chunked
|
||||
Unstacked = pmap_lib.Unstacked
|
||||
NoSharding = sharding_specs.NoSharding
|
||||
Chunked = sharding_specs.Chunked
|
||||
Unstacked = sharding_specs.Unstacked
|
||||
|
||||
ShardedAxis = pmap_lib.ShardedAxis
|
||||
Replicated = pmap_lib.Replicated
|
||||
ShardedAxis = sharding_specs.ShardedAxis
|
||||
Replicated = sharding_specs.Replicated
|
||||
|
||||
_UNSHARDED_INSTANCE = NoSharding()
|
||||
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
|
||||
Mesh = jax._src.mesh.Mesh
|
||||
MeshAxisName = mesh.MeshAxisName
|
||||
MeshDimAssignment = Union[ShardedAxis, Replicated]
|
||||
ShardingSpec = pmap_lib.ShardingSpec
|
||||
ShardingSpec = sharding_specs.ShardingSpec
|
||||
|
||||
OpShardingType = Any
|
||||
|
||||
PartitionSpec = sharding_impls.PartitionSpec
|
||||
|
||||
|
||||
def sharding_spec_mesh_shape(self):
|
||||
sharded_axis_sizes = []
|
||||
for sharding in self.sharding:
|
||||
if isinstance(sharding, NoSharding):
|
||||
continue
|
||||
elif isinstance(sharding, Unstacked):
|
||||
sharded_axis_sizes.append(sharding.size)
|
||||
elif isinstance(sharding, Chunked):
|
||||
sharded_axis_sizes.extend(sharding.chunks)
|
||||
else:
|
||||
assert_unreachable(sharding)
|
||||
return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis) else a.replicas
|
||||
for a in self.mesh_mapping)
|
||||
|
||||
|
||||
def _get_logical_mesh_ids(mesh_shape):
|
||||
return np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
|
||||
|
||||
|
||||
def sharding_spec_sharding_proto(self, special_axes: Mapping[int, OpShardingType] = {}):
|
||||
"""Converts a ShardingSpec to an OpSharding proto.
|
||||
|
||||
See
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
|
||||
for details on the OpSharding proto.
|
||||
Unfortunately the semantics are not very well described in the proto spec, but the code here might help:
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py
|
||||
"""
|
||||
mesh_shape = cast(Tuple[int, ...], self.mesh_shape)
|
||||
mesh = _get_logical_mesh_ids(self.mesh_shape)
|
||||
|
||||
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
|
||||
replicated_maxes = [] # lists mesh axis identifiers to replicate over
|
||||
for maxis, assignment in enumerate(self.mesh_mapping):
|
||||
if isinstance(assignment, Replicated):
|
||||
replicated_maxes.append((maxis, assignment.replicas))
|
||||
elif isinstance(assignment, ShardedAxis):
|
||||
sharded_axes[assignment.axis] = maxis
|
||||
else:
|
||||
assert_unreachable(assignment)
|
||||
|
||||
proto = xc.OpSharding()
|
||||
if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes:
|
||||
proto.type = xc.OpSharding.Type.REPLICATED
|
||||
return proto
|
||||
else:
|
||||
proto.type = xc.OpSharding.Type.OTHER
|
||||
|
||||
mesh_permutation = []
|
||||
new_mesh_shape = []
|
||||
next_sharded_axis = 0
|
||||
for axis, sharding in enumerate(self.sharding):
|
||||
if isinstance(sharding, NoSharding):
|
||||
new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over
|
||||
elif isinstance(sharding, Chunked):
|
||||
for nchunks in sharding.chunks:
|
||||
maxis = sharded_axes[next_sharded_axis]
|
||||
assert mesh_shape[maxis] == nchunks
|
||||
mesh_permutation.append(maxis)
|
||||
next_sharded_axis += 1
|
||||
new_mesh_shape.append(int(np.prod(sharding.chunks)))
|
||||
elif isinstance(sharding, Unstacked):
|
||||
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
|
||||
else:
|
||||
assert_unreachable(sharding)
|
||||
|
||||
# Create a partial sharding proto if tensor is replicated or partitioned
|
||||
# specially over some mesh axes.
|
||||
if replicated_maxes:
|
||||
last_tile_dims = []
|
||||
axes_by_type: Dict[OpShardingType, List[jax._src.mesh.MeshAxisName]] = {}
|
||||
size_by_type: Dict[OpShardingType, int] = defaultdict(lambda: 1)
|
||||
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys()))
|
||||
for axis, size in replicated_maxes:
|
||||
ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED)
|
||||
axes_by_type.setdefault(ty, []).append(axis)
|
||||
size_by_type[ty] *= size
|
||||
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
|
||||
last_tile_dims.append(ty)
|
||||
new_mesh_shape.append(size_by_type[ty])
|
||||
mesh_permutation.extend(axes)
|
||||
proto.last_tile_dims = last_tile_dims
|
||||
|
||||
proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape)
|
||||
proto.tile_assignment_dimensions = list(proto_mesh.shape)
|
||||
proto.tile_assignment_devices = list(proto_mesh.flat)
|
||||
return proto
|
||||
|
||||
|
||||
def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
||||
"""Returns NumPy-style indices corresponding to a sharding spec.
|
||||
|
||||
Args:
|
||||
shape: The shape of the logical array being sharded.
|
||||
|
||||
Returns:
|
||||
An ndarray with the same shape as the logical mesh (as derived form
|
||||
`mesh_mapping`). Each entry is a NumPy-style index selecting the subset of
|
||||
the data array to be placed on a corresponding device. The indices can be
|
||||
ints, slice objects with step=1, or tuples of those.
|
||||
"""
|
||||
assert len(shape) == len(self.sharding), (shape, self.sharding)
|
||||
|
||||
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding)
|
||||
# Take the op sharding indices generation route for pjit/xmap cases.
|
||||
if not has_unstacked:
|
||||
op_sharding_proto = sharding_spec_sharding_proto(self)
|
||||
return op_shardings.op_sharding_to_numpy_indices(
|
||||
op_sharding_proto, shape, math.prod(self.mesh_shape)
|
||||
).reshape(self.mesh_shape)
|
||||
|
||||
axis_indices: List[Sequence[Index]] = []
|
||||
shard_indices_shape = []
|
||||
for dim, sharding in enumerate(self.sharding):
|
||||
axis_size = shape[dim]
|
||||
if isinstance(sharding, NoSharding):
|
||||
axis_indices.append([slice(None)])
|
||||
# NOTE: We don't append unsharded dimensions to shard_indices_shape here,
|
||||
# because they do not appear in the mesh mapping.
|
||||
elif isinstance(sharding, Unstacked):
|
||||
assert axis_size == sharding.size, f'{axis_size} != {sharding.size}'
|
||||
axis_indices.append(range(axis_size))
|
||||
shard_indices_shape.append(axis_size)
|
||||
elif isinstance(sharding, Chunked):
|
||||
total_chunks = int(np.prod(sharding.chunks))
|
||||
shard_size, ragged = divmod(axis_size, total_chunks)
|
||||
assert not ragged, (axis_size, total_chunks, dim)
|
||||
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
|
||||
for i in range(total_chunks)])
|
||||
shard_indices_shape.extend(sharding.chunks)
|
||||
else:
|
||||
assert_unreachable(sharding)
|
||||
|
||||
# shard_indices is an ndarray representing the sharded axes of the logical array,
|
||||
# with each dimension having size equal to the number of shards across the corresponding
|
||||
# logical array dimension, and each element containing the multi-dimensional index that
|
||||
# is used to extract the corresponding shard of the logical array.
|
||||
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
|
||||
for i, idxs in enumerate(it.product(*axis_indices)):
|
||||
shard_indices[i] = idxs
|
||||
shard_indices = shard_indices.reshape(shard_indices_shape)
|
||||
|
||||
# Ensure that each sharded axis is used exactly once in the mesh mapping
|
||||
num_sharded_dim = len(shard_indices_shape)
|
||||
sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)]
|
||||
assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and
|
||||
len(sharded_dim_perm) == num_sharded_dim)
|
||||
# Replicate/reorder the indices according to the mesh mapping
|
||||
replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated))
|
||||
replica_dim, sharded_dim = it.count(0), iter(sharded_dim_perm)
|
||||
perm = [next(replica_dim) if isinstance(a, Replicated) else
|
||||
len(replica_sizes) + next(sharded_dim)
|
||||
for a in self.mesh_mapping]
|
||||
return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape)
|
||||
.transpose(perm))
|
||||
|
||||
def sharding_spec_repr(self):
|
||||
return f'ShardingSpec({self.sharding}, {self.mesh_mapping})'
|
||||
|
||||
|
||||
ShardingSpec.mesh_shape = property(sharding_spec_mesh_shape)
|
||||
ShardingSpec.sharding_proto = sharding_spec_sharding_proto
|
||||
ShardingSpec.indices = sharding_spec_indices
|
||||
# mypy raises: error: Cannot assign to a method [assignment]
|
||||
ShardingSpec.__repr__ = sharding_spec_repr # type: ignore
|
||||
# Do not pollute the namespace
|
||||
del sharding_spec_mesh_shape, sharding_spec_indices, sharding_spec_repr
|
||||
|
||||
def spec_to_indices(shape: Sequence[int],
|
||||
spec: ShardingSpec) -> Tuple[Index, ...]:
|
||||
"""Returns numpy-style indices corresponding to a sharding spec.
|
||||
|
||||
Each index describes a shard of the array. The order of the indices is the
|
||||
same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out
|
||||
row-major).
|
||||
|
||||
Args:
|
||||
shape: The shape of the logical array being sharded.
|
||||
spec: Describes how the array is sharded and how the shards are assigned to
|
||||
the logical mesh.
|
||||
|
||||
Returns:
|
||||
A tuple of length equal to the size of the mesh (inferred as the product of
|
||||
sharded dimension sizes and all replication factors). Each element is an
|
||||
int, a slice object with step=1, or a tuple thereof, to be treated as an
|
||||
index into the full logical array.
|
||||
"""
|
||||
return tuple(spec.indices(shape).flat) # type: ignore
|
||||
|
||||
|
||||
### util
|
||||
|
||||
@ -554,21 +347,6 @@ global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
|
||||
### lazy device-memory persistence and result handling
|
||||
|
||||
|
||||
def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
|
||||
if sharded_dim is not None:
|
||||
sharded_aval = aval.update(
|
||||
shape=aval.shape[:sharded_dim] + aval.shape[sharded_dim+1:])
|
||||
if sharded_dim_size is None:
|
||||
sharded_dim_size = aval.shape[sharded_dim]
|
||||
else:
|
||||
assert sharded_dim_size is not None
|
||||
sharded_aval = aval
|
||||
|
||||
return _pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None,
|
||||
sharded_aval, sharded_dim)
|
||||
|
||||
|
||||
# TODO(yashkatariya, phawkins): Remove this function after March 15, 2023.
|
||||
def make_sharded_device_array(
|
||||
aval: ShapedArray,
|
||||
@ -594,7 +372,7 @@ def make_sharded_device_array(
|
||||
from jax._src import pjit
|
||||
|
||||
if sharding_spec is None:
|
||||
sharding_spec = _create_pmap_sharding_spec(aval)
|
||||
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
|
||||
|
||||
mesh = jax._src.mesh.thread_resources.env.physical_mesh
|
||||
sharding: sharding_impls.XLACompatibleSharding
|
||||
@ -602,7 +380,7 @@ def make_sharded_device_array(
|
||||
sharding = sharding_impls.PmapSharding(
|
||||
np.asarray([d.device() for d in device_buffers]), sharding_spec)
|
||||
else:
|
||||
op_sharding = sharding_spec_sharding_proto(sharding_spec)
|
||||
op_sharding = sharding_specs.sharding_spec_sharding_proto(sharding_spec)
|
||||
pspec = pjit.parse_flatten_op_sharding(
|
||||
op_sharding, mesh)[0].get_partition_spec()
|
||||
sharding = sharding_impls.NamedSharding(mesh, pspec)
|
||||
@ -1323,8 +1101,10 @@ class UnloadedPmapExecutable:
|
||||
|
||||
local_arg_parts_ = parts.local_arg_parts or [None] * len(pci.avals)
|
||||
input_sharding_specs = [
|
||||
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
|
||||
parts.local_num_partitions, arg_parts, aval, in_axis)
|
||||
sharding_specs.pmap_sharding_spec(
|
||||
replicas.num_local_replicas, pci.axis_size,
|
||||
parts.local_num_partitions, arg_parts,
|
||||
cast(ShapedArray, aval).shape, in_axis)
|
||||
for aval, arg_parts, in_axis in safe_zip(
|
||||
shards.sharded_avals, local_arg_parts_, pci.in_axes)]
|
||||
in_shardings = _get_pmap_sharding(local_device_assignment, input_sharding_specs)
|
||||
@ -1342,15 +1122,16 @@ class UnloadedPmapExecutable:
|
||||
if out_axis is not None else aval
|
||||
for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)]
|
||||
out_specs = [
|
||||
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
|
||||
parts.local_num_partitions, out_parts, aval, out_axis)
|
||||
sharding_specs.pmap_sharding_spec(
|
||||
replicas.num_local_replicas, pci.axis_size,
|
||||
parts.local_num_partitions, out_parts, aval.shape, out_axis)
|
||||
for out_parts, aval, out_axis in safe_zip(
|
||||
local_out_parts, local_out_avals, pci.out_axes)]
|
||||
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
|
||||
|
||||
if hasattr(pci.backend, "compile_replicated"):
|
||||
input_indices = [
|
||||
spec_to_indices(aval.shape, spec) # pytype: disable=attribute-error
|
||||
sharding_specs.spec_to_indices(aval.shape, spec)
|
||||
if spec is not None else None
|
||||
for aval, spec in safe_zip(pci.avals, input_sharding_specs)
|
||||
]
|
||||
@ -1384,8 +1165,9 @@ class UnloadedPmapExecutable:
|
||||
input_indices = []
|
||||
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
|
||||
assert isinstance(spec, sharding_impls.PmapSharding), spec
|
||||
input_indices.append(spec_to_indices(aval.shape, spec.sharding_spec)
|
||||
if spec.sharding_spec is not None else None)
|
||||
input_indices.append(
|
||||
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
|
||||
if spec.sharding_spec is not None else None)
|
||||
handle_outs = local_avals_to_results_handler(self.local_output_avals,
|
||||
self.output_shardings)
|
||||
handle_args = InputsHandler(self.compiled.local_devices(),
|
||||
@ -1686,7 +1468,8 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
|
||||
else:
|
||||
replicated_aval = aval
|
||||
# TODO(skye): figure out how partitioning should work here
|
||||
sharding_spec = _pmap_sharding_spec(nrep, axis_size, 1, None, aval, in_axis)
|
||||
sharding_spec = sharding_specs.pmap_sharding_spec(
|
||||
nrep, axis_size, 1, None, aval.shape, in_axis)
|
||||
|
||||
buf = jax.device_put(val, devices[0])
|
||||
sharding = sharding_impls.PmapSharding(
|
||||
@ -1695,59 +1478,6 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
|
||||
devices)
|
||||
|
||||
|
||||
def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval,
|
||||
map_axis: Optional[int]) -> ShardingSpec:
|
||||
"""Sharding spec for arguments or results of a pmap.
|
||||
Args:
|
||||
nrep: number of local XLA replicas (product of local axis sizes)
|
||||
axis_size: local axis size for outer pmap
|
||||
npart: total number of XLA partitions (required by sharded_jit calls)
|
||||
parts: the partitioning of the value or None
|
||||
sharded_aval: the aval of the value inside the outer pmap, an instance of
|
||||
a ShapedArray.
|
||||
map_axis: the axis along which the value is mapped in the outer pmap
|
||||
Returns:
|
||||
A ShardingSpec.
|
||||
"""
|
||||
assert isinstance(sharded_aval, ShapedArray), sharded_aval
|
||||
replication_factor, ragged = divmod(nrep, axis_size)
|
||||
assert not ragged
|
||||
# get the sharding spec from inner sharded_jits as if we weren't in a pmap
|
||||
pspec = partitioned_sharding_spec(npart, parts, sharded_aval)
|
||||
maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),)
|
||||
if map_axis is not None:
|
||||
sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis])
|
||||
def shift_sharded_axis(a: MeshDimAssignment):
|
||||
if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis:
|
||||
return ShardedAxis(a.axis + 1)
|
||||
return a
|
||||
# replication_factor represents the product of inner pmaps, so it goes
|
||||
# after the outer pmapped axis at index 0
|
||||
return ShardingSpec(
|
||||
sharding=tuple_insert(pspec.sharding, map_axis, Unstacked(axis_size)),
|
||||
mesh_mapping=it.chain([ShardedAxis(sharded_in_axis)],
|
||||
maybe_replicate,
|
||||
map(shift_sharded_axis, pspec.mesh_mapping)))
|
||||
else:
|
||||
return ShardingSpec(
|
||||
sharding=pspec.sharding,
|
||||
mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping)
|
||||
|
||||
def partitioned_sharding_spec(num_partitions: int,
|
||||
partitions: Optional[Sequence[int]],
|
||||
aval) -> ShardingSpec:
|
||||
if partitions is None:
|
||||
maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),)
|
||||
return ShardingSpec(
|
||||
sharding=[_UNSHARDED_INSTANCE] * len(aval.shape),
|
||||
mesh_mapping=maybe_replicate)
|
||||
else:
|
||||
assert len(partitions) == len(aval.shape)
|
||||
return ShardingSpec(
|
||||
# Chunked expects a list of integers
|
||||
sharding=map(Chunked, [[x] for x in partitions]),
|
||||
mesh_mapping=map(ShardedAxis, range(len(partitions))))
|
||||
|
||||
|
||||
class ExecuteReplicated:
|
||||
"""The logic to shard inputs, execute a replicated model, returning outputs."""
|
||||
@ -2893,18 +2623,18 @@ class UnloadedMeshExecutable:
|
||||
use_auto_spmd_partitioning=auto_spmd_lowering,
|
||||
env_options_overrides=compiler_options,
|
||||
)
|
||||
opts = compile_options.executable_build_options
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None
|
||||
compile_options.executable_build_options.auto_spmd_partitioning_mesh_shape = \
|
||||
list(mesh.shape.values())
|
||||
compile_options.executable_build_options.auto_spmd_partitioning_mesh_ids = \
|
||||
_get_logical_mesh_ids(list(mesh.shape.values())).reshape(-1)
|
||||
opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values())
|
||||
opts.auto_spmd_partitioning_mesh_ids = (
|
||||
sharding_specs.get_logical_mesh_ids(list(mesh.shape.values()))
|
||||
.reshape(-1))
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
|
||||
if _allow_propagation_to_outputs is None:
|
||||
_allow_propagation_to_outputs = [False] * len(out_shardings)
|
||||
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \
|
||||
_allow_propagation_to_outputs
|
||||
opts.allow_spmd_sharding_propagation_to_output = _allow_propagation_to_outputs
|
||||
|
||||
if hasattr(backend, "compile_replicated"):
|
||||
return _compile_replicated_mesh_executable_from_hlo(
|
||||
@ -3308,28 +3038,6 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
|
||||
_check_aval(v.aval, what_thunk)
|
||||
|
||||
|
||||
def _make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes):
|
||||
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
|
||||
sharding = [_UNSHARDED_INSTANCE] * num_dimensions
|
||||
next_sharded_axis = 0
|
||||
# NOTE: sorted is stable, which is important when multiple resources
|
||||
# map to the same axis.
|
||||
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
|
||||
chunked = sharding[axis]
|
||||
if isinstance(chunked, NoSharding):
|
||||
chunked = Chunked([])
|
||||
sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]])
|
||||
assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \
|
||||
"Value mapped to the same mesh axis twice"
|
||||
mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis)
|
||||
next_sharded_axis += 1
|
||||
return ShardingSpec(sharding, mesh_mapping)
|
||||
|
||||
|
||||
def new_mesh_sharding_specs(axis_sizes, axis_names):
|
||||
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
|
||||
return partial(_make_sharding_spec, axis_sizes, mesh_axis_pos)
|
||||
|
||||
|
||||
def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False):
|
||||
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
|
||||
@ -3350,7 +3058,8 @@ def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False):
|
||||
'axis size should be zero but got '
|
||||
f'{aval_shape[axis] % axis_sizes[name]}')
|
||||
aval_shape[axis] //= axis_sizes[name]
|
||||
return _make_sharding_spec(axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes)
|
||||
return sharding_specs.make_sharding_spec(
|
||||
axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes)
|
||||
return mk_sharding_spec
|
||||
|
||||
|
||||
@ -3367,3 +3076,13 @@ def device_put(x, devices: Sequence[xc.ArrayImpl],
|
||||
return [jax.device_put(x, device) for device in devices]
|
||||
else:
|
||||
return [jax.device_put(val, device) for val, device in safe_zip(x, devices)]
|
||||
|
||||
# TODO(phawkins): fix external users not to use these functions.
|
||||
def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
|
||||
return sharding_specs.create_pmap_sharding_spec(
|
||||
aval.shape, sharded_dim, sharded_dim_size)
|
||||
|
||||
def _pmap_sharding_spec(nrep, axis_size, npart, parts,
|
||||
sharded_aval, map_axis: Optional[int]) -> ShardingSpec:
|
||||
return sharding_specs.pmap_sharding_spec(nrep, axis_size, npart, parts,
|
||||
sharded_aval.shape, map_axis)
|
||||
|
@ -31,6 +31,7 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import typing
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.config import config
|
||||
@ -293,8 +294,8 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla):
|
||||
return sharding
|
||||
elif isinstance(sharding, PmapSharding):
|
||||
key_shape = aval.dtype.impl.key_shape
|
||||
trailing_sharding = [pxla.NoSharding()] * len(key_shape)
|
||||
phys_sharding_spec = pxla.ShardingSpec(
|
||||
trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape)
|
||||
phys_sharding_spec = sharding_specs.ShardingSpec(
|
||||
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
|
||||
mesh_mapping=sharding.sharding_spec.mesh_mapping)
|
||||
return PmapSharding(devices=sharding.devices,
|
||||
|
@ -20,15 +20,14 @@ import operator as op
|
||||
from typing import (Any, Sequence, List, Tuple, Optional, Mapping, Dict, Set,
|
||||
FrozenSet, Union, cast)
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -296,7 +295,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
array_mapping = get_array_mapping(self._parsed_pspec)
|
||||
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
|
||||
# since we don't really need sharding spec.
|
||||
sharding_spec = pxla.new_mesh_sharding_specs(
|
||||
sharding_spec = sharding_specs.new_mesh_sharding_specs(
|
||||
self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping)
|
||||
# Used in `with_sharding_constraint`.
|
||||
special_axes = {}
|
||||
@ -369,11 +368,11 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
@use_cpp_class(xc.PmapSharding)
|
||||
class PmapSharding(XLACompatibleSharding):
|
||||
devices: np.ndarray
|
||||
sharding_spec: pxla.ShardingSpec
|
||||
sharding_spec: sharding_specs.ShardingSpec
|
||||
|
||||
@use_cpp_method()
|
||||
def __init__(self, devices: Union[Sequence[Device], np.ndarray],
|
||||
sharding_spec: pxla.ShardingSpec):
|
||||
sharding_spec: sharding_specs.ShardingSpec):
|
||||
self.devices = np.asarray(devices)
|
||||
# The sharding spec should be pmap's sharding spec.
|
||||
self.sharding_spec = sharding_spec
|
||||
@ -421,12 +420,12 @@ class PmapSharding(XLACompatibleSharding):
|
||||
"""
|
||||
# The dtype doesn't matter here. Its only used for creating the
|
||||
# sharding_spec.
|
||||
aval = core.ShapedArray(shape, np.int32)
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval, sharded_dim)
|
||||
sharding_spec = sharding_specs.create_pmap_sharding_spec(
|
||||
tuple(shape), sharded_dim)
|
||||
|
||||
num_ways_sharded = None
|
||||
for s in sharding_spec.sharding:
|
||||
if isinstance(s, pxla.Unstacked):
|
||||
if isinstance(s, sharding_specs.Unstacked):
|
||||
num_ways_sharded = s.size
|
||||
if num_ways_sharded is None:
|
||||
raise NotImplementedError(
|
||||
@ -444,7 +443,7 @@ class PmapSharding(XLACompatibleSharding):
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
self.shard_shape(global_shape) # raises a good error message
|
||||
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
|
||||
indices = sharding_specs.spec_to_indices(global_shape, self.sharding_spec)
|
||||
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
|
||||
|
||||
@functools.cached_property
|
||||
@ -459,7 +458,7 @@ class PmapSharding(XLACompatibleSharding):
|
||||
sharded_dim = None
|
||||
sharded_dim_size = None
|
||||
for i, s in enumerate(self.sharding_spec.sharding):
|
||||
if isinstance(s, pxla.Unstacked):
|
||||
if isinstance(s, sharding_specs.Unstacked):
|
||||
sharded_dim = i
|
||||
sharded_dim_size = s.size
|
||||
break
|
||||
|
346
jax/_src/sharding_specs.py
Normal file
346
jax/_src/sharding_specs.py
Normal file
@ -0,0 +1,346 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# A ShardingSpec describes at a high level how a logical array is sharded across
|
||||
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
|
||||
# describe how to shard inputs to a parallel computation). spec_to_indices()
|
||||
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
|
||||
# how the sharded array is "laid out" across devices. Given a sequence of
|
||||
# devices, we shard the data across the devices in row-major order, with
|
||||
# replication treated as an extra inner dimension.
|
||||
#
|
||||
# For example, given the logical data array [1, 2, 3, 4], if we were to
|
||||
# partition this array 4 ways with a replication factor of 2, for a total of 8
|
||||
# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
|
||||
#
|
||||
# This encoding is assumed by various parts of the system, e.g. generating
|
||||
# replica groups for collective operations.
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import op_shardings
|
||||
from jax._src import util
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
unsafe_map, map = map, util.safe_map
|
||||
|
||||
NoSharding = pmap_lib.NoSharding
|
||||
Chunked = pmap_lib.Chunked
|
||||
Unstacked = pmap_lib.Unstacked
|
||||
|
||||
_UNSHARDED_INSTANCE = NoSharding()
|
||||
|
||||
ShardedAxis = pmap_lib.ShardedAxis
|
||||
Replicated = pmap_lib.Replicated
|
||||
MeshDimAssignment = Union[ShardedAxis, Replicated]
|
||||
|
||||
ShardingSpec = pmap_lib.ShardingSpec
|
||||
|
||||
OpShardingType = Any
|
||||
|
||||
|
||||
|
||||
def _sharding_spec_mesh_shape(self):
|
||||
sharded_axis_sizes = []
|
||||
for sharding in self.sharding:
|
||||
if isinstance(sharding, NoSharding):
|
||||
continue
|
||||
elif isinstance(sharding, Unstacked):
|
||||
sharded_axis_sizes.append(sharding.size)
|
||||
elif isinstance(sharding, Chunked):
|
||||
sharded_axis_sizes.extend(sharding.chunks)
|
||||
else:
|
||||
util.assert_unreachable(sharding)
|
||||
return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis)
|
||||
else a.replicas
|
||||
for a in self.mesh_mapping)
|
||||
|
||||
|
||||
def get_logical_mesh_ids(mesh_shape):
|
||||
return np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
|
||||
|
||||
|
||||
_MeshAxisName = Any
|
||||
|
||||
def sharding_spec_sharding_proto(
|
||||
self, special_axes: Mapping[int, OpShardingType] = {}):
|
||||
"""Converts a ShardingSpec to an OpSharding proto.
|
||||
|
||||
See
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
|
||||
for details on the OpSharding proto.
|
||||
Unfortunately the semantics are not very well described in the proto spec, but
|
||||
the code here might help:
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py
|
||||
"""
|
||||
mesh_shape = cast(Tuple[int, ...], self.mesh_shape)
|
||||
mesh = get_logical_mesh_ids(self.mesh_shape)
|
||||
|
||||
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
|
||||
replicated_maxes = [] # lists mesh axis identifiers to replicate over
|
||||
for maxis, assignment in enumerate(self.mesh_mapping):
|
||||
if isinstance(assignment, Replicated):
|
||||
replicated_maxes.append((maxis, assignment.replicas))
|
||||
elif isinstance(assignment, ShardedAxis):
|
||||
sharded_axes[assignment.axis] = maxis
|
||||
else:
|
||||
util.assert_unreachable(assignment)
|
||||
|
||||
proto = xc.OpSharding()
|
||||
if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes:
|
||||
proto.type = xc.OpSharding.Type.REPLICATED
|
||||
return proto
|
||||
else:
|
||||
proto.type = xc.OpSharding.Type.OTHER
|
||||
|
||||
mesh_permutation = []
|
||||
new_mesh_shape = []
|
||||
next_sharded_axis = 0
|
||||
for axis, sharding in enumerate(self.sharding):
|
||||
if isinstance(sharding, NoSharding):
|
||||
new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over
|
||||
elif isinstance(sharding, Chunked):
|
||||
for nchunks in sharding.chunks:
|
||||
maxis = sharded_axes[next_sharded_axis]
|
||||
assert mesh_shape[maxis] == nchunks
|
||||
mesh_permutation.append(maxis)
|
||||
next_sharded_axis += 1
|
||||
new_mesh_shape.append(int(np.prod(sharding.chunks)))
|
||||
elif isinstance(sharding, Unstacked):
|
||||
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
|
||||
else:
|
||||
util.assert_unreachable(sharding)
|
||||
|
||||
# Create a partial sharding proto if tensor is replicated or partitioned
|
||||
# specially over some mesh axes.
|
||||
if replicated_maxes:
|
||||
last_tile_dims = []
|
||||
axes_by_type: Dict[OpShardingType, List[_MeshAxisName]] = {}
|
||||
size_by_type: Dict[OpShardingType, int] = collections.defaultdict(lambda: 1)
|
||||
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys()))
|
||||
for axis, size in replicated_maxes:
|
||||
ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED)
|
||||
axes_by_type.setdefault(ty, []).append(axis)
|
||||
size_by_type[ty] *= size
|
||||
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
|
||||
last_tile_dims.append(ty)
|
||||
new_mesh_shape.append(size_by_type[ty])
|
||||
mesh_permutation.extend(axes)
|
||||
proto.last_tile_dims = last_tile_dims
|
||||
|
||||
proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape)
|
||||
proto.tile_assignment_dimensions = list(proto_mesh.shape)
|
||||
proto.tile_assignment_devices = list(proto_mesh.flat)
|
||||
return proto
|
||||
|
||||
|
||||
def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
||||
"""Returns NumPy-style indices corresponding to a sharding spec.
|
||||
|
||||
Args:
|
||||
shape: The shape of the logical array being sharded.
|
||||
|
||||
Returns:
|
||||
An ndarray with the same shape as the logical mesh (as derived form
|
||||
`mesh_mapping`). Each entry is a NumPy-style index selecting the subset of
|
||||
the data array to be placed on a corresponding device. The indices can be
|
||||
ints, slice objects with step=1, or tuples of those.
|
||||
"""
|
||||
assert len(shape) == len(self.sharding), (shape, self.sharding)
|
||||
|
||||
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding)
|
||||
# Take the op sharding indices generation route for pjit/xmap cases.
|
||||
if not has_unstacked:
|
||||
op_sharding_proto = sharding_spec_sharding_proto(self)
|
||||
return op_shardings.op_sharding_to_numpy_indices(
|
||||
op_sharding_proto, shape, math.prod(self.mesh_shape)
|
||||
).reshape(self.mesh_shape)
|
||||
|
||||
axis_indices: List[Sequence[Index]] = []
|
||||
shard_indices_shape = []
|
||||
for dim, sharding in enumerate(self.sharding):
|
||||
axis_size = shape[dim]
|
||||
if isinstance(sharding, NoSharding):
|
||||
axis_indices.append([slice(None)])
|
||||
# NOTE: We don't append unsharded dimensions to shard_indices_shape here,
|
||||
# because they do not appear in the mesh mapping.
|
||||
elif isinstance(sharding, Unstacked):
|
||||
assert axis_size == sharding.size, f'{axis_size} != {sharding.size}'
|
||||
axis_indices.append(range(axis_size))
|
||||
shard_indices_shape.append(axis_size)
|
||||
elif isinstance(sharding, Chunked):
|
||||
total_chunks = int(np.prod(sharding.chunks))
|
||||
shard_size, ragged = divmod(axis_size, total_chunks)
|
||||
assert not ragged, (axis_size, total_chunks, dim)
|
||||
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
|
||||
for i in range(total_chunks)])
|
||||
shard_indices_shape.extend(sharding.chunks)
|
||||
else:
|
||||
util.assert_unreachable(sharding)
|
||||
|
||||
# shard_indices is an ndarray representing the sharded axes of the logical array,
|
||||
# with each dimension having size equal to the number of shards across the corresponding
|
||||
# logical array dimension, and each element containing the multi-dimensional index that
|
||||
# is used to extract the corresponding shard of the logical array.
|
||||
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
|
||||
for i, idxs in enumerate(itertools.product(*axis_indices)):
|
||||
shard_indices[i] = idxs
|
||||
shard_indices = shard_indices.reshape(shard_indices_shape)
|
||||
|
||||
# Ensure that each sharded axis is used exactly once in the mesh mapping
|
||||
num_sharded_dim = len(shard_indices_shape)
|
||||
sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)]
|
||||
assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and
|
||||
len(sharded_dim_perm) == num_sharded_dim)
|
||||
# Replicate/reorder the indices according to the mesh mapping
|
||||
replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated))
|
||||
replica_dim, sharded_dim = itertools.count(0), iter(sharded_dim_perm)
|
||||
perm = [next(replica_dim) if isinstance(a, Replicated) else
|
||||
len(replica_sizes) + next(sharded_dim)
|
||||
for a in self.mesh_mapping]
|
||||
return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape)
|
||||
.transpose(perm))
|
||||
|
||||
def _sharding_spec_repr(self):
|
||||
return f'ShardingSpec({self.sharding}, {self.mesh_mapping})'
|
||||
|
||||
|
||||
ShardingSpec.mesh_shape = property(_sharding_spec_mesh_shape)
|
||||
ShardingSpec.sharding_proto = sharding_spec_sharding_proto
|
||||
ShardingSpec.indices = _sharding_spec_indices
|
||||
# mypy raises: error: Cannot assign to a method [assignment]
|
||||
ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore
|
||||
|
||||
|
||||
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
|
||||
def spec_to_indices(shape: Sequence[int],
|
||||
spec: ShardingSpec) -> Tuple[Index, ...]:
|
||||
"""Returns numpy-style indices corresponding to a sharding spec.
|
||||
|
||||
Each index describes a shard of the array. The order of the indices is the
|
||||
same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out
|
||||
row-major).
|
||||
|
||||
Args:
|
||||
shape: The shape of the logical array being sharded.
|
||||
spec: Describes how the array is sharded and how the shards are assigned to
|
||||
the logical mesh.
|
||||
|
||||
Returns:
|
||||
A tuple of length equal to the size of the mesh (inferred as the product of
|
||||
sharded dimension sizes and all replication factors). Each element is an
|
||||
int, a slice object with step=1, or a tuple thereof, to be treated as an
|
||||
index into the full logical array.
|
||||
"""
|
||||
return tuple(spec.indices(shape).flat) # type: ignore
|
||||
|
||||
|
||||
def partitioned_sharding_spec(num_partitions: int,
|
||||
partitions: Optional[Sequence[int]],
|
||||
shape: Sequence[int]) -> ShardingSpec:
|
||||
if partitions is None:
|
||||
maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),)
|
||||
return ShardingSpec(
|
||||
sharding=[_UNSHARDED_INSTANCE] * len(shape),
|
||||
mesh_mapping=maybe_replicate)
|
||||
else:
|
||||
assert len(partitions) == len(shape)
|
||||
return ShardingSpec(
|
||||
# Chunked expects a list of integers
|
||||
sharding=map(Chunked, [[x] for x in partitions]),
|
||||
mesh_mapping=map(ShardedAxis, range(len(partitions))))
|
||||
|
||||
|
||||
def make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes):
|
||||
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
|
||||
sharding = [_UNSHARDED_INSTANCE] * num_dimensions
|
||||
next_sharded_axis = 0
|
||||
# NOTE: sorted is stable, which is important when multiple resources
|
||||
# map to the same axis.
|
||||
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
|
||||
chunked = sharding[axis]
|
||||
if isinstance(chunked, NoSharding):
|
||||
chunked = Chunked([])
|
||||
sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]])
|
||||
assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \
|
||||
"Value mapped to the same mesh axis twice"
|
||||
mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis)
|
||||
next_sharded_axis += 1
|
||||
return ShardingSpec(sharding, mesh_mapping)
|
||||
|
||||
|
||||
def new_mesh_sharding_specs(axis_sizes, axis_names):
|
||||
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
|
||||
return functools.partial(make_sharding_spec, axis_sizes, mesh_axis_pos)
|
||||
|
||||
def pmap_sharding_spec(nrep, axis_size, npart, parts,
|
||||
sharded_shape: Sequence[int],
|
||||
map_axis: Optional[int]) -> ShardingSpec:
|
||||
"""Sharding spec for arguments or results of a pmap.
|
||||
Args:
|
||||
nrep: number of local XLA replicas (product of local axis sizes)
|
||||
axis_size: local axis size for outer pmap
|
||||
npart: total number of XLA partitions (required by sharded_jit calls)
|
||||
parts: the partitioning of the value or None
|
||||
sharded_aval: the aval of the value inside the outer pmap, an instance of
|
||||
a ShapedArray.
|
||||
map_axis: the axis along which the value is mapped in the outer pmap
|
||||
Returns:
|
||||
A ShardingSpec.
|
||||
"""
|
||||
replication_factor, ragged = divmod(nrep, axis_size)
|
||||
assert not ragged
|
||||
# get the sharding spec from inner sharded_jits as if we weren't in a pmap
|
||||
pspec = partitioned_sharding_spec(npart, parts, sharded_shape)
|
||||
maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),)
|
||||
if map_axis is not None:
|
||||
sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis])
|
||||
def shift_sharded_axis(a: MeshDimAssignment):
|
||||
if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis:
|
||||
return ShardedAxis(a.axis + 1)
|
||||
return a
|
||||
# replication_factor represents the product of inner pmaps, so it goes
|
||||
# after the outer pmapped axis at index 0
|
||||
return ShardingSpec(
|
||||
sharding=util.tuple_insert(
|
||||
pspec.sharding, map_axis, Unstacked(axis_size)),
|
||||
mesh_mapping=itertools.chain(
|
||||
[ShardedAxis(sharded_in_axis)], maybe_replicate,
|
||||
map(shift_sharded_axis, pspec.mesh_mapping)))
|
||||
else:
|
||||
return ShardingSpec(
|
||||
sharding=pspec.sharding,
|
||||
mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping)
|
||||
|
||||
|
||||
def create_pmap_sharding_spec(shape: Tuple[int, ...], sharded_dim: int = 0,
|
||||
sharded_dim_size: Optional[int] = None):
|
||||
if sharded_dim is not None:
|
||||
sharded_shape = shape[:sharded_dim] + shape[sharded_dim+1:]
|
||||
if sharded_dim_size is None:
|
||||
sharded_dim_size = shape[sharded_dim]
|
||||
else:
|
||||
assert sharded_dim_size is not None
|
||||
sharded_shape = shape
|
||||
|
||||
return pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None,
|
||||
sharded_shape, sharded_dim)
|
@ -17,7 +17,6 @@ from jax._src.interpreters.pxla import (
|
||||
ArrayMapping as ArrayMapping,
|
||||
ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified,
|
||||
AvalDimSharding as AvalDimSharding,
|
||||
Chunked as Chunked,
|
||||
EmapInfo as EmapInfo,
|
||||
ExecuteReplicated as ExecuteReplicated,
|
||||
Index as Index,
|
||||
@ -27,8 +26,6 @@ from jax._src.interpreters.pxla import (
|
||||
MeshComputation as MeshComputation,
|
||||
MeshDimAssignment as MeshDimAssignment,
|
||||
MeshExecutable as MeshExecutable,
|
||||
NoSharding as NoSharding,
|
||||
OpShardingType as OpShardingType,
|
||||
OrderedDictType as OrderedDictType,
|
||||
ParallelCallableInfo as ParallelCallableInfo,
|
||||
PartitionInfo as PartitionInfo,
|
||||
@ -37,18 +34,14 @@ from jax._src.interpreters.pxla import (
|
||||
PmapExecutable as PmapExecutable,
|
||||
PxlaResultHandler as PxlaResultHandler,
|
||||
ReplicaInfo as ReplicaInfo,
|
||||
Replicated as Replicated,
|
||||
ResultsHandler as ResultsHandler,
|
||||
SPMDBatchTrace as SPMDBatchTrace,
|
||||
ShardInfo as ShardInfo,
|
||||
ShardedAxis as ShardedAxis,
|
||||
ShardingSpec as ShardingSpec,
|
||||
TileManual as TileManual,
|
||||
TileVectorize as TileVectorize,
|
||||
TilingMethod as TilingMethod,
|
||||
UnloadedMeshExecutable as UnloadedMeshExecutable,
|
||||
UnloadedPmapExecutable as UnloadedPmapExecutable,
|
||||
Unstacked as Unstacked,
|
||||
WeakRefList as WeakRefList,
|
||||
_UNSPECIFIED as _UNSPECIFIED,
|
||||
_create_pmap_sharding_spec as _create_pmap_sharding_spec,
|
||||
@ -76,9 +69,7 @@ from jax._src.interpreters.pxla import (
|
||||
maybe_extend_axis_env as maybe_extend_axis_env,
|
||||
mesh_sharding_specs as mesh_sharding_specs,
|
||||
multi_host_supported_collectives as multi_host_supported_collectives,
|
||||
new_mesh_sharding_specs as new_mesh_sharding_specs,
|
||||
parallel_callable as parallel_callable,
|
||||
partitioned_sharding_spec as partitioned_sharding_spec,
|
||||
reconcile_num_partitions as reconcile_num_partitions,
|
||||
replicate as replicate,
|
||||
resource_typecheck as resource_typecheck,
|
||||
@ -88,8 +79,6 @@ from jax._src.interpreters.pxla import (
|
||||
shard_aval as shard_aval,
|
||||
shard_aval_handlers as shard_aval_handlers,
|
||||
shard_to_full_p as shard_to_full_p,
|
||||
sharding_spec_sharding_proto as sharding_spec_sharding_proto,
|
||||
spec_to_indices as spec_to_indices,
|
||||
spmd_primitive_batchers as spmd_primitive_batchers,
|
||||
stage_parallel_callable as stage_parallel_callable,
|
||||
tile_aval_nd as tile_aval_nd,
|
||||
@ -114,6 +103,19 @@ from jax._src.op_shardings import (
|
||||
op_sharding_to_indices as op_sharding_to_indices,
|
||||
)
|
||||
|
||||
from jax._src.sharding_specs import (
|
||||
Chunked as Chunked,
|
||||
NoSharding as NoSharding,
|
||||
OpShardingType as OpShardingType,
|
||||
Replicated as Replicated,
|
||||
ShardedAxis as ShardedAxis,
|
||||
ShardingSpec as ShardingSpec,
|
||||
Unstacked as Unstacked,
|
||||
new_mesh_sharding_specs as new_mesh_sharding_specs,
|
||||
sharding_spec_sharding_proto as sharding_spec_sharding_proto,
|
||||
spec_to_indices as spec_to_indices,
|
||||
)
|
||||
|
||||
# Deprecations
|
||||
|
||||
from jax._src.mesh import Mesh as _deprecated_Mesh
|
||||
|
@ -40,10 +40,10 @@ from jax._src.lax import parallel
|
||||
from jax._src import api as src_api
|
||||
from jax import random
|
||||
from jax._src import core
|
||||
from jax._src.core import ShapedArray
|
||||
from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
linearize, device_put)
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
@ -118,13 +118,11 @@ ignore_xmap_warning = partial(
|
||||
|
||||
def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
|
||||
devices=None, sharded_dim_size=None):
|
||||
dtype = np.int32
|
||||
aval = ShapedArray(input_shape, dtype)
|
||||
|
||||
if input_data is None:
|
||||
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes, sharded_dim_size)
|
||||
sharding_spec = sharding_specs.create_pmap_sharding_spec(
|
||||
input_shape, in_axes, sharded_dim_size)
|
||||
|
||||
if devices is None:
|
||||
devices = jax.devices()
|
||||
@ -2829,7 +2827,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
shape = (4, 8)
|
||||
spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])),
|
||||
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
((slice(0,2), slice(0,4)),
|
||||
(slice(0,2), slice(4,8)),
|
||||
(slice(2,4), slice(0,4)),
|
||||
@ -2839,7 +2837,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
shape = (4, 8)
|
||||
spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])),
|
||||
mesh_mapping=map(pxla.ShardedAxis, (1, 0)))
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
((slice(0,2), slice(0,4)),
|
||||
(slice(2,4), slice(0,4)),
|
||||
(slice(0,2), slice(4,8)),
|
||||
@ -2851,7 +2849,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
mesh_mapping=(pxla.Replicated(2),
|
||||
pxla.ShardedAxis(1),
|
||||
pxla.ShardedAxis(0)))
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
((slice(0,2), slice(0,4)),
|
||||
(slice(2,4), slice(0,4)),
|
||||
(slice(0,2), slice(4,8)),
|
||||
@ -2861,7 +2859,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
shape = (4, 8)
|
||||
spec = pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.NoSharding()),
|
||||
mesh_mapping=(pxla.ShardedAxis(0),))
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
((slice(0,2), slice(None)),
|
||||
(slice(2,4), slice(None))))
|
||||
|
||||
@ -2869,14 +2867,14 @@ class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
shape = (4, 8)
|
||||
spec = pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
|
||||
mesh_mapping=())
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
((slice(None), slice(None)),))
|
||||
|
||||
def testUnmaterializedAxis(self):
|
||||
shape = (4, 8)
|
||||
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(4), pxla.NoSharding()),
|
||||
mesh_mapping=(pxla.ShardedAxis(0),))
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
((0, slice(None)),
|
||||
(1, slice(None)),
|
||||
(2, slice(None)),
|
||||
@ -2885,7 +2883,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
shape = (2, 2)
|
||||
spec = pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.Unstacked(2)),
|
||||
mesh_mapping=(pxla.ShardedAxis(0),))
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
((slice(None), 0),
|
||||
(slice(None), 1)))
|
||||
|
||||
@ -2893,14 +2891,14 @@ class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
shape = (2, 8)
|
||||
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
|
||||
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3)))
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
tuple([(0, slice(None))] * 3 + [(1, slice(None))] * 3))
|
||||
|
||||
def testReplicationPosition2(self):
|
||||
shape = (2, 8)
|
||||
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])),
|
||||
mesh_mapping=(pxla.ShardedAxis(0), pxla.ShardedAxis(1), pxla.Replicated(3)))
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
((0, slice(0, 4)), (0, slice(0, 4)), (0, slice(0, 4)),
|
||||
(0, slice(4, 8)), (0, slice(4, 8)), (0, slice(4, 8)),
|
||||
(1, slice(0, 4)), (1, slice(0, 4)), (1, slice(0, 4)),
|
||||
@ -2910,7 +2908,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
shape = (2, 8)
|
||||
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])),
|
||||
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3), pxla.ShardedAxis(1)))
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
((0, slice(0, 4)), (0, slice(4, 8)),
|
||||
(0, slice(0, 4)), (0, slice(4, 8)),
|
||||
(0, slice(0, 4)), (0, slice(4, 8)),
|
||||
@ -2922,7 +2920,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
shape = (2, 8)
|
||||
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
|
||||
mesh_mapping=(pxla.Replicated(3), pxla.ShardedAxis(0)))
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
tuple([(0, slice(None)), (1, slice(None))] * 3))
|
||||
|
||||
def testMultipleReplications(self):
|
||||
@ -2933,7 +2931,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
pxla.ShardedAxis(0), pxla.Replicated(2),
|
||||
pxla.ShardedAxis(1)))
|
||||
self.assertEqual(
|
||||
pxla.spec_to_indices(shape, spec),
|
||||
sharding_specs.spec_to_indices(shape, spec),
|
||||
((0, slice(None), slice(0, 2)), (0, slice(None), slice(2, 4)),
|
||||
(0, slice(None), slice(0, 2)), (0, slice(None), slice(2, 4)),
|
||||
(1, slice(None), slice(0, 2)), (1, slice(None), slice(2, 4)),
|
||||
@ -2943,7 +2941,7 @@ class SpecToIndicesTest(jtu.JaxTestCase):
|
||||
shape = ()
|
||||
spec = pxla.ShardingSpec(sharding=(),
|
||||
mesh_mapping=(pxla.Replicated(3),))
|
||||
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
||||
self.assertEqual(sharding_specs.spec_to_indices(shape, spec),
|
||||
((), (), ()))
|
||||
|
||||
|
||||
@ -3003,7 +3001,7 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
mesh_mapping=(pxla.Replicated(2), pxla.Replicated(3)))],
|
||||
])
|
||||
def testShardArgs(self, shape, spec, make_arg):
|
||||
indices = pxla.spec_to_indices(shape, spec)
|
||||
indices = sharding_specs.spec_to_indices(shape, spec)
|
||||
nshards = len(indices)
|
||||
if jax.device_count() < nshards:
|
||||
raise SkipTest
|
||||
@ -3014,7 +3012,8 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
|
||||
else:
|
||||
sharding = jax.sharding.GSPMDSharding(
|
||||
jax.devices()[:nshards], pxla.sharding_spec_sharding_proto(spec))
|
||||
jax.devices()[:nshards],
|
||||
sharding_specs.sharding_spec_sharding_proto(spec))
|
||||
|
||||
results = pxla.shard_args(
|
||||
jax.devices()[:nshards], [indices], [sharding], [arg]
|
||||
|
Loading…
x
Reference in New Issue
Block a user