mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

We're switching to the new terminology to avoid confusion in cases where multiple jax processes are running on a single host, and each process has a unique process_index/host_id. This keeps aliases for the old `host_id` APIs for now, but these will eventually be removed. This was originally commited in b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in 14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test failures from renaming the local_devices argument name. This change is identical except it also adds staging for the argument name change.
1591 lines
68 KiB
Python
1591 lines
68 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Implementation of pmap and related functionality."""
|
|
|
|
# A ShardingSpec describes at a high level how a logical array is sharded across
|
|
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
|
|
# describe how to shard inputs to a parallel computation). spec_to_indices()
|
|
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
|
|
# how the sharded array is "laid out" across devices. Given a sequence of
|
|
# devices, we shard the data across the devices in row-major order, with
|
|
# replication treated as an extra inner dimension.
|
|
#
|
|
# For example, given the logical data array [1, 2, 3, 4], if we were to
|
|
# partition this array 4 ways with a replication factor of 2, for a total of 8
|
|
# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
|
|
#
|
|
# This encoding is assumed by various parts of the system, e.g. generating
|
|
# replica groups for collective operations.
|
|
|
|
import sys
|
|
from contextlib import contextmanager
|
|
from collections import defaultdict, OrderedDict
|
|
import itertools as it
|
|
import operator as op
|
|
import threading
|
|
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
|
|
Type, Union, Iterable, NamedTuple, TYPE_CHECKING)
|
|
|
|
from absl import logging
|
|
import numpy as np
|
|
|
|
from .._src.config import config
|
|
from .. import core
|
|
from .. import linear_util as lu
|
|
from ..abstract_arrays import array_types
|
|
from ..core import ConcreteArray, ShapedArray
|
|
from .._src.util import (partial, unzip3, prod, safe_map, safe_zip,
|
|
extend_name_stack, wrap_name, assert_unreachable,
|
|
tuple_insert, tuple_delete, distributed_debug_log)
|
|
from ..lib import xla_bridge as xb
|
|
from ..lib import xla_client as xc
|
|
from ..lib import pmap_lib
|
|
from ..tree_util import tree_flatten, tree_map
|
|
from . import batching
|
|
from . import partial_eval as pe
|
|
from . import xla
|
|
from . import ad
|
|
|
|
if sys.version_info >= (3, 9):
|
|
OrderedDictType = OrderedDict
|
|
else:
|
|
OrderedDictType = Dict
|
|
|
|
xops = xc.ops
|
|
|
|
unsafe_map, map = map, safe_map # type: ignore
|
|
|
|
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
|
|
|
|
|
# mypy cannot deal with the C++ types. An alternative is to use `# type: ignore`
|
|
if TYPE_CHECKING:
|
|
# We cannot use `NoSharding = Any` with mypy, otherwise you get:
|
|
# error: Cannot use isinstance() with Any type [misc]
|
|
class NoSharding:
|
|
pass
|
|
|
|
class Chunked(NamedTuple):
|
|
chunks: List[int]
|
|
|
|
class Unstacked(NamedTuple):
|
|
size: int
|
|
|
|
class ShardedAxis(NamedTuple):
|
|
axis: int
|
|
|
|
class Replicated(NamedTuple):
|
|
replicas: int
|
|
else:
|
|
# See the C++ code for comments.
|
|
NoSharding = pmap_lib.NoSharding
|
|
Chunked = pmap_lib.Chunked
|
|
Unstacked = pmap_lib.Unstacked
|
|
|
|
ShardedAxis = pmap_lib.ShardedAxis
|
|
Replicated = pmap_lib.Replicated
|
|
|
|
_UNSHARDED_INSTANCE = NoSharding()
|
|
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
|
|
MeshDimAssignment = Union[ShardedAxis, Replicated]
|
|
|
|
|
|
class ShardingSpec:
|
|
"""Describes the sharding of an ndarray.
|
|
|
|
Attributes:
|
|
sharding: specifies how the array is supposed to get partitioned into chunks.
|
|
Its length should match the rank of the array. See the docstring of
|
|
`AvalDimSharding` for the supported partitioning schemes.
|
|
mesh_mapping` describes an assignments of the array chunks created by `sharding`
|
|
to a logical device mesh. The length of the tuple is equal to the rank of the
|
|
mesh. Each mesh dimension can either get partitions of data varying along one
|
|
of the sharded dimensions, or the data can be replicated. See the docstring of
|
|
`MeshDimAssignment` for more information.
|
|
"""
|
|
sharding: Tuple[AvalDimSharding, ...]
|
|
mesh_mapping: Tuple[MeshDimAssignment, ...]
|
|
|
|
def __init__(self,
|
|
sharding: Iterable[AvalDimSharding],
|
|
mesh_mapping: Iterable[MeshDimAssignment]):
|
|
self.sharding = tuple(sharding)
|
|
assert all(x is not None for x in self.sharding)
|
|
self.mesh_mapping = tuple(mesh_mapping)
|
|
|
|
@property
|
|
def mesh_shape(self):
|
|
sharded_axis_sizes = []
|
|
for sharding in self.sharding:
|
|
if isinstance(sharding, NoSharding):
|
|
continue
|
|
elif isinstance(sharding, Unstacked):
|
|
sharded_axis_sizes.append(sharding.size)
|
|
elif isinstance(sharding, Chunked):
|
|
sharded_axis_sizes.extend(sharding.chunks)
|
|
else:
|
|
assert_unreachable(sharding)
|
|
return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis) else a.replicas
|
|
for a in self.mesh_mapping)
|
|
|
|
def sharding_proto(self):
|
|
"""Converts a ShardingSpec to an OpSharding proto.
|
|
|
|
See
|
|
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
|
|
for details on the OpSharding proto.
|
|
Unfortunately the semantics are not very well described in the proto spec, but the code here might help:
|
|
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
|
|
"""
|
|
mesh_shape = self.mesh_shape
|
|
mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
|
|
|
|
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
|
|
replicated_maxes = [] # lists mesh axis identifiers to replicate over
|
|
for maxis, assignment in enumerate(self.mesh_mapping):
|
|
if isinstance(assignment, Replicated):
|
|
replicated_maxes.append(maxis)
|
|
elif isinstance(assignment, ShardedAxis):
|
|
sharded_axes[assignment.axis] = maxis
|
|
else:
|
|
assert_unreachable(assignment)
|
|
|
|
proto = xc.OpSharding()
|
|
if len(replicated_maxes) == len(self.mesh_mapping):
|
|
proto.type = xc.OpSharding.Type.REPLICATED
|
|
return proto
|
|
else:
|
|
proto.type = xc.OpSharding.Type.OTHER
|
|
|
|
mesh_permutation = []
|
|
new_mesh_shape = []
|
|
next_sharded_axis = 0
|
|
for axis, sharding in enumerate(self.sharding):
|
|
if isinstance(sharding, NoSharding):
|
|
new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over
|
|
elif isinstance(sharding, Chunked):
|
|
for nchunks in sharding.chunks:
|
|
maxis = sharded_axes[next_sharded_axis]
|
|
assert mesh_shape[maxis] == nchunks
|
|
mesh_permutation.append(maxis)
|
|
next_sharded_axis += 1
|
|
new_mesh_shape.append(int(np.prod(sharding.chunks)))
|
|
elif isinstance(sharding, Unstacked):
|
|
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
|
|
else:
|
|
assert_unreachable(sharding)
|
|
|
|
# Create the partial sharding proto if tensor is replicated over some mesh axes
|
|
if replicated_maxes:
|
|
new_mesh_shape.append(-1)
|
|
mesh_permutation.extend(replicated_maxes)
|
|
proto.replicate_on_last_tile_dim = True
|
|
|
|
proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape)
|
|
proto.tile_assignment_dimensions = list(proto_mesh.shape)
|
|
proto.tile_assignment_devices = list(proto_mesh.flat)
|
|
return proto
|
|
|
|
def indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
|
"""Returns NumPy-style indices corresponding to a sharding spec.
|
|
|
|
Args:
|
|
shape: The shape of the logical array being sharded.
|
|
|
|
Returns:
|
|
An ndarray with the same shape as the logical mesh (as derived form
|
|
`mesh_mapping`). Each entry is a NumPy-style index selecting the subset of
|
|
the data array to be placed on a corresponding device. The indices can be
|
|
ints, slice objects with step=1, or tuples of those.
|
|
"""
|
|
assert len(shape) == len(self.sharding), (shape, self.sharding)
|
|
|
|
axis_indices: List[Sequence[Index]] = []
|
|
shard_indices_shape = []
|
|
for dim, sharding in enumerate(self.sharding):
|
|
axis_size = shape[dim]
|
|
if isinstance(sharding, NoSharding):
|
|
axis_indices.append([slice(None)])
|
|
# NOTE: We don't append unsharded dimensions to shard_indices_shape here,
|
|
# because they do not appear in the mesh mapping.
|
|
elif isinstance(sharding, Unstacked):
|
|
assert axis_size == sharding.size, f'{axis_size} != {sharding.size}'
|
|
axis_indices.append(range(axis_size))
|
|
shard_indices_shape.append(axis_size)
|
|
elif isinstance(sharding, Chunked):
|
|
total_chunks = int(np.prod(sharding.chunks))
|
|
shard_size, ragged = divmod(axis_size, total_chunks)
|
|
assert not ragged, (axis_size, total_chunks, dim)
|
|
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
|
|
for i in range(total_chunks)])
|
|
shard_indices_shape.extend(sharding.chunks)
|
|
else:
|
|
assert_unreachable(sharding)
|
|
|
|
# shard_indices is an ndarray representing the sharded axes of the logical array,
|
|
# with each dimension having size equal to the number of shards across the corresponding
|
|
# logical array dimension, and each element containing the multi-dimensional index that
|
|
# is used to extract the corresponding shard of the logical array.
|
|
shard_indices = np.empty([prod(shard_indices_shape)], dtype=np.object_)
|
|
for i, idxs in enumerate(it.product(*axis_indices)):
|
|
shard_indices[i] = idxs
|
|
shard_indices = shard_indices.reshape(shard_indices_shape)
|
|
|
|
# Ensure that each sharded axis is used exactly once in the mesh mapping
|
|
num_sharded_dim = len(shard_indices_shape)
|
|
sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)]
|
|
assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and
|
|
len(sharded_dim_perm) == num_sharded_dim)
|
|
# Replicate/reorder the indices according to the mesh mapping
|
|
replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated))
|
|
replica_dim, sharded_dim = it.count(0), iter(sharded_dim_perm)
|
|
perm = [next(replica_dim) if isinstance(a, Replicated) else
|
|
len(replica_sizes) + next(sharded_dim)
|
|
for a in self.mesh_mapping]
|
|
return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape)
|
|
.transpose(perm))
|
|
|
|
def __eq__(self, other):
|
|
return (self.sharding, self.mesh_mapping) == (other.sharding,
|
|
other.mesh_mapping)
|
|
|
|
def __hash__(self):
|
|
return hash((self.sharding, self.mesh_mapping))
|
|
|
|
def __repr__(self):
|
|
return f'ShardingSpec({self.sharding}, {self.mesh_mapping})'
|
|
|
|
def spec_to_indices(shape: Tuple[int, ...],
|
|
spec: ShardingSpec) -> Tuple[Index, ...]:
|
|
"""Returns numpy-style indices corresponding to a sharding spec.
|
|
|
|
Each index describes a shard of the array. The order of the indices is the
|
|
same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out
|
|
row-major).
|
|
|
|
Args:
|
|
shape: The shape of the logical array being sharded.
|
|
spec: Describes how the array is sharded and how the shards are assigned to
|
|
the logical mesh.
|
|
|
|
Returns:
|
|
A tuple of length equal to the size of the mesh (inferred as the product of
|
|
sharded dimension sizes and all replication factors). Each element is an
|
|
int, a slice object with step=1, or a tuple thereof, to be treated as an
|
|
index into the full logical array.
|
|
"""
|
|
return tuple(spec.indices(shape).flat) # type: ignore
|
|
|
|
|
|
### util
|
|
|
|
def identity(x): return x
|
|
|
|
# TODO(skye): expose PyLocalBuffers in xla_client
|
|
def shard_args(devices: Sequence[xb.xla_client.Device],
|
|
indices: Sequence[Sequence[Index]],
|
|
args) -> Sequence[Sequence[xb.xla_client._xla.PyLocalBuffer]]:
|
|
"""Shard each argument data array along its leading axis.
|
|
|
|
Args:
|
|
devices: sequence of Devices mapping replica index to a physical device.
|
|
indices: sequence of the same length as `args` describing how each arg
|
|
should be sharded/replicated across `devices`. Each element in `indices`
|
|
is the same length as `devices`.
|
|
args: a sequence of JaxTypes representing arguments to be sharded according
|
|
to `indices` and placed on `devices`.
|
|
|
|
Returns:
|
|
A list of length matching args, containing lists of per-device buffers
|
|
for each argument.
|
|
"""
|
|
|
|
def shard_arg(a, arg):
|
|
# The shard_arg_handlers allow an extensible set of types to be sharded, but
|
|
# inline handling for ShardedDeviceArray as a special case for performance
|
|
# NOTE: we compare indices instead of sharding_spec because
|
|
# pmap_benchmark.pmap_shard_args_benchmark indicates this is faster.
|
|
if type(arg) is ShardedDeviceArray and indices[a] == arg.indices:
|
|
return [
|
|
buf if buf.device() == d else buf.copy_to_device(d)
|
|
for d, buf in zip(devices, arg.device_buffers)
|
|
]
|
|
else:
|
|
arg = xla.canonicalize_dtype(arg)
|
|
return shard_arg_handlers[type(arg)](arg, devices, indices[a])
|
|
|
|
return [shard_arg(a, arg) for a, arg in enumerate(args)]
|
|
|
|
|
|
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {}
|
|
shard_arg_handlers[core.Unit] = \
|
|
lambda x, devices, _: device_put(core.unit, devices, replicate=True)
|
|
def _shard_array(x, devices, indices):
|
|
return device_put([x[i] for i in indices], devices)
|
|
for _t in array_types:
|
|
shard_arg_handlers[_t] = _shard_array
|
|
|
|
def _shard_device_array(x, devices, indices):
|
|
start_indices, limit_indices, removed_dims = map(tuple, unzip3(
|
|
_as_slice_indices(x, idx) for idx in indices))
|
|
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
|
return device_put(shards, devices)
|
|
shard_arg_handlers[xla._DeviceArray] = _shard_device_array
|
|
shard_arg_handlers[xla._CppDeviceArray] = _shard_device_array
|
|
|
|
|
|
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
|
|
# from the input ShardingSpec, rather than the indices. However, this would
|
|
# require duplicating the ordering logic of spec_to_indices, which is more
|
|
# subtle and more likely to change than the index logic we have to support here.
|
|
def _as_slice_indices(arr: xla.DeviceArrayProtocol, idx: Index) -> Tuple[
|
|
Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]:
|
|
"""Returns start_indices, limit_indices, removed_dims"""
|
|
start_indices = [0] * arr.ndim
|
|
limit_indices = list(arr.shape)
|
|
removed_dims = []
|
|
|
|
tuple_idx = idx if isinstance(idx, tuple) else (idx,)
|
|
for dim, sub_idx in enumerate(tuple_idx):
|
|
if isinstance(sub_idx, int):
|
|
start_indices[dim] = sub_idx
|
|
limit_indices[dim] = sub_idx + 1
|
|
removed_dims.append(dim)
|
|
elif sub_idx == slice(None):
|
|
continue
|
|
else:
|
|
assert isinstance(sub_idx, slice), sub_idx
|
|
assert isinstance(sub_idx.start, int), sub_idx
|
|
assert isinstance(sub_idx.stop, int), sub_idx
|
|
start_indices[dim] = sub_idx.start
|
|
limit_indices[dim] = sub_idx.stop
|
|
|
|
return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore
|
|
|
|
|
|
def shard_aval(size, axis: int, aval):
|
|
try:
|
|
return shard_aval_handlers[type(aval)](size, axis, aval)
|
|
except KeyError as err:
|
|
raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err
|
|
shard_aval_handlers: Dict[Type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
|
|
shard_aval_handlers[core.AbstractUnit] = lambda size, axis, x: x
|
|
def _shard_abstract_array(size, axis: int, x):
|
|
try:
|
|
if x.shape[axis] != size:
|
|
raise ValueError(f"Axis size {size} does not match dimension {axis} of "
|
|
f"shape {x.shape}")
|
|
except IndexError:
|
|
raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None
|
|
return x.update(shape=tuple_delete(x.shape, axis))
|
|
shard_aval_handlers[ShapedArray] = _shard_abstract_array
|
|
|
|
# TODO(skye): expose PyLocalBuffers in xla_client
|
|
def aval_to_result_handler(sharding_spec: Optional[ShardingSpec],
|
|
indices: Optional[Tuple[Index]],
|
|
aval: core.AbstractValue) -> Callable[
|
|
[List[xb.xla_client._xla.PyLocalBuffer]], Any]:
|
|
"""Returns a function for handling the raw buffers of a single output aval.
|
|
|
|
Args:
|
|
sharding_spec: indicates how the output is sharded across devices, or None
|
|
for non-array avals.
|
|
indices: the pre-computed result of spec_to_indices, or None for non-array
|
|
avals.
|
|
aval: the output AbstractValue.
|
|
|
|
Returns:
|
|
A function for handling the PyLocalBuffers that will eventually be produced
|
|
for this output. The function will return an object suitable for returning
|
|
to the user, e.g. a ShardedDeviceArray.
|
|
"""
|
|
try:
|
|
return pxla_result_handlers[type(aval)](sharding_spec, indices, aval)
|
|
except KeyError as err:
|
|
raise TypeError("No pxla_result_handler for type: {}".format(type(aval))
|
|
) from err
|
|
|
|
PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client._xla.PyLocalBuffer]], Any]]
|
|
pxla_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
|
pxla_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
|
|
def array_result_handler(sharding_spec, indices, aval: ShapedArray):
|
|
return lambda bufs: ShardedDeviceArray(aval, sharding_spec, bufs, indices)
|
|
pxla_result_handlers[ShapedArray] = array_result_handler
|
|
pxla_result_handlers[ConcreteArray] = array_result_handler
|
|
|
|
|
|
### lazy device-memory persistence and result handling
|
|
|
|
class ShardedDeviceArray(xla.DeviceArray): # type: ignore
|
|
"""A ShardedDeviceArray is an ndarray sharded across devices.
|
|
|
|
The purpose of a ShardedDeviceArray is to reduce the number of transfers when
|
|
executing replicated computations, by allowing results to persist on the
|
|
devices that produced them. That way dispatching a similarly replicated
|
|
computation that consumes the same sharded memory layout does not incur any
|
|
transfers.
|
|
|
|
A ShardedDeviceArray represents one logical ndarray value, and simulates the
|
|
behavior of an ndarray so that it can be treated by user code as an ndarray;
|
|
that is, it is only an optimization to reduce transfers.
|
|
|
|
Attributes:
|
|
aval: A ShapedArray indicating the shape and dtype of this array.
|
|
sharding_spec: describes how this array is sharded across `device_buffers`.
|
|
device_buffers: the buffers containing the data for this array. Each buffer
|
|
is the same shape and on a different device. Buffers are in row-major
|
|
order, with replication treated as an extra innermost dimension.
|
|
indices: the result of spec_to_indices(sharding_spec). Can optionally be
|
|
precomputed for efficiency. A list the same length as
|
|
`device_buffers`. Each index indicates what portion of the full array is
|
|
stored in the corresponding device buffer, i.e. `array[indices[i]] ==
|
|
device_buffers[i].to_py()`.
|
|
"""
|
|
__slots__ = [
|
|
"aval", "device_buffers", "sharding_spec", "indices",
|
|
"_one_replica_buffer_indices", "_npy_value"
|
|
]
|
|
|
|
# TODO(skye): expose PyLocalBuffers in xla_client
|
|
def __init__(self,
|
|
aval: ShapedArray,
|
|
sharding_spec, # TODO(skye): add type annotation back, see below
|
|
device_buffers: List[xb.xla_client._xla.PyLocalBuffer] = None,
|
|
indices: Optional[Tuple[Index, ...]] = None):
|
|
xla.DeviceArray.__init__(self)
|
|
|
|
# TODO(skye): this is temporary staging while we switch users over to
|
|
# providing sharding_spec. It assumes that any pre-existing callers are
|
|
# creating pmap-style ShardedDeviceArrays over the first dimension.
|
|
if device_buffers is None:
|
|
device_buffers = sharding_spec
|
|
sharded_aval = aval.update(shape=aval.shape[1:])
|
|
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0],
|
|
1, None, sharded_aval, 0)
|
|
|
|
# TODO(skye): assert invariants. Keep performance in mind though.
|
|
if indices is None:
|
|
indices = spec_to_indices(aval.shape, sharding_spec)
|
|
self.aval = aval
|
|
self.device_buffers = device_buffers
|
|
self.sharding_spec = sharding_spec
|
|
self.indices = indices
|
|
self._npy_value = None
|
|
self._one_replica_buffer_indices = None
|
|
if config.jax_enable_checks:
|
|
assert type(aval) is ShapedArray
|
|
|
|
@property
|
|
def one_replica_buffer_indices(self):
|
|
"""Indices of buffers containing one complete copy of the array data."""
|
|
if self._one_replica_buffer_indices is None:
|
|
one_replica_indices = []
|
|
seen_index_hashes = set()
|
|
for i, index in enumerate(self.indices):
|
|
hashed_index = _hashable_index(index)
|
|
if hashed_index not in seen_index_hashes:
|
|
one_replica_indices.append(i)
|
|
seen_index_hashes.add(hashed_index)
|
|
self._one_replica_buffer_indices = one_replica_indices
|
|
return self._one_replica_buffer_indices
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.aval.shape
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self.aval.dtype
|
|
|
|
@property
|
|
def size(self):
|
|
return prod(self.aval.shape)
|
|
|
|
@property
|
|
def ndim(self):
|
|
return len(self.aval.shape)
|
|
|
|
def copy_to_host_async(self):
|
|
for buffer_index in self.one_replica_buffer_indices:
|
|
self.device_buffers[buffer_index].copy_to_host_async()
|
|
|
|
def delete(self):
|
|
for buf in self.device_buffers:
|
|
buf.delete()
|
|
self.device_buffers = None
|
|
self._npy_value = None
|
|
|
|
def _check_if_deleted(self):
|
|
if self.device_buffers is None:
|
|
raise ValueError("ShardedDeviceArray has been deleted.")
|
|
|
|
def block_until_ready(self):
|
|
self._check_if_deleted()
|
|
for buf in self.device_buffers:
|
|
buf.block_host_until_ready()
|
|
return self
|
|
|
|
@property
|
|
def _value(self):
|
|
if self._npy_value is None:
|
|
self.copy_to_host_async()
|
|
npy_value = np.empty(self.aval.shape, self.aval.dtype)
|
|
for i in self.one_replica_buffer_indices:
|
|
npy_value[self.indices[i]] = self.device_buffers[i].to_py()
|
|
self._npy_value = npy_value
|
|
return self._npy_value
|
|
|
|
def __getitem__(self, idx):
|
|
if not isinstance(idx, tuple):
|
|
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
|
|
else:
|
|
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
|
|
try:
|
|
buf_idx = self.indices.index(cidx)
|
|
except ValueError:
|
|
# NOTE: Slow path, this will materialize the sharded array on a single
|
|
# device and use XLA's Gather to index into the resulting array.
|
|
return xla.DeviceArray.__getitem__(self, idx)
|
|
else:
|
|
buf = self.device_buffers[buf_idx]
|
|
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
|
|
return xla.make_device_array(aval, None, buf)
|
|
|
|
|
|
def _hashable_index(idx):
|
|
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x,
|
|
idx)
|
|
|
|
# The fast path is handled directly in shard_args().
|
|
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
|
|
def _shard_sharded_device_array_slow_path(x, devices, indices):
|
|
candidates = defaultdict(list)
|
|
for buf, idx in safe_zip(x.device_buffers, x.indices):
|
|
candidates[_hashable_index(idx)].append(buf)
|
|
|
|
bufs = []
|
|
for idx, device in safe_zip(indices, devices):
|
|
# Look up all buffers that contain the correct slice of the logical array.
|
|
candidates_list = candidates[_hashable_index(idx)]
|
|
if not candidates_list:
|
|
# This array isn't sharded correctly. Reshard it via host roundtrip.
|
|
# TODO(skye): more efficient reshard?
|
|
return shard_arg_handlers[type(x._value)](x._value, devices, indices)
|
|
# Try to find a candidate buffer already on the correct device,
|
|
# otherwise copy one of them.
|
|
for buf in candidates_list:
|
|
if buf.device() == device:
|
|
bufs.append(buf)
|
|
break
|
|
else:
|
|
bufs.append(buf.copy_to_device(device))
|
|
return bufs
|
|
shard_arg_handlers[ShardedDeviceArray] = _shard_sharded_device_array_slow_path
|
|
|
|
def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
|
|
return xb.constant(c, np.asarray(val), canonicalize_types=canonicalize_types)
|
|
xb.register_constant_handler(ShardedDeviceArray, _sharded_device_array_constant_handler)
|
|
|
|
core.pytype_aval_mappings[ShardedDeviceArray] = ConcreteArray
|
|
xla.device_put_handlers[ShardedDeviceArray] = xla._device_put_array
|
|
xla.pytype_aval_mappings[ShardedDeviceArray] = op.attrgetter('aval')
|
|
xla.canonicalize_dtype_handlers[ShardedDeviceArray] = identity
|
|
|
|
|
|
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
|
|
|
|
def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size,
|
|
global_axis_size, devices, name, in_axes, out_axes_thunk,
|
|
donated_invars, global_arg_shapes):
|
|
abstract_args = unsafe_map(xla.abstractify, args)
|
|
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
|
|
global_axis_size, devices, name,
|
|
in_axes, out_axes_thunk,
|
|
donated_invars, global_arg_shapes,
|
|
*abstract_args)
|
|
# Don't re-abstractify args unless logging is enabled for performance.
|
|
if config.jax_distributed_debug:
|
|
distributed_debug_log(("Running pmapped function", name),
|
|
("python function", fun.f),
|
|
("devices", devices),
|
|
("abstract args", map(xla.abstractify, args)))
|
|
return compiled_fun(*args)
|
|
|
|
@lu.cache
|
|
def parallel_callable(fun: lu.WrappedFun,
|
|
backend_name: Optional[str],
|
|
axis_name,
|
|
axis_size: int,
|
|
global_axis_size: Optional[int],
|
|
devices: Optional[Sequence[Any]],
|
|
name: str,
|
|
in_axes: Iterable[Optional[int]],
|
|
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
|
donated_invars: Iterable[bool],
|
|
global_arg_shapes,
|
|
*avals):
|
|
if devices is not None and len(devices) == 0:
|
|
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
|
|
|
|
# Determine global_axis_size for use in AxisEnv.
|
|
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
|
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
|
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
|
|
if (xb.process_count() == 1 and global_axis_size is not None and
|
|
global_axis_size != axis_size):
|
|
raise ValueError(
|
|
f"Specified axis_size {global_axis_size} doesn't match received "
|
|
f"axis_size {axis_size}.")
|
|
|
|
must_run_on_all_devices = False
|
|
no_nested_sharding = False
|
|
if global_axis_size is None:
|
|
if xb.process_count() == 1:
|
|
global_axis_size = axis_size
|
|
elif devices:
|
|
# This allows each host in a multi-host pmap to run on a different number
|
|
# of devices, but precludes nested sharding (i.e. inner pmaps or
|
|
# sharded_jits).
|
|
global_axis_size = len(devices)
|
|
no_nested_sharding = True
|
|
else:
|
|
# This assumes all hosts run on the same number of devices. We make sure
|
|
# this assumption is true by requiring that the pmap is run on all devices
|
|
# (and making the further assumption that each host has the same number of
|
|
# devices). Nested sharding is ok in this case.
|
|
global_axis_size = axis_size * xb.process_count()
|
|
assert all(len(xb.local_devices(process_index)) == xb.local_device_count()
|
|
for process_index in range(xb.process_count()))
|
|
must_run_on_all_devices = True
|
|
|
|
if devices:
|
|
local_devices = [d for d in devices if d.process_index == xb.process_index()]
|
|
assert len(local_devices) > 0
|
|
else:
|
|
local_devices = None # type: ignore
|
|
|
|
sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval
|
|
for axis, aval in safe_zip(in_axes, avals))
|
|
if any(s is not None for s in global_arg_shapes):
|
|
# TODO(skye): we could take this branch unconditionally if we handled
|
|
# grad of global_arg_shapes correctly.
|
|
global_sharded_avals = [
|
|
aval.update(shape=shape) if shape is not None else aval
|
|
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
|
|
else:
|
|
global_sharded_avals = sharded_avals # type: ignore
|
|
logging.vlog(2, "sharded_avals: %s", sharded_avals)
|
|
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
|
|
|
|
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
|
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals, transform_name="pmap")
|
|
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
|
|
|
out_axes = out_axes_thunk()
|
|
assert len(out_sharded_avals) == len(out_axes), (len(out_sharded_avals), len(out_axes))
|
|
|
|
# TODO(skye,mattjj): allow more collectives on multi-host as we test them, but
|
|
# for now raise an error
|
|
if devices is not None:
|
|
is_multi_host_pmap = len(local_devices) != len(devices)
|
|
else:
|
|
is_multi_host_pmap = xb.process_count() > 1
|
|
if is_multi_host_pmap:
|
|
check_multihost_collective_allowlist(jaxpr)
|
|
|
|
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
|
|
jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
|
|
num_local_replicas = axis_size * jaxpr_replicas
|
|
num_global_replicas = global_axis_size * jaxpr_replicas
|
|
|
|
(arg_parts, out_parts, num_partitions, local_arg_parts, local_out_parts,
|
|
local_num_partitions) = _find_partitions(jaxpr)
|
|
|
|
if local_num_partitions is None:
|
|
local_num_partitions = num_partitions
|
|
|
|
if local_arg_parts is None:
|
|
local_arg_parts = arg_parts
|
|
if local_out_parts is None:
|
|
local_out_parts = out_parts
|
|
|
|
logging.vlog(2, "num_replicas: %d num_local_replicas: %d",
|
|
num_global_replicas, num_local_replicas)
|
|
logging.vlog(2, "num_partitions: %d local_num_partitions: %d",
|
|
num_partitions, local_num_partitions)
|
|
logging.vlog(2, "arg_parts: %s", arg_parts)
|
|
logging.vlog(2, "local_arg_parts: %s", local_arg_parts)
|
|
logging.vlog(2, "out_parts: %s", out_parts)
|
|
logging.vlog(2, "local_out_parts: %s", local_out_parts)
|
|
logging.vlog(2, "devices: %s", devices)
|
|
logging.vlog(2, "local_devices: %s", local_devices)
|
|
|
|
num_local_shards = num_local_replicas * local_num_partitions
|
|
num_global_shards = num_global_replicas * num_partitions
|
|
|
|
if (xb.process_count() > 1 and must_run_on_all_devices and
|
|
num_local_shards != xb.local_device_count()):
|
|
if num_local_shards == axis_size:
|
|
raise ValueError(
|
|
f"On multi-host platforms, the input to pmapped functions must have "
|
|
f"leading axis size equal to the number of local devices if no "
|
|
f"`devices` argument is specified. Got axis_size={axis_size}, "
|
|
f"num_local_devices={xb.local_device_count()}")
|
|
else:
|
|
raise ValueError(
|
|
f"On multi-host platforms, pmapped functions must run across all "
|
|
f"devices, i.e. num_replicas * num_partitions should equal the "
|
|
f"number of local devices. Got num_replicas={num_local_replicas}, "
|
|
f"num_partitions={num_partitions}, and "
|
|
f"num_local_devices={xb.local_device_count()}")
|
|
|
|
if no_nested_sharding and (jaxpr_replicas > 1 or num_partitions > 1):
|
|
raise ValueError(
|
|
f"On multi-host platforms, pmapped functions that both have `devices` "
|
|
f"specified and contain an inner_pmap or sharded_jit must specify an "
|
|
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
|
|
f"{jaxpr_replicas} and nested_partitions={num_partitions}")
|
|
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
|
logging.log(log_priority,
|
|
f"Compiling {fun.__name__} ({id(fun)}) for {num_global_shards} "
|
|
f"devices with args {avals}. (num_replicas={num_global_replicas}"
|
|
f" num_partitions={num_partitions})")
|
|
|
|
axis_env = xla.AxisEnv(num_global_replicas, (axis_name,), (global_axis_size,))
|
|
|
|
tuple_args = len(global_sharded_avals) > 100 # pass long arg lists as tuple for TPU
|
|
|
|
c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
|
|
xla_consts = map(partial(xb.constant, c), consts)
|
|
replicated_args = [axis is None for axis in in_axes]
|
|
xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args,
|
|
replicated=replicated_args,
|
|
partitions=arg_parts,
|
|
donated_invars=donated_invars)
|
|
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
|
out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend_name, axis_env, xla_consts,
|
|
extend_name_stack(wrap_name(name, 'pmap')), *xla_args)
|
|
build_out_tuple = partial(xops.Tuple, c, out_nodes)
|
|
if out_parts is not None:
|
|
out_tuple = xb.with_sharding(c, out_parts, build_out_tuple)
|
|
else:
|
|
out_tuple = build_out_tuple()
|
|
backend = xb.get_backend(backend_name)
|
|
if backend.platform in ("gpu", "tpu"):
|
|
donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
|
|
built = c.Build(out_tuple)
|
|
|
|
if devices is None:
|
|
if num_global_shards > xb.device_count(backend):
|
|
msg = ("compiling computation that requires {} logical devices, but only {} XLA "
|
|
"devices are available (num_replicas={}, num_partitions={})")
|
|
raise ValueError(msg.format(num_global_shards, xb.device_count(backend),
|
|
num_global_replicas, num_partitions))
|
|
|
|
# On a single host, we use the platform's default device assignment to
|
|
# potentially take advantage of device locality. On multiple hosts, the
|
|
# default device assignment may interleave different hosts' replicas,
|
|
# violating pmap's semantics where data is sharded across replicas in
|
|
# row-major order. Instead, manually create a device assignment that ensures
|
|
# each host is responsible for a continguous set of replicas.
|
|
if num_global_shards > num_local_shards:
|
|
# TODO(skye): use a locality-aware assignment that satisfies the above
|
|
# constraint.
|
|
devices = [d for process_index in range(xb.process_count())
|
|
for d in xb.local_devices(process_index)]
|
|
else:
|
|
devices = xb.get_backend(backend).get_default_device_assignment(
|
|
num_global_replicas, num_partitions)
|
|
else:
|
|
if num_local_shards != len(local_devices):
|
|
local_devices_str = ", ".join(map(str, local_devices))
|
|
if num_local_shards == axis_size:
|
|
raise ValueError(
|
|
f"Leading axis size of input to pmapped function must equal the "
|
|
f"number of local devices passed to pmap. Got axis_size="
|
|
f"{axis_size}, num_local_devices={len(local_devices)}.\n(Local "
|
|
f"devices available to pmap: {local_devices_str})")
|
|
else:
|
|
raise ValueError(
|
|
f"pmapped function requires {num_local_shards} local devices to "
|
|
f"run due to nested pmapped or other parallel functions, but only "
|
|
f"{len(local_devices)} are available.\n(outer axis size: "
|
|
f"{axis_size}, local devices available to pmap: "
|
|
f"{local_devices_str})")
|
|
if num_global_shards != len(devices):
|
|
raise ValueError("compiling computation that creates %s shards, "
|
|
"but %s devices were specified" %
|
|
(num_global_shards, len(devices)))
|
|
|
|
# 'devices' may be 1D or 2D at this point (e.g.
|
|
# get_default_device_assignment() returns 2D assignment, caller may have
|
|
# provided 1D list of devices).
|
|
device_assignment = tree_map(lambda d: d.id, devices)
|
|
# Convert to 2D in case it's 1D and we have > 1 partitions.
|
|
device_assignment = np.array(device_assignment).reshape(
|
|
(num_global_replicas, num_partitions))
|
|
# TODO(b/162356737): Enabling SPMD partitioning causes issues with some
|
|
# non-partitioned workloads, so disable unless needed.
|
|
use_spmd_partitioning = num_partitions > 1
|
|
compile_options = xb.get_compile_options(
|
|
num_replicas=num_global_replicas,
|
|
num_partitions=num_partitions,
|
|
device_assignment=device_assignment,
|
|
use_spmd_partitioning=use_spmd_partitioning,
|
|
)
|
|
compile_options.parameter_is_tupled_arguments = tuple_args
|
|
compiled = xla.backend_compile(backend, built, compile_options)
|
|
|
|
local_arg_parts_ = local_arg_parts or [None] * len(avals)
|
|
input_sharding_specs = [
|
|
_pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions,
|
|
parts, aval, in_axis)
|
|
if aval is not core.abstract_unit else None
|
|
for aval, parts, in_axis in safe_zip(sharded_avals, local_arg_parts_, in_axes)]
|
|
input_indices = [spec_to_indices(aval.shape, spec)
|
|
if spec is not None else None
|
|
for aval, spec in safe_zip(avals, input_sharding_specs)]
|
|
handle_args = partial(shard_args, compiled.local_devices(), input_indices)
|
|
nouts = len(out_sharded_avals)
|
|
if out_parts is None:
|
|
out_parts = (None,) * nouts
|
|
if local_out_parts is None:
|
|
local_out_parts = (None,) * nouts
|
|
|
|
local_out_avals = [get_local_aval(aval, parts, lparts)
|
|
for aval, parts, lparts
|
|
in safe_zip(out_sharded_avals, out_parts, local_out_parts)]
|
|
local_unmapped_avals = [core.unmapped_aval(axis_size, out_axis, aval)
|
|
if out_axis is not None else aval
|
|
for aval, out_axis in safe_zip(local_out_avals, out_axes)]
|
|
|
|
out_specs = [_pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions,
|
|
parts, aval, out_axis)
|
|
if aval is not core.abstract_unit else None
|
|
for parts, aval, out_axis in safe_zip(local_out_parts, local_out_avals, out_axes)]
|
|
handle_outs = avals_to_results_handler(
|
|
num_local_replicas, local_num_partitions, out_specs, local_unmapped_avals)
|
|
|
|
if hasattr(backend, "wrap_execute_replicated"):
|
|
return backend.wrap_execute_replicated(compiled, compiled.local_devices(),
|
|
input_indices, input_sharding_specs,
|
|
handle_outs)
|
|
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
|
|
|
|
multi_host_supported_collectives: Set[core.Primitive] = set()
|
|
|
|
|
|
def check_multihost_collective_allowlist(jaxpr):
|
|
used_collectives = set(xla.jaxpr_collectives(jaxpr))
|
|
if not used_collectives.issubset(multi_host_supported_collectives):
|
|
bad_collectives = used_collectives - multi_host_supported_collectives
|
|
msg = "using collectives that aren't supported for multi-host: {}"
|
|
raise TypeError(msg.format(", ".join(map(str, bad_collectives))))
|
|
|
|
|
|
PartitionsOrReplicated = Optional[Tuple[int, ...]]
|
|
|
|
def _find_partitions(jaxpr) -> Tuple[
|
|
Optional[Tuple[PartitionsOrReplicated, ...]],
|
|
Optional[Tuple[PartitionsOrReplicated, ...]],
|
|
int,
|
|
Optional[Tuple[PartitionsOrReplicated, ...]],
|
|
Optional[Tuple[PartitionsOrReplicated, ...]],
|
|
Optional[int]]:
|
|
"""Returns (in_partitions, out_partitions, num_partitions, local_in_parts,
|
|
local_out_parts, local_num_partitions).
|
|
"""
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive.name == "sharded_call":
|
|
if len(jaxpr.eqns) > 1:
|
|
raise NotImplementedError(
|
|
"pmap of sharded_jit + non-sharded operations not yet implemented.")
|
|
num_partitions = reconcile_num_partitions(eqn.params["call_jaxpr"],
|
|
eqn.params["nparts"])
|
|
return (eqn.params["in_parts"],
|
|
eqn.params["out_parts_thunk"](),
|
|
num_partitions,
|
|
eqn.params["local_in_parts"],
|
|
eqn.params["local_out_parts_thunk"](),
|
|
eqn.params["local_nparts"])
|
|
return None, None, 1, None, None, None
|
|
|
|
def reconcile_num_partitions(jaxpr, outer_num_parts: Optional[int]):
|
|
"""Returns the total number of partitions to use.
|
|
|
|
Validates that any inner partitioning matches outer_num_parts if provided, and
|
|
returns the number of partitions to use based on outer_num_parts and any inner
|
|
partitioning.
|
|
"""
|
|
inner_num_parts = _inner_partitions(jaxpr, outer_num_parts)
|
|
if outer_num_parts is None and inner_num_parts is None:
|
|
# No partitions specified anywhere, everything is replicated.
|
|
return 1
|
|
if outer_num_parts is None:
|
|
return inner_num_parts
|
|
return outer_num_parts
|
|
|
|
|
|
def _inner_partitions(jaxpr, expected_num_parts: Optional[int]):
|
|
"""Returns the total number of partitions from PartitionSpecs inside `jaxpr`.
|
|
|
|
Also validates that this number matches `expected_num_parts` if provided.
|
|
"""
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive.name in ["sharding_constraint", "infeed"]:
|
|
parts = eqn.params["partitions"]
|
|
nparts = get_num_partitions(parts)
|
|
if expected_num_parts is None:
|
|
expected_num_parts = nparts
|
|
elif nparts is not None and nparts != expected_num_parts:
|
|
# TODO(skye): raise this error as we trace the jaxpr
|
|
raise ValueError(
|
|
f"with_sharding_constraint with partitions={parts} "
|
|
f"(total partitions: {nparts}) doesn't match expected number of "
|
|
f"partitions: {expected_num_parts}. If these partitions look "
|
|
f"right, check outer sharded_jit and/or other "
|
|
f"with_sharding_constraint calls.")
|
|
else:
|
|
for subjaxpr in core.jaxprs_in_params(eqn.params):
|
|
expected_num_parts = _inner_partitions(subjaxpr, expected_num_parts)
|
|
return expected_num_parts
|
|
|
|
|
|
def get_num_partitions(*partitions):
|
|
partition_specs = tree_flatten(partitions)[0]
|
|
if len(partition_specs) == 0:
|
|
# Everything is specified as replicated (all Nones).
|
|
return None
|
|
num_partitions_set = {np.prod(spec) for spec in partition_specs}
|
|
if len(num_partitions_set) > 1:
|
|
raise ValueError(
|
|
f"All partition specs must use the same number of total partitions, "
|
|
f"got {partitions}, with distinct number of partitions "
|
|
f"{num_partitions_set} (the total number of partitions is the product "
|
|
f"of a partition spec)")
|
|
assert len(num_partitions_set) == 1
|
|
return num_partitions_set.pop()
|
|
|
|
|
|
def get_global_aval(local_aval, global_parts: PartitionsOrReplicated,
|
|
local_parts: PartitionsOrReplicated):
|
|
if local_aval is core.abstract_unit:
|
|
return local_aval
|
|
if global_parts is None:
|
|
return local_aval
|
|
assert local_parts is not None
|
|
global_shape = [dim * _safe_div(ngparts, nlparts)
|
|
for dim, ngparts, nlparts
|
|
in safe_zip(local_aval.shape, global_parts, local_parts)]
|
|
return local_aval.update(shape=global_shape)
|
|
|
|
|
|
def get_local_aval(global_aval, global_parts: PartitionsOrReplicated,
|
|
local_parts: PartitionsOrReplicated):
|
|
if global_aval is core.abstract_unit:
|
|
return global_aval
|
|
if global_parts is None:
|
|
return global_aval
|
|
assert local_parts is not None
|
|
local_shape = [_safe_div(dim, _safe_div(ngparts, nlparts))
|
|
for dim, ngparts, nlparts
|
|
in safe_zip(global_aval.shape, global_parts, local_parts)]
|
|
return global_aval.update(shape=local_shape)
|
|
|
|
|
|
def _safe_div(x, y):
|
|
result, ragged = divmod(x, y)
|
|
assert not ragged, f"{x} % {y} != 0"
|
|
return result
|
|
|
|
|
|
class ResultsHandler:
|
|
__slots__ = ("handlers", "out_specs", "out_indices", "unmapped_local_out_avals")
|
|
|
|
def __init__(self, handlers, out_specs, out_indices, unmapped_local_out_avals):
|
|
self.out_specs = out_specs
|
|
self.out_indices = out_indices
|
|
self.handlers = handlers
|
|
self.unmapped_local_out_avals = unmapped_local_out_avals
|
|
|
|
def __call__(self, out_bufs):
|
|
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]
|
|
|
|
def avals_to_results_handler(nrep, npart, out_specs, unmapped_local_out_avals):
|
|
out_indices = [spec_to_indices(aval.shape, spec)
|
|
if aval is not core.abstract_unit else None
|
|
for aval, spec in safe_zip(unmapped_local_out_avals, out_specs)] # pytype: disable=attribute-error
|
|
handlers = [aval_to_result_handler(spec, idcs, aval)
|
|
for spec, idcs, aval in safe_zip(out_specs, out_indices, unmapped_local_out_avals)]
|
|
|
|
return ResultsHandler(handlers, out_specs, out_indices, unmapped_local_out_avals)
|
|
|
|
def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
|
|
"""Replicates ``val`` across multiple devices.
|
|
|
|
Args:
|
|
val: the value to be replicated.
|
|
axis_size: the length of the output, i.e. the logical number of replicas to
|
|
create. Usually equal to `nrep`, but in the case of nested pmaps, `nrep` may
|
|
be a multiple of `axis_size`.
|
|
nrep: the number of replicas to create. If ``devices`` is set, must be equal
|
|
to ``len(devices)``.
|
|
devices: the devices to replicate across. If None, ``nrep`` will be used to
|
|
generate a default device assignment.
|
|
backend: string specifying which backend to use.
|
|
in_axis: axis along which the value is to be replciated.
|
|
|
|
Returns:
|
|
A ShardedDeviceArray of length `axis_size` where each shard is equal to
|
|
``val``.
|
|
"""
|
|
device_count = (len(devices) if devices else xb.local_device_count(backend))
|
|
if nrep > device_count:
|
|
msg = ("Cannot replicate across %d replicas because only %d local devices "
|
|
"are available." % (nrep, device_count))
|
|
if devices:
|
|
msg += (" (local devices = %s)"
|
|
% ", ".join(map(str, devices)) if devices else str(None))
|
|
raise ValueError(msg)
|
|
|
|
if devices is None:
|
|
assert nrep is not None
|
|
# TODO(skye): use different device assignment on multihost
|
|
devices = xb.get_backend(backend).get_default_device_assignment(nrep)
|
|
assert nrep == len(devices)
|
|
|
|
aval = xla.abstractify(val) # type: ShapedArray
|
|
if in_axis is not None:
|
|
replicated_aval = aval.update(shape=(axis_size,) + aval.shape)
|
|
else:
|
|
replicated_aval = aval
|
|
# TODO(skye): figure out how partitioning should work here
|
|
sharding_spec = _pmap_sharding_spec(nrep, axis_size, 1, None, aval, in_axis)
|
|
device_buffers = device_put(val, devices, replicate=True)
|
|
return ShardedDeviceArray(replicated_aval, sharding_spec, device_buffers)
|
|
|
|
def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, map_axis: Optional[int]):
|
|
"""Sharding spec for arguments or results of a pmap.
|
|
Args:
|
|
nrep: number of local XLA replicas (product of local axis sizes)
|
|
axis_size: local axis size for outer pmap
|
|
npart: total number of XLA partitions (required by sharded_jit calls)
|
|
parts: the partitioning of the value or None
|
|
sharded_aval: the aval of the value inside the outer pmap, an instance of
|
|
a ShapedArray.
|
|
map_axis: the axis along which the value is mapped in the outer pmap
|
|
Returns:
|
|
A ShardingSpec.
|
|
"""
|
|
assert isinstance(sharded_aval, ShapedArray), sharded_aval
|
|
replication_factor, ragged = divmod(nrep, axis_size)
|
|
assert not ragged
|
|
# get the sharding spec from inner sharded_jits as if we weren't in a pmap
|
|
pspec = partitioned_sharding_spec(npart, parts, sharded_aval)
|
|
maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),)
|
|
if map_axis is not None:
|
|
sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis])
|
|
def shift_sharded_axis(a: MeshDimAssignment):
|
|
if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis:
|
|
return ShardedAxis(a.axis + 1)
|
|
return a
|
|
# replication_factor represents the product of inner pmaps, so it goes
|
|
# after the outer pmapped axis at index 0
|
|
return ShardingSpec(
|
|
sharding=tuple_insert(pspec.sharding, map_axis, Unstacked(axis_size)),
|
|
mesh_mapping=it.chain([ShardedAxis(sharded_in_axis)],
|
|
maybe_replicate,
|
|
map(shift_sharded_axis, pspec.mesh_mapping)))
|
|
else:
|
|
return ShardingSpec(
|
|
sharding=pspec.sharding,
|
|
mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping)
|
|
|
|
def partitioned_sharding_spec(num_partitions: int,
|
|
partitions: Optional[Sequence[int]],
|
|
aval) -> ShardingSpec:
|
|
if partitions is None:
|
|
maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),)
|
|
return ShardingSpec(
|
|
sharding=[_UNSHARDED_INSTANCE] * len(aval.shape),
|
|
mesh_mapping=maybe_replicate)
|
|
else:
|
|
assert len(partitions) == len(aval.shape)
|
|
return ShardingSpec(
|
|
# Chunked expects a list of integers
|
|
sharding=map(Chunked, [[x] for x in partitions]),
|
|
mesh_mapping=map(ShardedAxis, range(len(partitions))))
|
|
|
|
|
|
def execute_replicated(compiled, backend, in_handler, out_handler, *args):
|
|
input_bufs = in_handler(args)
|
|
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
|
|
if xla.needs_check_special():
|
|
for bufs in out_bufs:
|
|
xla.check_special("parallel computation", bufs)
|
|
return out_handler(out_bufs)
|
|
|
|
|
|
xla_pmap_p = core.MapPrimitive('xla_pmap')
|
|
xla_pmap = xla_pmap_p.bind
|
|
xla_pmap_p.def_impl(xla_pmap_impl)
|
|
|
|
# Set param update handlers to update `donated_invars` just like xla_call_p
|
|
pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
|
|
ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p]
|
|
ad.call_transpose_param_updaters[xla_pmap_p] = \
|
|
ad.call_transpose_param_updaters[xla.xla_call_p]
|
|
|
|
def _pmap_translation_rule(c, axis_env,
|
|
in_nodes, name_stack, axis_name, axis_size,
|
|
global_axis_size, devices, name,
|
|
call_jaxpr, *, backend=None, in_axes, out_axes,
|
|
donated_invars, global_arg_shapes):
|
|
del donated_invars # Unused.
|
|
# We in-line here rather than generating a Call HLO as in the xla_call
|
|
# translation rule just because the extra tuple stuff is a pain.
|
|
if axis_env.names and devices is not None:
|
|
raise ValueError("Nested pmap with explicit devices argument.")
|
|
if global_axis_size is None:
|
|
global_axis_size = axis_size
|
|
new_env = xla.extend_axis_env(axis_env, axis_name, global_axis_size)
|
|
# Shard the in_nodes that are mapped
|
|
in_avals = [v.aval for v in call_jaxpr.invars]
|
|
in_nodes_sharded = (
|
|
_xla_shard(c, aval, new_env, in_node, in_axis) if in_axis is not None else in_node
|
|
for aval, in_node, in_axis in safe_zip(in_avals, in_nodes, in_axes))
|
|
|
|
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
|
sharded_outs = xla.jaxpr_subcomp(
|
|
c, call_jaxpr, backend, new_env, (),
|
|
extend_name_stack(name_stack, wrap_name(name, 'pmap')), *in_nodes_sharded)
|
|
out_avals = [v.aval for v in call_jaxpr.outvars]
|
|
outs = [_xla_unshard(c, aval, new_env, out_axis, shard, backend=backend)
|
|
for aval, out_axis, shard in safe_zip(out_avals, out_axes, sharded_outs)]
|
|
return xops.Tuple(c, outs)
|
|
|
|
xla.call_translations[xla_pmap_p] = _pmap_translation_rule
|
|
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
|
|
|
|
def _xla_shard(c, aval, axis_env, x, in_axis):
|
|
if aval is core.abstract_unit:
|
|
return x
|
|
elif aval is core.abstract_token:
|
|
return x
|
|
elif isinstance(aval, ShapedArray):
|
|
dims = list(c.get_shape(x).dimensions())
|
|
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
|
|
idxs = [zero] * (len(dims) - 1)
|
|
idxs.insert(in_axis, _unravel_index(c, axis_env))
|
|
dims_unsqueezed = dims.copy()
|
|
dims_unsqueezed[in_axis] = 1
|
|
dims_squeezed = dims.copy()
|
|
dims_squeezed.pop(in_axis)
|
|
return xops.Reshape(xops.DynamicSlice(x, idxs, dims_unsqueezed), dims_squeezed)
|
|
else:
|
|
raise TypeError((aval, c.get_shape(x)))
|
|
|
|
# TODO(b/110096942): more efficient gather
|
|
def _xla_unshard(c, aval, axis_env, out_axis, x, backend):
|
|
if aval is core.abstract_unit:
|
|
return x
|
|
elif aval is core.abstract_token:
|
|
return x
|
|
elif isinstance(aval, ShapedArray):
|
|
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
|
convert_bool = (np.issubdtype(aval.dtype, np.bool_)
|
|
and xb.get_backend(backend).platform in ('cpu', 'gpu'))
|
|
if convert_bool:
|
|
x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32))
|
|
|
|
xla_shape = c.get_shape(x)
|
|
dims = list(xla_shape.dimensions())
|
|
padded = xops.Broadcast(xb.constant(c, np.array(0, xla_shape.numpy_dtype())),
|
|
[axis_env.sizes[-1]] + dims)
|
|
zero = xb.constant(c, np.zeros((), dtype=np.uint32))
|
|
idxs = [_unravel_index(c, axis_env)] + [zero] * len(dims)
|
|
padded = xops.DynamicUpdateSlice(padded, xops.Reshape(x, [1] + dims), idxs)
|
|
replica_groups_protos = xc.make_replica_groups(
|
|
xla.axis_groups(axis_env, axis_env.names[-1]))
|
|
out = xops.CrossReplicaSum(padded, replica_groups_protos)
|
|
if out_axis != 0:
|
|
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
|
|
perm = list(range(1, len(dims)))
|
|
perm.insert(out_axis, 0)
|
|
out = xops.Transpose(out, perm)
|
|
|
|
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
|
if convert_bool:
|
|
nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32)))
|
|
out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(np.bool_))
|
|
return out
|
|
else:
|
|
raise TypeError((aval, c.get_shape(x)))
|
|
|
|
def _unravel_index(c, axis_env):
|
|
div = xb.constant(c, np.array(axis_env.nreps // prod(axis_env.sizes), np.uint32))
|
|
mod = xb.constant(c, np.array(axis_env.sizes[-1], np.uint32))
|
|
return xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
|
|
|
# ------------------- xmap -------------------
|
|
|
|
MeshAxisName = Any
|
|
"""
|
|
ArrayMapping specifies how an ndarray should map to mesh axes.
|
|
|
|
Note that the ordering is crucial for the cases when this mapping is non-injective
|
|
(i.e. when multiple mesh axes map to the same positional axis). Then, the
|
|
order of entries of the mapping determines a major-to-minor order on mesh axes,
|
|
according to which chunks of the value along the repeated dimension will be assigned.
|
|
|
|
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
|
|
The second dimension of the value would get chunked into 6 pieces, and assigned to the
|
|
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
|
|
that would mean that a flat list of chunks would get assigned to a flattened list of
|
|
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
|
|
mesh devices ndarray would have to be transposed before flattening and assignment.
|
|
"""
|
|
ArrayMapping = OrderedDictType[MeshAxisName, int]
|
|
|
|
class Mesh:
|
|
__slots__ = ('devices', 'axis_names', '_hash')
|
|
|
|
def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]):
|
|
assert devices.ndim == len(axis_names)
|
|
# TODO: Make sure that devices are unique? At least with the quick and
|
|
# dirty check that the array size is not larger than the number of
|
|
# available devices?
|
|
self.devices = devices.copy()
|
|
self.devices.flags.writeable = False
|
|
self.axis_names = tuple(axis_names)
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, Mesh):
|
|
return False
|
|
return (self.axis_names == other.axis_names and
|
|
np.array_equal(self.devices, other.devices))
|
|
|
|
def __hash__(self):
|
|
if not hasattr(self, '_hash'):
|
|
self._hash = hash((self.axis_names, tuple(self.devices.flat)))
|
|
return self._hash
|
|
|
|
def __setattr__(self, name, value):
|
|
if hasattr(self, name):
|
|
raise RuntimeError("Cannot reassign attributes of immutable mesh objects")
|
|
super().__setattr__(name, value)
|
|
|
|
@property
|
|
def shape(self):
|
|
return OrderedDict((name, size) for name, size in safe_zip(self.axis_names, self.devices.shape))
|
|
|
|
@property
|
|
def size(self):
|
|
return np.prod(list(self.shape.values()))
|
|
|
|
# TODO: This is pretty expensive to compute. Cache this on the mesh object?
|
|
@property
|
|
def local_mesh(self):
|
|
if not self.devices.ndim:
|
|
return self
|
|
process_index = xb.process_index()
|
|
is_local_device = np.vectorize(
|
|
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
|
|
subcube_indices = []
|
|
# We take the smallest slice of each dimension that doesn't skip any local device.
|
|
for axis in range(self.devices.ndim):
|
|
other_axes = tuple_delete(tuple(range(self.devices.ndim)), axis)
|
|
# NOTE: This re-reduces over many axes multiple times, so we could definitely
|
|
# optimize it, but I hope it won't be a bottleneck anytime soon.
|
|
local_slices = is_local_device.any(other_axes, keepdims=False)
|
|
nonzero_indices = np.flatnonzero(local_slices)
|
|
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
|
|
subcube_indices.append(slice(start, end + 1))
|
|
subcube_indices = tuple(subcube_indices)
|
|
# We only end up with all conditions being true if the local devices formed a
|
|
# subcube of the full array. This is because we were biased towards taking a
|
|
# "hull" spanned by the devices, and in case the local devices don't form a
|
|
# subcube that hull will contain non-local devices.
|
|
assert is_local_device[subcube_indices].all()
|
|
return Mesh(self.devices[subcube_indices], self.axis_names)
|
|
|
|
def __getitem__(self, new_axes):
|
|
axis_pos = {name: i for i, name in enumerate(self.axis_names)}
|
|
new_devices = self.devices.transpose(tuple(axis_pos[axis] for axis in new_axes) +
|
|
tuple(axis_pos[axis] for axis in self.axis_names
|
|
if axis not in new_axes))
|
|
new_devices = new_devices[(slice(None),) * len(new_axes) +
|
|
(0,) * (len(self.axis_names) - len(new_axes))]
|
|
return Mesh(new_devices, new_axes)
|
|
|
|
@property
|
|
def device_ids(self):
|
|
return np.vectorize(lambda d: d.id, otypes=[int])(self.devices)
|
|
|
|
def __repr__(self):
|
|
return f"Mesh({self.devices!r}, {self.axis_names!r})"
|
|
|
|
|
|
def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
|
|
if aval is core.abstract_unit:
|
|
return aval
|
|
assert isinstance(aval, ShapedArray)
|
|
shape = list(aval.shape)
|
|
named_shape = dict(aval.named_shape)
|
|
for name, axis in in_axes.items():
|
|
assert shape[axis] % axis_sizes[name] == 0
|
|
assert name not in named_shape
|
|
named_shape[name] = axis_sizes[name]
|
|
shape[axis] //= axis_sizes[name]
|
|
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
|
|
|
def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval):
|
|
if aval is core.abstract_unit:
|
|
return aval
|
|
assert isinstance(aval, ShapedArray)
|
|
shape = list(aval.shape)
|
|
named_shape = dict(aval.named_shape)
|
|
for name, axis in out_axes.items():
|
|
shape[axis] *= axis_sizes[name]
|
|
named_shape.pop(name, None) # The name might be missing --- it's a broadcast.
|
|
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
|
|
|
def mesh_callable(fun: lu.WrappedFun,
|
|
transformed_name: str,
|
|
backend_name: Optional[str],
|
|
mesh: Mesh,
|
|
in_axes: Sequence[ArrayMapping],
|
|
out_axes_thunk: Callable[[], Sequence[ArrayMapping]],
|
|
donated_invars: Sequence[bool],
|
|
spmd_lowering: bool,
|
|
*local_in_untiled_avals,
|
|
tile_by_mesh_axes: bool):
|
|
local_mesh = mesh.local_mesh
|
|
global_axis_sizes = mesh.shape
|
|
local_axis_sizes = local_mesh.shape
|
|
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
|
logging.log(log_priority,
|
|
f"Compiling {fun.__name__} ({id(fun)}) for {tuple(global_axis_sizes.items())} "
|
|
f"mesh with args {local_in_untiled_avals}. Argument mapping: {in_axes}.")
|
|
|
|
# 1. Trace to jaxpr and preprocess/verify it
|
|
in_tiled_avals = [tile_aval_nd(local_axis_sizes, aval_in_axes, aval)
|
|
for aval, aval_in_axes in safe_zip(local_in_untiled_avals, in_axes)]
|
|
if spmd_lowering:
|
|
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
|
|
for name, size in reversed(mesh.shape.items()):
|
|
if tile_by_mesh_axes:
|
|
fun = batching.vtile(fun,
|
|
tuple(a.get(name, None) for a in in_axes),
|
|
tuple(a.get(name, None) for a in out_axes_thunk()),
|
|
tile_size=size, axis_name=name)
|
|
global_in_untiled_avals = [untile_aval_nd(global_axis_sizes, aval_in_axes, aval)
|
|
for aval, aval_in_axes in safe_zip(in_tiled_avals, in_axes)]
|
|
in_jaxpr_avals = global_in_untiled_avals
|
|
else:
|
|
in_jaxpr_avals = in_tiled_avals
|
|
with core.extend_axis_env_nd(mesh.shape.items()):
|
|
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
|
|
out_axes = out_axes_thunk()
|
|
assert len(out_axes) == len(out_jaxpr_avals)
|
|
if spmd_lowering:
|
|
global_out_untiled_avals = out_jaxpr_avals
|
|
out_tiled_avals = [tile_aval_nd(global_axis_sizes, aval_out_axes, aval)
|
|
for aval, aval_out_axes in safe_zip(global_out_untiled_avals, out_axes)]
|
|
else:
|
|
out_tiled_avals = out_jaxpr_avals
|
|
local_out_untiled_avals = [untile_aval_nd(local_axis_sizes, aval_out_axes, aval)
|
|
for aval, aval_out_axes in safe_zip(out_tiled_avals, out_axes)]
|
|
_sanitize_mesh_jaxpr(jaxpr)
|
|
if local_mesh.shape != mesh.shape:
|
|
check_multihost_collective_allowlist(jaxpr)
|
|
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
|
|
|
# 3. Build up the HLO
|
|
c = xb.make_computation_builder(f"xmap_{fun.__name__}")
|
|
xla_consts = map(partial(xb.constant, c), consts)
|
|
tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU
|
|
in_partitions: Optional[List]
|
|
if spmd_lowering:
|
|
replicated_args = [False] * len(in_jaxpr_avals)
|
|
global_sharding_spec = mesh_sharding_specs(global_axis_sizes, mesh.axis_names)
|
|
in_partitions = [global_sharding_spec(aval, aval_in_axes).sharding_proto()
|
|
if aval is not core.abstract_unit else None
|
|
for aval, aval_in_axes in safe_zip(global_in_untiled_avals, in_axes)]
|
|
out_partitions = [global_sharding_spec(aval, aval_out_axes).sharding_proto()
|
|
for aval, aval_out_axes in safe_zip(global_out_untiled_avals, out_axes)]
|
|
partitions_proto = True
|
|
axis_env = xla.AxisEnv(nreps=1, names=(), sizes=()) # All named axes have been vmapped
|
|
else:
|
|
replicated_args = [not axis for axis in in_axes]
|
|
in_partitions = None
|
|
partitions_proto = False
|
|
axis_env = xla.AxisEnv(nreps=mesh.size,
|
|
names=tuple(global_axis_sizes.keys()),
|
|
sizes=tuple(global_axis_sizes.values()))
|
|
xla_args, donated_invars = xla._xla_callable_args(
|
|
c, in_jaxpr_avals, tuple_args,
|
|
replicated=replicated_args,
|
|
partitions=in_partitions,
|
|
partitions_proto=partitions_proto,
|
|
donated_invars=donated_invars)
|
|
with core.extend_axis_env_nd(mesh.shape.items()):
|
|
out_nodes = xla.jaxpr_subcomp(
|
|
c, jaxpr, backend_name, axis_env, xla_consts,
|
|
extend_name_stack(wrap_name(transformed_name, 'xmap')), *xla_args)
|
|
backend = xb.get_backend(backend_name)
|
|
if spmd_lowering:
|
|
out_partitions_t = xb.tuple_sharding_proto(out_partitions)
|
|
out_tuple = xb.with_sharding_proto(c, out_partitions_t, xops.Tuple, c, out_nodes)
|
|
else:
|
|
out_tuple = xops.Tuple(c, out_nodes)
|
|
if backend.platform in ("gpu", "tpu"):
|
|
xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
|
|
# TODO: Warn about unused donations?
|
|
built = c.Build(out_tuple)
|
|
|
|
return compile_and_wrap_mesh_hlo(built, backend, mesh, local_in_untiled_avals,
|
|
local_out_untiled_avals, in_axes, out_axes,
|
|
spmd_lowering, tuple_args)
|
|
|
|
|
|
def compile_and_wrap_mesh_hlo(computation: xc.XlaComputation, backend,
|
|
mesh: Mesh,
|
|
local_in_untiled_avals: Sequence[ShapedArray],
|
|
local_out_untiled_avals: Sequence[ShapedArray],
|
|
in_axes: Sequence[ArrayMapping],
|
|
out_axes: Sequence[ArrayMapping],
|
|
spmd_lowering: bool, tuple_args: bool):
|
|
local_mesh = mesh.local_mesh
|
|
local_axis_sizes = local_mesh.shape
|
|
if spmd_lowering:
|
|
num_replicas, num_partitions = 1, mesh.size
|
|
num_local_replicas, num_local_partitions = 1, local_mesh.size
|
|
else:
|
|
num_replicas, num_partitions = mesh.size, 1
|
|
num_local_replicas, num_local_partitions = local_mesh.size, 1
|
|
device_assignment = mesh.device_ids.reshape((num_replicas, num_partitions))
|
|
compile_options = xb.get_compile_options(
|
|
num_replicas=num_replicas,
|
|
num_partitions=num_partitions,
|
|
device_assignment=device_assignment,
|
|
use_spmd_partitioning=spmd_lowering,
|
|
)
|
|
compile_options.parameter_is_tupled_arguments = tuple_args
|
|
compiled = xla.backend_compile(backend, computation, compile_options)
|
|
|
|
local_sharding_spec = mesh_sharding_specs(local_axis_sizes, mesh.axis_names)
|
|
local_input_specs = [local_sharding_spec(aval, aval_in_axes)
|
|
if aval is not core.abstract_unit else None
|
|
for aval, aval_in_axes in safe_zip(local_in_untiled_avals, in_axes)]
|
|
input_indices = [spec_to_indices(aval.shape, spec)
|
|
if spec is not None else None
|
|
for aval, spec in safe_zip(local_in_untiled_avals, local_input_specs)]
|
|
handle_args = partial(shard_args, compiled.local_devices(), input_indices)
|
|
|
|
local_output_specs = [local_sharding_spec(aval, aval_out_axes)
|
|
for aval, aval_out_axes in safe_zip(local_out_untiled_avals, out_axes)]
|
|
handle_outs = avals_to_results_handler(num_local_replicas, num_local_partitions,
|
|
local_output_specs, local_out_untiled_avals)
|
|
|
|
if hasattr(backend, "wrap_execute_replicated"):
|
|
return backend.wrap_execute_replicated(compiled, compiled.local_devices(),
|
|
input_indices, local_input_specs,
|
|
handle_outs)
|
|
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
|
|
|
|
_forbidden_primitives = {
|
|
'xla_pmap': 'pmap',
|
|
'sharded_call': 'sharded_jit',
|
|
}
|
|
def _sanitize_mesh_jaxpr(jaxpr):
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive.name in _forbidden_primitives:
|
|
raise RuntimeError(f"Nesting {_forbidden_primitives[eqn.primitive.name]} "
|
|
f"inside xmaps not supported!")
|
|
core.traverse_jaxpr_params(_sanitize_mesh_jaxpr, eqn.params)
|
|
|
|
|
|
def mesh_sharding_specs(axis_sizes, axis_names):
|
|
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
|
|
# NOTE: This takes in the non-sharded avals!
|
|
def mk_sharding_spec(aval, aval_axes):
|
|
sharding = [_UNSHARDED_INSTANCE] * len(aval.shape)
|
|
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
|
|
next_sharded_axis = 0
|
|
aval_shape = list(aval.shape)
|
|
# NOTE: sorted is stable, which is important when multiple resources
|
|
# map to the same axis.
|
|
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
|
|
assert aval_shape[axis] % axis_sizes[name] == 0, (axis_sizes[name], aval.shape[axis])
|
|
aval_shape[axis] //= axis_sizes[name]
|
|
if isinstance(sharding[axis], NoSharding):
|
|
sharding[axis] = Chunked([])
|
|
sharding[axis] = Chunked(sharding[axis].chunks + [axis_sizes[name]])
|
|
assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \
|
|
"Value mapped to the same mesh axis twice"
|
|
mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis)
|
|
next_sharded_axis += 1
|
|
return ShardingSpec(sharding, mesh_mapping)
|
|
return mk_sharding_spec
|
|
|
|
@contextmanager
|
|
def maybe_extend_axis_env(*args, **kwargs):
|
|
with core.extend_axis_env(*args, **kwargs):
|
|
yield
|
|
|
|
|
|
class DynamicAxisEnvFrame(object):
|
|
__slots__ = ["name", "pmap_trace", "hard_size"]
|
|
def __init__(self, name, pmap_trace, hard_size):
|
|
self.name = name
|
|
self.pmap_trace = pmap_trace
|
|
self.hard_size = hard_size
|
|
|
|
class DynamicAxisEnv(list):
|
|
def __contains__(self, axis_name):
|
|
return axis_name in (frame.name for frame in self)
|
|
|
|
def __getitem__(self, axis_name):
|
|
if axis_name not in self:
|
|
raise NameError("unbound axis name: {}".format(axis_name))
|
|
for frame in reversed(self):
|
|
if frame.name == axis_name:
|
|
return frame
|
|
|
|
raise AssertionError
|
|
|
|
@property
|
|
def sizes(self):
|
|
return tuple(frame.hard_size for frame in self)
|
|
|
|
@property
|
|
def nreps(self):
|
|
return prod(frame.hard_size for frame in self)
|
|
|
|
class _ThreadLocalState(threading.local):
|
|
def __init__(self):
|
|
self.dynamic_axis_env = DynamicAxisEnv()
|
|
|
|
_thread_local_state = _ThreadLocalState()
|
|
|
|
def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool=False) -> List[xb.xla_client._xla.PyLocalBuffer]:
|
|
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
|
|
if replicate:
|
|
return list(it.chain.from_iterable(xla.device_put(x, device) for device in devices))
|
|
else:
|
|
return list(it.chain.from_iterable(xla.device_put(val, device) for val, device in safe_zip(x, devices)))
|