mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
3026 lines
121 KiB
Python
3026 lines
121 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Implementation of pmap and related functionality."""
|
|
|
|
# A ShardingSpec describes at a high level how a logical array is sharded across
|
|
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
|
|
# describe how to shard inputs to a parallel computation). spec_to_indices()
|
|
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
|
|
# how the sharded array is "laid out" across devices. Given a sequence of
|
|
# devices, we shard the data across the devices in row-major order, with
|
|
# replication treated as an extra inner dimension.
|
|
#
|
|
# For example, given the logical data array [1, 2, 3, 4], if we were to
|
|
# partition this array 4 ways with a replication factor of 2, for a total of 8
|
|
# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
|
|
#
|
|
# This encoding is assumed by various parts of the system, e.g. generating
|
|
# replica groups for collective operations.
|
|
|
|
from __future__ import annotations
|
|
|
|
import enum
|
|
from contextlib import contextmanager, ContextDecorator
|
|
from collections import defaultdict, OrderedDict
|
|
import dataclasses
|
|
from functools import partial, lru_cache
|
|
import itertools as it
|
|
import operator as op
|
|
import threading
|
|
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
|
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
|
|
TYPE_CHECKING)
|
|
import sys
|
|
|
|
from absl import logging
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax import core
|
|
from jax import linear_util as lu
|
|
from jax.core import ConcreteArray, ShapedArray
|
|
from jax.errors import JAXTypeError
|
|
from jax.interpreters import ad
|
|
from jax.interpreters import batching
|
|
from jax.interpreters import mlir
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax.interpreters import xla
|
|
from jax.tree_util import tree_flatten, tree_map
|
|
|
|
from jax._src import abstract_arrays
|
|
from jax._src import api_util
|
|
from jax._src import device_array
|
|
from jax._src import source_info_util
|
|
from jax._src import util
|
|
from jax._src import dispatch
|
|
from jax._src import profiler
|
|
from jax._src import stages
|
|
from jax._src.abstract_arrays import array_types
|
|
from jax._src.config import config
|
|
from jax._src.lib import xla_bridge as xb
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.lib import xla_extension_version
|
|
from jax._src.lib import pmap_lib
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import mhlo
|
|
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
|
|
new_name_stack, wrap_name, assert_unreachable,
|
|
tuple_insert, tuple_delete, distributed_debug_log)
|
|
|
|
if TYPE_CHECKING:
|
|
from jax.experimental.sharding import MeshPspecSharding, XLACompatibleSharding
|
|
|
|
# Built in Python lists don't support weak refs but subclasses of lists do.
|
|
class WeakRefList(list):
|
|
pass
|
|
|
|
if sys.version_info >= (3, 8):
|
|
from functools import cached_property as maybe_cached_property
|
|
else:
|
|
maybe_cached_property = property
|
|
|
|
if sys.version_info >= (3, 9):
|
|
OrderedDictType = OrderedDict
|
|
else:
|
|
OrderedDictType = Dict
|
|
|
|
xe = xc._xla
|
|
|
|
unsafe_map, map = map, safe_map # type: ignore
|
|
|
|
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
|
|
|
|
NoSharding = pmap_lib.NoSharding
|
|
Chunked = pmap_lib.Chunked
|
|
Unstacked = pmap_lib.Unstacked
|
|
|
|
ShardedAxis = pmap_lib.ShardedAxis
|
|
Replicated = pmap_lib.Replicated
|
|
|
|
_UNSHARDED_INSTANCE = NoSharding()
|
|
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
|
|
MeshDimAssignment = Union[ShardedAxis, Replicated]
|
|
ShardingSpec = pmap_lib.ShardingSpec
|
|
|
|
MeshAxisName = Any
|
|
OpShardingType = Any
|
|
|
|
|
|
def sharding_spec_mesh_shape(self):
|
|
sharded_axis_sizes = []
|
|
for sharding in self.sharding:
|
|
if isinstance(sharding, NoSharding):
|
|
continue
|
|
elif isinstance(sharding, Unstacked):
|
|
sharded_axis_sizes.append(sharding.size)
|
|
elif isinstance(sharding, Chunked):
|
|
sharded_axis_sizes.extend(sharding.chunks)
|
|
else:
|
|
assert_unreachable(sharding)
|
|
return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis) else a.replicas
|
|
for a in self.mesh_mapping)
|
|
|
|
|
|
def _get_logical_mesh_ids(mesh_shape):
|
|
return np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
|
|
|
|
|
|
def sharding_spec_sharding_proto(self, special_axes: Mapping[int, OpShardingType] = {}):
|
|
"""Converts a ShardingSpec to an OpSharding proto.
|
|
|
|
See
|
|
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
|
|
for details on the OpSharding proto.
|
|
Unfortunately the semantics are not very well described in the proto spec, but the code here might help:
|
|
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
|
|
"""
|
|
mesh_shape = cast(Tuple[int, ...], self.mesh_shape)
|
|
mesh = _get_logical_mesh_ids(self.mesh_shape)
|
|
|
|
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
|
|
replicated_maxes = [] # lists mesh axis identifiers to replicate over
|
|
for maxis, assignment in enumerate(self.mesh_mapping):
|
|
if isinstance(assignment, Replicated):
|
|
replicated_maxes.append((maxis, assignment.replicas))
|
|
elif isinstance(assignment, ShardedAxis):
|
|
sharded_axes[assignment.axis] = maxis
|
|
else:
|
|
assert_unreachable(assignment)
|
|
|
|
proto = xc.OpSharding()
|
|
if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes:
|
|
proto.type = xc.OpSharding.Type.REPLICATED
|
|
return proto
|
|
else:
|
|
proto.type = xc.OpSharding.Type.OTHER
|
|
|
|
mesh_permutation = []
|
|
new_mesh_shape = []
|
|
next_sharded_axis = 0
|
|
for axis, sharding in enumerate(self.sharding):
|
|
if isinstance(sharding, NoSharding):
|
|
new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over
|
|
elif isinstance(sharding, Chunked):
|
|
for nchunks in sharding.chunks:
|
|
maxis = sharded_axes[next_sharded_axis]
|
|
assert mesh_shape[maxis] == nchunks
|
|
mesh_permutation.append(maxis)
|
|
next_sharded_axis += 1
|
|
new_mesh_shape.append(int(np.prod(sharding.chunks)))
|
|
elif isinstance(sharding, Unstacked):
|
|
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
|
|
else:
|
|
assert_unreachable(sharding)
|
|
|
|
# Create a partial sharding proto if tensor is replicated or partitioned
|
|
# specially over some mesh axes.
|
|
if replicated_maxes:
|
|
last_tile_dims = []
|
|
axes_by_type: Dict[OpShardingType, List[MeshAxisName]] = {}
|
|
size_by_type: Dict[OpShardingType, int] = defaultdict(lambda: 1)
|
|
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys()))
|
|
for axis, size in replicated_maxes:
|
|
ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED)
|
|
axes_by_type.setdefault(ty, []).append(axis)
|
|
size_by_type[ty] *= size
|
|
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
|
|
last_tile_dims.append(ty)
|
|
new_mesh_shape.append(size_by_type[ty])
|
|
mesh_permutation.extend(axes)
|
|
proto.last_tile_dims = last_tile_dims
|
|
|
|
proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape)
|
|
proto.tile_assignment_dimensions = list(proto_mesh.shape)
|
|
proto.tile_assignment_devices = list(proto_mesh.flat)
|
|
return proto
|
|
|
|
|
|
def _get_num_ways_dim_sharded(op_sharding: xc.OpSharding) -> Tuple[Sequence[int], int]:
|
|
partitions = op_sharding.tile_assignment_dimensions
|
|
if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]:
|
|
replicate_on_last_tile_dim = True
|
|
else:
|
|
replicate_on_last_tile_dim = op_sharding.replicate_on_last_tile_dim
|
|
if op_sharding.last_tile_dims:
|
|
raise NotImplementedError("Unhandled OpSharding type. Please open a bug report!")
|
|
num_replicas = 1
|
|
if replicate_on_last_tile_dim:
|
|
num_replicas = partitions[-1]
|
|
partitions = partitions[:-1]
|
|
return partitions, num_replicas
|
|
|
|
|
|
def _op_sharding_to_numpy_indices(
|
|
op_sharding: xc.OpSharding, shape: Tuple[int, ...],
|
|
num_devices: int) -> np.ndarray:
|
|
indices = np.empty(num_devices, dtype=np.object_)
|
|
|
|
# num_devices is required as an argument when op_sharding is
|
|
# REPLICATED. `jax.device_count()` cannot be used because you can create
|
|
# an opsharding with less number of devices than `jax.device_count()`.
|
|
if is_op_sharding_replicated(op_sharding):
|
|
indices.fill((slice(None),) * len(shape))
|
|
return indices
|
|
|
|
assert num_devices == len(op_sharding.tile_assignment_devices)
|
|
|
|
partitions, num_replicas = _get_num_ways_dim_sharded(op_sharding)
|
|
assert len(partitions) == len(shape), (len(partitions), len(shape))
|
|
|
|
axis_indices: List[Sequence[Index]] = []
|
|
for dim, n_shards in zip(shape, partitions):
|
|
if n_shards == 1:
|
|
axis_indices.append([slice(None)])
|
|
elif n_shards > 1:
|
|
shard_size, ragged = divmod(dim, n_shards)
|
|
assert not ragged, (dim, n_shards, dim)
|
|
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
|
|
for i in range(n_shards)])
|
|
else:
|
|
raise AssertionError('Unrecognized number of shards. Please file a bug!')
|
|
|
|
device_it = iter(op_sharding.tile_assignment_devices)
|
|
for i, idxs in enumerate(it.product(*axis_indices)):
|
|
for _ in range(num_replicas):
|
|
indices[next(device_it)] = idxs
|
|
return indices
|
|
|
|
|
|
def op_sharding_to_indices(op_sharding: xc.OpSharding, shape: Tuple[int, ...],
|
|
num_devices: int) -> Tuple[Tuple[slice, ...], ...]:
|
|
indices = _op_sharding_to_numpy_indices(op_sharding, shape, num_devices)
|
|
return tuple(indices.flat)
|
|
|
|
|
|
def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
|
"""Returns NumPy-style indices corresponding to a sharding spec.
|
|
|
|
Args:
|
|
shape: The shape of the logical array being sharded.
|
|
|
|
Returns:
|
|
An ndarray with the same shape as the logical mesh (as derived form
|
|
`mesh_mapping`). Each entry is a NumPy-style index selecting the subset of
|
|
the data array to be placed on a corresponding device. The indices can be
|
|
ints, slice objects with step=1, or tuples of those.
|
|
"""
|
|
assert len(shape) == len(self.sharding), (shape, self.sharding)
|
|
|
|
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding)
|
|
# Take the op sharding indices generation route for pjit/xmap cases.
|
|
if not has_unstacked:
|
|
op_sharding_proto = sharding_spec_sharding_proto(self)
|
|
return _op_sharding_to_numpy_indices(
|
|
op_sharding_proto, shape, prod(self.mesh_shape)).reshape(self.mesh_shape)
|
|
|
|
axis_indices: List[Sequence[Index]] = []
|
|
shard_indices_shape = []
|
|
for dim, sharding in enumerate(self.sharding):
|
|
axis_size = shape[dim]
|
|
if isinstance(sharding, NoSharding):
|
|
axis_indices.append([slice(None)])
|
|
# NOTE: We don't append unsharded dimensions to shard_indices_shape here,
|
|
# because they do not appear in the mesh mapping.
|
|
elif isinstance(sharding, Unstacked):
|
|
assert axis_size == sharding.size, f'{axis_size} != {sharding.size}'
|
|
axis_indices.append(range(axis_size))
|
|
shard_indices_shape.append(axis_size)
|
|
elif isinstance(sharding, Chunked):
|
|
total_chunks = int(np.prod(sharding.chunks))
|
|
shard_size, ragged = divmod(axis_size, total_chunks)
|
|
assert not ragged, (axis_size, total_chunks, dim)
|
|
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
|
|
for i in range(total_chunks)])
|
|
shard_indices_shape.extend(sharding.chunks)
|
|
else:
|
|
assert_unreachable(sharding)
|
|
|
|
# shard_indices is an ndarray representing the sharded axes of the logical array,
|
|
# with each dimension having size equal to the number of shards across the corresponding
|
|
# logical array dimension, and each element containing the multi-dimensional index that
|
|
# is used to extract the corresponding shard of the logical array.
|
|
shard_indices = np.empty([prod(shard_indices_shape)], dtype=np.object_)
|
|
for i, idxs in enumerate(it.product(*axis_indices)):
|
|
shard_indices[i] = idxs
|
|
shard_indices = shard_indices.reshape(shard_indices_shape)
|
|
|
|
# Ensure that each sharded axis is used exactly once in the mesh mapping
|
|
num_sharded_dim = len(shard_indices_shape)
|
|
sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)]
|
|
assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and
|
|
len(sharded_dim_perm) == num_sharded_dim)
|
|
# Replicate/reorder the indices according to the mesh mapping
|
|
replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated))
|
|
replica_dim, sharded_dim = it.count(0), iter(sharded_dim_perm)
|
|
perm = [next(replica_dim) if isinstance(a, Replicated) else
|
|
len(replica_sizes) + next(sharded_dim)
|
|
for a in self.mesh_mapping]
|
|
return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape)
|
|
.transpose(perm))
|
|
|
|
def sharding_spec_repr(self):
|
|
return f'ShardingSpec({self.sharding}, {self.mesh_mapping})'
|
|
|
|
|
|
ShardingSpec.mesh_shape = property(sharding_spec_mesh_shape)
|
|
ShardingSpec.sharding_proto = sharding_spec_sharding_proto
|
|
ShardingSpec.indices = sharding_spec_indices
|
|
# mypy raises: error: Cannot assign to a method [assignment]
|
|
ShardingSpec.__repr__ = sharding_spec_repr # type: ignore
|
|
# Do not pollute the namespace
|
|
del sharding_spec_mesh_shape, sharding_spec_indices, sharding_spec_repr
|
|
|
|
def spec_to_indices(shape: Tuple[int, ...],
|
|
spec: ShardingSpec) -> Tuple[Index, ...]:
|
|
"""Returns numpy-style indices corresponding to a sharding spec.
|
|
|
|
Each index describes a shard of the array. The order of the indices is the
|
|
same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out
|
|
row-major).
|
|
|
|
Args:
|
|
shape: The shape of the logical array being sharded.
|
|
spec: Describes how the array is sharded and how the shards are assigned to
|
|
the logical mesh.
|
|
|
|
Returns:
|
|
A tuple of length equal to the size of the mesh (inferred as the product of
|
|
sharded dimension sizes and all replication factors). Each element is an
|
|
int, a slice object with step=1, or a tuple thereof, to be treated as an
|
|
index into the full logical array.
|
|
"""
|
|
return tuple(spec.indices(shape).flat) # type: ignore
|
|
|
|
|
|
### util
|
|
|
|
def identity(x): return x
|
|
|
|
def _shard_arg(arg, devices, arg_indices):
|
|
"""Returns a list of size len(devices) containing per-device buffers.
|
|
|
|
For the C++ pmap path, we fallback to Python (this function) to shard
|
|
arguments that are not supported by the C++ `ShardArg`.
|
|
|
|
Arrgs:
|
|
arg: The Python argument.
|
|
devices: The list of devices to shard over.
|
|
arg_indices: A list of `len(devices)` indices to use to shard the argument.
|
|
"""
|
|
if isinstance(arg, ShardedDeviceArray) and arg_indices == arg.indices:
|
|
# The shard_arg_handlers allow an extensible set of types to be sharded, but
|
|
# inline handling for ShardedDeviceArray as a special case for performance
|
|
# NOTE: we compare indices instead of sharding_spec because
|
|
# pmap_benchmark.pmap_shard_args_benchmark indicates this is faster.
|
|
return [
|
|
buf if buf.device() == d else buf.copy_to_device(d)
|
|
for d, buf in zip(devices, arg.device_buffers)
|
|
]
|
|
else:
|
|
arg = xla.canonicalize_dtype(arg)
|
|
return shard_arg_handlers[type(arg)](arg, devices, arg_indices)
|
|
|
|
|
|
@profiler.annotate_function
|
|
def shard_args(devices: Sequence[xb.xla_client.Device],
|
|
indices: Sequence[Sequence[Index]],
|
|
args) -> Sequence[Sequence[xb.xla_client.Buffer]]:
|
|
"""Shard each argument data array along its leading axis.
|
|
|
|
Args:
|
|
devices: sequence of Devices mapping replica index to a physical device.
|
|
indices: sequence of the same length as `args` describing how each arg
|
|
should be sharded/replicated across `devices`. Each element in `indices`
|
|
is the same length as `devices`.
|
|
args: a sequence of JaxTypes representing arguments to be sharded according
|
|
to `indices` and placed on `devices`.
|
|
|
|
Returns:
|
|
A list of length matching args, containing lists of per-device buffers
|
|
for each argument.
|
|
"""
|
|
return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
|
|
|
|
|
|
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {}
|
|
def _shard_array(x, devices, indices):
|
|
return device_put([x[i] for i in indices], devices)
|
|
for _t in array_types:
|
|
shard_arg_handlers[_t] = _shard_array
|
|
|
|
def _shard_device_array(x, devices, indices):
|
|
start_indices, limit_indices, removed_dims = unzip3(
|
|
_as_slice_indices(x, idx) for idx in indices)
|
|
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
|
return device_put(shards, devices)
|
|
for t in device_array.device_array_types:
|
|
shard_arg_handlers[t] = _shard_device_array
|
|
|
|
|
|
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
|
|
# from the input ShardingSpec, rather than the indices. However, this would
|
|
# require duplicating the ordering logic of spec_to_indices, which is more
|
|
# subtle and more likely to change than the index logic we have to support here.
|
|
def _as_slice_indices(arr: device_array.DeviceArrayProtocol, idx: Index) -> Tuple[
|
|
Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]:
|
|
"""Returns start_indices, limit_indices, removed_dims"""
|
|
start_indices = [0] * arr.ndim
|
|
limit_indices = list(arr.shape)
|
|
removed_dims = []
|
|
|
|
tuple_idx = idx if isinstance(idx, tuple) else (idx,)
|
|
for dim, sub_idx in enumerate(tuple_idx):
|
|
if isinstance(sub_idx, int):
|
|
start_indices[dim] = sub_idx
|
|
limit_indices[dim] = sub_idx + 1
|
|
removed_dims.append(dim)
|
|
elif sub_idx == slice(None):
|
|
continue
|
|
else:
|
|
assert isinstance(sub_idx, slice), sub_idx
|
|
assert isinstance(sub_idx.start, int), sub_idx
|
|
assert isinstance(sub_idx.stop, int), sub_idx
|
|
start_indices[dim] = sub_idx.start
|
|
limit_indices[dim] = sub_idx.stop
|
|
|
|
return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore
|
|
|
|
|
|
def shard_aval(size, axis: int, aval):
|
|
try:
|
|
return shard_aval_handlers[type(aval)](size, axis, aval)
|
|
except KeyError as err:
|
|
raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err
|
|
shard_aval_handlers: Dict[Type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
|
|
def _shard_abstract_array(size, axis: int, x):
|
|
try:
|
|
if x.shape[axis] != size:
|
|
raise ValueError(f"Axis size {size} does not match dimension {axis} of "
|
|
f"shape {x.shape}")
|
|
except IndexError:
|
|
raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None
|
|
return x.update(shape=tuple_delete(x.shape, axis))
|
|
shard_aval_handlers[ShapedArray] = _shard_abstract_array
|
|
|
|
|
|
class _AUTOAxisResource:
|
|
pass
|
|
AUTO = _AUTOAxisResource()
|
|
|
|
def _is_auto(x):
|
|
return isinstance(x, _AUTOAxisResource)
|
|
|
|
|
|
class _UnspecifiedValue:
|
|
pass
|
|
_UNSPECIFIED = _UnspecifiedValue()
|
|
|
|
def _is_unspecified(x):
|
|
return isinstance(x, _UnspecifiedValue)
|
|
|
|
"""
|
|
ArrayMapping specifies how an ndarray should map to mesh axes.
|
|
|
|
Note that the ordering is crucial for the cases when this mapping is non-injective
|
|
(i.e. when multiple mesh axes map to the same positional axis). Then, the
|
|
order of entries of the mapping determines a major-to-minor order on mesh axes,
|
|
according to which chunks of the value along the repeated dimension will be assigned.
|
|
|
|
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
|
|
The second dimension of the value would get chunked into 6 pieces, and assigned to the
|
|
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
|
|
that would mean that a flat list of chunks would get assigned to a flattened list of
|
|
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
|
|
mesh devices ndarray would have to be transposed before flattening and assignment.
|
|
"""
|
|
ArrayMapping = OrderedDictType[MeshAxisName, int]
|
|
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, _AUTOAxisResource,
|
|
_UnspecifiedValue]
|
|
|
|
|
|
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
|
if not array_mapping:
|
|
return PartitionSpec()
|
|
max_index = -1
|
|
reverse_map = defaultdict(list)
|
|
for axis, index in array_mapping.items():
|
|
reverse_map[index].append(axis)
|
|
if index > max_index:
|
|
max_index = index
|
|
partitions = tuple(tuple(reverse_map[i]) if reverse_map[i] else None
|
|
for i in range(max_index + 1))
|
|
return PartitionSpec(*partitions)
|
|
|
|
|
|
def local_aval_to_result_handler(
|
|
aval: core.AbstractValue,
|
|
sharding_spec: Optional[ShardingSpec],
|
|
indices: Optional[Tuple[Index]],
|
|
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
|
"""Returns a function for handling the raw buffers of a single output aval.
|
|
|
|
Args:
|
|
aval: The local output AbstractValue.
|
|
sharding_spec: Indicates how the output is sharded across devices, or None
|
|
for non-array avals.
|
|
indices: The pre-computed result of spec_to_indices, or None for non-array
|
|
avals.
|
|
|
|
Returns:
|
|
A function for handling the Buffers that will eventually be produced
|
|
for this output. The function will return an object suitable for returning
|
|
to the user, e.g. a ShardedDeviceArray.
|
|
"""
|
|
try:
|
|
return local_result_handlers[type(aval)](aval, sharding_spec, indices)
|
|
except KeyError as err:
|
|
raise TypeError(
|
|
f"No pxla_result_handler for type: {type(aval)}") from err
|
|
|
|
PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]]
|
|
local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
|
def sda_array_result_handler(aval: ShapedArray, sharding_spec, indices):
|
|
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
|
|
indices)
|
|
local_result_handlers[ShapedArray] = sda_array_result_handler
|
|
local_result_handlers[ConcreteArray] = sda_array_result_handler
|
|
|
|
|
|
class OutputType(enum.Enum):
|
|
Array = 0
|
|
GlobalDeviceArray = 1
|
|
|
|
|
|
def global_aval_to_result_handler(
|
|
aval: core.AbstractValue, out_sharding,
|
|
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
|
"""Returns a function for handling the raw buffers of a single output aval.
|
|
|
|
Args:
|
|
aval: The global output AbstractValue.
|
|
out_axis_resources: A PartitionSpec specifying the sharding of outputs.
|
|
Used for creating GSDAs.
|
|
global_mesh: The global device mesh that generated this output. Used
|
|
for creating GSDAs.
|
|
|
|
Returns:
|
|
A function for handling the Buffers that will eventually be produced
|
|
for this output. The function will return an object suitable for returning
|
|
to the user, e.g. a ShardedDeviceArray.
|
|
"""
|
|
if config.jax_array:
|
|
output_type = OutputType.Array
|
|
elif config.jax_parallel_functions_output_gda:
|
|
output_type = OutputType.GlobalDeviceArray
|
|
try:
|
|
return global_result_handlers[(type(aval), output_type)](aval, out_sharding)
|
|
except KeyError as err:
|
|
raise TypeError(
|
|
f"No pxla_result_handler for type: {type(aval)}") from err
|
|
|
|
global_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaResultHandler] = {}
|
|
|
|
### lazy device-memory persistence and result handling
|
|
|
|
# TODO(jblespiau): Consider removing this option.
|
|
_USE_CPP_SDA = True
|
|
|
|
|
|
def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
|
|
if sharded_dim is not None:
|
|
sharded_aval = aval.update(
|
|
shape=aval.shape[:sharded_dim] + aval.shape[sharded_dim+1:])
|
|
if sharded_dim_size is None:
|
|
sharded_dim_size = aval.shape[sharded_dim]
|
|
else:
|
|
assert sharded_dim_size is not None
|
|
sharded_aval = aval
|
|
|
|
return _pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None,
|
|
sharded_aval, sharded_dim)
|
|
|
|
|
|
def make_sharded_device_array(
|
|
aval: ShapedArray,
|
|
sharding_spec: Optional[ShardingSpec],
|
|
# Any is for JAX extensions implementing their own buffer.
|
|
device_buffers: List[Union[Any, xb.xla_client.Buffer]],
|
|
indices: Optional[Tuple[Index, ...]] = None,
|
|
):
|
|
"""Returns a ShardedDeviceArray implementation based on arguments.
|
|
|
|
Returns either a C++ SDA or a Python DeviceArray when the buffers are not
|
|
JAX buffers.
|
|
|
|
Args:
|
|
aval: The `ShapedArray` for this array.
|
|
sharding_spec: If `None`, assumes a pmap-style ShardedDeviceArrays over the
|
|
first dimension.
|
|
device_buffers: If a list of Jax `Buffer` objects, a C++ SDA will be
|
|
returned (if the version is high enough). Otherwise, a Python object will
|
|
be returned, for JAX extensions not implementing the C++ API.
|
|
indices: For caching purposes, will be computed if `None`.
|
|
"""
|
|
if sharding_spec is None:
|
|
sharding_spec = _create_pmap_sharding_spec(aval)
|
|
|
|
if indices is None:
|
|
indices = spec_to_indices(aval.shape, sharding_spec)
|
|
|
|
if (_USE_CPP_SDA and
|
|
(not device_buffers or
|
|
isinstance(device_buffers[0], xb.xla_client.Buffer))):
|
|
return pmap_lib.ShardedDeviceArray.make(
|
|
aval, sharding_spec, device_buffers,
|
|
indices, aval.weak_type)
|
|
|
|
return _ShardedDeviceArray(aval, sharding_spec, device_buffers, indices)
|
|
|
|
|
|
if _USE_CPP_SDA:
|
|
ShardedDeviceArrayBase = pmap_lib.ShardedDeviceArrayBase # type: ignore
|
|
# We want the C++ SDA to extend the DeviceArrayBase. We want this both to
|
|
# benefit from its methods, and to have isinstance(x, DeviceArray) return true
|
|
ShardedDeviceArrayBase.__bases__ = ((device_array.DeviceArray,) + # type: ignore
|
|
ShardedDeviceArrayBase.__bases__)
|
|
_SDA_BASE_CLASS = pmap_lib.ShardedDeviceArrayBase # type: ignore
|
|
else:
|
|
_SDA_BASE_CLASS: Type[device_array.DeviceArray] = device_array.DeviceArray # type: ignore
|
|
|
|
|
|
class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
|
|
"""A ShardedDeviceArray is an ndarray sharded across devices.
|
|
|
|
The purpose of a ShardedDeviceArray is to reduce the number of transfers when
|
|
executing replicated computations, by allowing results to persist on the
|
|
devices that produced them. That way dispatching a similarly replicated
|
|
computation that consumes the same sharded memory layout does not incur any
|
|
transfers.
|
|
|
|
A ShardedDeviceArray represents one logical ndarray value, and simulates the
|
|
behavior of an ndarray so that it can be treated by user code as an ndarray;
|
|
that is, it is only an optimization to reduce transfers.
|
|
|
|
Attributes:
|
|
aval: A ShapedArray indicating the shape and dtype of this array.
|
|
sharding_spec: describes how this array is sharded across `device_buffers`.
|
|
device_buffers: the buffers containing the data for this array. Each buffer
|
|
is the same shape and on a different device. Buffers are in row-major
|
|
order, with replication treated as an extra innermost dimension.
|
|
indices: the result of spec_to_indices(sharding_spec). Can optionally be
|
|
precomputed for efficiency. A list the same length as
|
|
`device_buffers`. Each index indicates what portion of the full array is
|
|
stored in the corresponding device buffer, i.e. `array[indices[i]] ==
|
|
device_buffers[i].to_py()`.
|
|
"""
|
|
__slots__ = [
|
|
"aval", "device_buffers", "sharding_spec", "indices",
|
|
"_one_replica_buffer_indices", "_npy_value"
|
|
]
|
|
|
|
def __init__(self,
|
|
aval: ShapedArray,
|
|
sharding_spec: ShardingSpec,
|
|
device_buffers: List[xb.xla_client.Buffer],
|
|
indices: Optional[Tuple[Index, ...]] = None):
|
|
super().__init__()
|
|
|
|
# TODO(skye): assert invariants. Keep performance in mind though.
|
|
if indices is None:
|
|
indices = spec_to_indices(aval.shape, sharding_spec)
|
|
|
|
self.aval = aval
|
|
self.device_buffers = device_buffers
|
|
self.sharding_spec = sharding_spec
|
|
self.indices = indices
|
|
self._npy_value = None
|
|
self._one_replica_buffer_indices = None
|
|
if config.jax_enable_checks:
|
|
assert type(aval) is ShapedArray
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.aval.shape
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self.aval.dtype
|
|
|
|
@property
|
|
def size(self):
|
|
return prod(self.aval.shape)
|
|
|
|
@property
|
|
def ndim(self):
|
|
return len(self.aval.shape)
|
|
|
|
def delete(self):
|
|
if self.device_buffers is None:
|
|
return
|
|
for buf in self.device_buffers:
|
|
buf.delete()
|
|
self.device_buffers = None
|
|
self._npy_value = None
|
|
|
|
|
|
def _one_replica_buffer_indices(indices: Tuple[Index, ...]):
|
|
"""Returns a set of buffer-indices containing one complete copy of the array."""
|
|
one_replica_indices = []
|
|
seen_index_hashes = set()
|
|
for i, index in enumerate(indices):
|
|
hashed_index = _hashable_index(index)
|
|
if hashed_index not in seen_index_hashes:
|
|
one_replica_indices.append(i)
|
|
seen_index_hashes.add(hashed_index)
|
|
return one_replica_indices
|
|
|
|
|
|
def _sda_one_replica_buffer_indices(self):
|
|
"""Indices of buffers containing one complete copy of the array data."""
|
|
if self._one_replica_buffer_indices is None:
|
|
self._one_replica_buffer_indices = _one_replica_buffer_indices(self.indices)
|
|
return self._one_replica_buffer_indices
|
|
|
|
|
|
def _sda_copy_to_host_async(self):
|
|
for buffer_index in self.one_replica_buffer_indices:
|
|
self.device_buffers[buffer_index].copy_to_host_async()
|
|
|
|
|
|
def _sda_check_if_deleted(self):
|
|
if self.device_buffers is None:
|
|
raise ValueError("ShardedDeviceArray has been deleted.")
|
|
|
|
|
|
def _sda_block_until_ready(self):
|
|
self._check_if_deleted()
|
|
for buf in self.device_buffers:
|
|
buf.block_until_ready()
|
|
return self
|
|
|
|
|
|
def _sda_value(self):
|
|
if self._npy_value is None:
|
|
self.copy_to_host_async()
|
|
npy_value = np.empty(self.aval.shape, self.aval.dtype)
|
|
for i in self.one_replica_buffer_indices:
|
|
npy_value[self.indices[i]] = self.device_buffers[i].to_py()
|
|
self._npy_value = npy_value
|
|
return self._npy_value
|
|
|
|
|
|
def _sda__getitem__(self, idx):
|
|
self._check_if_deleted()
|
|
if not isinstance(idx, tuple):
|
|
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
|
|
else:
|
|
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
|
|
if self._npy_value is None:
|
|
try:
|
|
buf_idx = self.indices.index(cidx)
|
|
except ValueError:
|
|
buf_idx = None
|
|
if buf_idx is not None:
|
|
buf = self.device_buffers[buf_idx]
|
|
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
|
|
return device_array.make_device_array(aval, None, buf)
|
|
return super(self.__class__, self).__getitem__(idx)
|
|
|
|
|
|
def _sda__iter__(self):
|
|
if self.ndim == 0:
|
|
raise TypeError("iteration over a 0-d array") # same as numpy error
|
|
else:
|
|
return (self[i] for i in range(self.shape[0]))
|
|
|
|
def _sda__reversed__(self):
|
|
if self.ndim == 0:
|
|
raise TypeError("iteration over a 0-d array") # same as numpy error
|
|
else:
|
|
return (self[i] for i in range(self.shape[0] - 1, -1, -1))
|
|
|
|
|
|
for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]:
|
|
setattr(sda, "one_replica_buffer_indices",
|
|
property(_sda_one_replica_buffer_indices))
|
|
setattr(sda, "copy_to_host_async", _sda_copy_to_host_async)
|
|
setattr(sda, "_check_if_deleted", _sda_check_if_deleted)
|
|
setattr(sda, "block_until_ready", _sda_block_until_ready)
|
|
setattr(sda, "_value", property(_sda_value))
|
|
setattr(sda, "__getitem__", _sda__getitem__)
|
|
setattr(sda, "__iter__", _sda__iter__)
|
|
setattr(sda, "__reversed__", _sda__reversed__)
|
|
|
|
del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async,
|
|
_sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__)
|
|
|
|
|
|
ShardedDeviceArray: Type[object]
|
|
if _USE_CPP_SDA:
|
|
ShardedDeviceArray = pmap_lib.ShardedDeviceArrayBase
|
|
else:
|
|
ShardedDeviceArray = _ShardedDeviceArray
|
|
|
|
|
|
def _hashable_index(idx):
|
|
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
|
|
|
|
# The fast path is handled directly in shard_args().
|
|
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
|
|
def _shard_sharded_device_array_slow_path(x, devices, indices):
|
|
candidates = defaultdict(list)
|
|
for buf, idx in safe_zip(x.device_buffers, x.indices):
|
|
candidates[_hashable_index(idx)].append(buf)
|
|
|
|
bufs = []
|
|
for idx, device in safe_zip(indices, devices):
|
|
# Look up all buffers that contain the correct slice of the logical array.
|
|
candidates_list = candidates[_hashable_index(idx)]
|
|
if not candidates_list:
|
|
# This array isn't sharded correctly. Reshard it via host roundtrip.
|
|
# TODO(skye): more efficient reshard?
|
|
return shard_arg_handlers[type(x._value)](x._value, devices, indices)
|
|
# Try to find a candidate buffer already on the correct device,
|
|
# otherwise copy one of them.
|
|
for buf in candidates_list:
|
|
if buf.device() == device:
|
|
bufs.append(buf)
|
|
break
|
|
else:
|
|
bufs.append(buf.copy_to_device(device))
|
|
return bufs
|
|
|
|
|
|
def _sharded_device_array_mlir_constant_handler(val, canonicalize_types=True):
|
|
return mlir.ir_constants(np.asarray(val),
|
|
canonicalize_types=canonicalize_types)
|
|
|
|
def _register_handlers_for_sharded_device_array(sda):
|
|
shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path
|
|
mlir.register_constant_handler(sda,
|
|
_sharded_device_array_mlir_constant_handler)
|
|
|
|
core.pytype_aval_mappings[sda] = abstract_arrays.canonical_concrete_aval
|
|
dispatch.device_put_handlers[sda] = dispatch._device_put_array
|
|
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
|
|
xla.canonicalize_dtype_handlers[sda] = identity
|
|
api_util._shaped_abstractify_handlers[sda] = op.attrgetter("aval")
|
|
|
|
_register_handlers_for_sharded_device_array(_ShardedDeviceArray)
|
|
_register_handlers_for_sharded_device_array(pmap_lib.ShardedDeviceArray)
|
|
|
|
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
|
|
|
|
def xla_pmap_impl(fun: lu.WrappedFun, *args,
|
|
backend: Optional[str],
|
|
axis_name: core.AxisName,
|
|
axis_size: int,
|
|
global_axis_size: Optional[int],
|
|
devices: Optional[Sequence[Any]],
|
|
name: str,
|
|
in_axes: Sequence[Optional[int]],
|
|
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
|
donated_invars: Sequence[bool],
|
|
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
|
|
abstract_args = unsafe_map(xla.abstractify, args)
|
|
compiled_fun, fingerprint = parallel_callable(
|
|
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
|
|
in_axes, out_axes_thunk, donated_invars, global_arg_shapes,
|
|
*abstract_args)
|
|
|
|
# Don't re-abstractify args unless logging is enabled for performance.
|
|
if config.jax_distributed_debug:
|
|
distributed_debug_log(("Running pmapped function", name),
|
|
("python function", fun.f),
|
|
("devices", devices),
|
|
("abstract args", map(xla.abstractify, args)),
|
|
("fingerprint", fingerprint))
|
|
return compiled_fun(*args)
|
|
|
|
|
|
@lu.cache
|
|
def parallel_callable(fun: lu.WrappedFun,
|
|
backend_name: Optional[str],
|
|
axis_name: core.AxisName,
|
|
axis_size: int,
|
|
global_axis_size: Optional[int],
|
|
devices: Optional[Sequence[Any]],
|
|
name: str,
|
|
in_axes: Sequence[Optional[int]],
|
|
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
|
donated_invars: Sequence[bool],
|
|
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
|
|
*avals):
|
|
pmap_computation = lower_parallel_callable(
|
|
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
|
|
in_axes, out_axes_thunk, donated_invars, global_arg_shapes, avals)
|
|
pmap_executable = pmap_computation.compile()
|
|
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ParallelCallableInfo:
|
|
name: str
|
|
backend: xla.Backend
|
|
axis_name: core.AxisName
|
|
axis_size: int
|
|
global_axis_size: Optional[int]
|
|
devices: Optional[Sequence[xla.Device]]
|
|
in_axes: Iterable[Optional[int]]
|
|
out_axes_thunk: Callable[[], Sequence[Optional[int]]]
|
|
avals: Sequence[core.AbstractValue]
|
|
|
|
@maybe_cached_property
|
|
def local_devices(self):
|
|
if self.devices:
|
|
out = [d for d in self.devices
|
|
if d.process_index == xb.process_index(self.backend)]
|
|
assert len(out) > 0
|
|
else:
|
|
out = None # type: ignore
|
|
return out
|
|
|
|
@maybe_cached_property
|
|
def out_axes(self):
|
|
return self.out_axes_thunk()
|
|
|
|
|
|
class ShardInfo(NamedTuple):
|
|
sharded_avals: Sequence[core.AbstractValue]
|
|
out_sharded_avals: Sequence[core.AbstractValue]
|
|
global_sharded_avals: Sequence[core.AbstractValue]
|
|
num_local_shards: int
|
|
num_global_shards: int
|
|
|
|
|
|
class ReplicaInfo(NamedTuple):
|
|
jaxpr_replicas: int
|
|
num_local_replicas: int
|
|
num_global_replicas: int
|
|
|
|
|
|
def find_replicas(jaxpr, axis_size, global_axis_size):
|
|
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
|
|
jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr)
|
|
num_local_replicas = axis_size * jaxpr_replicas
|
|
num_global_replicas = global_axis_size * jaxpr_replicas
|
|
return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas)
|
|
|
|
|
|
def should_tuple_args(shards: ShardInfo):
|
|
# tuplify long arg lists for TPU
|
|
return len(shards.global_sharded_avals) > 100
|
|
|
|
|
|
def stage_parallel_callable(
|
|
pci: ParallelCallableInfo,
|
|
fun: lu.WrappedFun,
|
|
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
|
|
sharded_avals = tuple(
|
|
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
|
|
for axis, aval in safe_zip(pci.in_axes, pci.avals))
|
|
if any(s is not None for s in global_arg_shapes):
|
|
# TODO(skye): we could take this branch unconditionally if we handled
|
|
# grad of global_arg_shapes correctly.
|
|
global_sharded_avals = [
|
|
aval.update(shape=shape) if shape is not None else aval
|
|
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
|
|
else:
|
|
global_sharded_avals = sharded_avals # type: ignore
|
|
|
|
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
|
|
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
|
"for pmap in {elapsed_time} sec"):
|
|
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
|
|
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
|
|
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
|
|
|
assert len(out_sharded_avals) == len(pci.out_axes), (
|
|
len(out_sharded_avals), len(pci.out_axes))
|
|
|
|
# TODO(skye,mattjj): allow more collectives on multi-host as we test them, but
|
|
# for now raise an error
|
|
if pci.devices is not None:
|
|
is_multi_host_pmap = len(pci.local_devices) != len(pci.devices)
|
|
else:
|
|
is_multi_host_pmap = xb.process_count(pci.backend) > 1
|
|
if is_multi_host_pmap:
|
|
check_multihost_collective_allowlist(jaxpr)
|
|
|
|
replicas = find_replicas(jaxpr, pci.axis_size, pci.global_axis_size)
|
|
parts = find_partitions(jaxpr)
|
|
|
|
num_local_shards = replicas.num_local_replicas * parts.local_num_partitions
|
|
num_global_shards = replicas.num_global_replicas * parts.num_partitions
|
|
|
|
shards = ShardInfo(
|
|
sharded_avals, out_sharded_avals, global_sharded_avals,
|
|
num_local_shards, num_global_shards)
|
|
|
|
return jaxpr, consts, replicas, parts, shards
|
|
|
|
|
|
def _shardings_to_mlir_shardings(
|
|
shardings: Optional[Sequence[PartitionsOrReplicated]]
|
|
) -> Optional[Sequence[Optional[xc.OpSharding]]]:
|
|
if shardings is None:
|
|
return None
|
|
return [xla.sharding_to_proto(s) for s in shardings]
|
|
|
|
@profiler.annotate_function
|
|
def lower_parallel_callable(
|
|
fun: lu.WrappedFun,
|
|
backend_name: Optional[str],
|
|
axis_name: core.AxisName,
|
|
axis_size: int,
|
|
global_axis_size: Optional[int],
|
|
devices: Optional[Sequence[xla.Device]],
|
|
name: str,
|
|
in_axes: Iterable[Optional[int]],
|
|
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
|
donated_invars: Sequence[bool],
|
|
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
|
|
avals: Sequence[core.AbstractValue]):
|
|
if devices is not None and len(devices) == 0:
|
|
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
|
|
|
|
# Determine global_axis_size for use in AxisEnv.
|
|
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
|
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
|
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
|
|
if (xb.process_count() == 1 and global_axis_size is not None and
|
|
global_axis_size != axis_size):
|
|
raise ValueError(
|
|
f"Specified axis_size {global_axis_size} doesn't match received "
|
|
f"axis_size {axis_size}.")
|
|
|
|
if devices is not None and backend_name is None:
|
|
backend = xb.get_device_backend(devices[0])
|
|
else:
|
|
backend = xb.get_backend(backend_name)
|
|
|
|
must_run_on_all_devices = False
|
|
no_nested_sharding = False
|
|
if global_axis_size is None:
|
|
if xb.process_count(backend) == 1:
|
|
global_axis_size = axis_size
|
|
elif devices:
|
|
# This allows each host in a multi-host pmap to run on a different number
|
|
# of devices, but precludes nested sharding (i.e. inner pmaps or
|
|
# sharded_jits).
|
|
global_axis_size = len(devices)
|
|
no_nested_sharding = True
|
|
else:
|
|
# This assumes all hosts run on the same number of devices. We make sure
|
|
# this assumption is true by requiring that the pmap is run on all devices
|
|
# (and making the further assumption that each host has the same number of
|
|
# devices). Nested sharding is ok in this case.
|
|
global_axis_size = axis_size * xb.process_count(backend)
|
|
assert all(
|
|
len(xb.local_devices(process_index, backend)) == xb.local_device_count(backend)
|
|
for process_index in range(xb.process_count(backend)))
|
|
must_run_on_all_devices = True
|
|
|
|
pci = ParallelCallableInfo(
|
|
name, backend, axis_name, axis_size, global_axis_size, devices,
|
|
in_axes, out_axes_thunk, avals)
|
|
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
|
|
pci, fun, global_arg_shapes)
|
|
|
|
if logging.vlog_is_on(2):
|
|
logging.vlog(2, "sharded_avals: %s", shards.sharded_avals)
|
|
logging.vlog(2, "global_sharded_avals: %s", shards.global_sharded_avals)
|
|
logging.vlog(2, "num_replicas: %d num_local_replicas: %d",
|
|
replicas.num_global_replicas, replicas.num_local_replicas)
|
|
logging.vlog(2, "num_partitions: %d local_num_partitions: %d",
|
|
parts.num_partitions, parts.local_num_partitions)
|
|
logging.vlog(2, "arg_parts: %s", parts.arg_parts)
|
|
logging.vlog(2, "local_arg_parts: %s", parts.local_arg_parts)
|
|
logging.vlog(2, "out_parts: %s", parts.out_parts)
|
|
logging.vlog(2, "local_out_parts: %s", parts.local_out_parts)
|
|
logging.vlog(2, "devices: %s", devices)
|
|
logging.vlog(2, "local_devices: %s", pci.local_devices)
|
|
|
|
if (xb.process_count(backend) > 1 and must_run_on_all_devices and
|
|
shards.num_local_shards != xb.local_device_count(backend)):
|
|
if shards.num_local_shards == axis_size:
|
|
raise ValueError(
|
|
f"On multi-host platforms, the input to pmapped functions must have "
|
|
f"leading axis size equal to the number of local devices if no "
|
|
f"`devices` argument is specified. Got axis_size={axis_size}, "
|
|
f"num_local_devices={xb.local_device_count(backend)}")
|
|
else:
|
|
raise ValueError(
|
|
f"On multi-host platforms, pmapped functions must run across all "
|
|
f"devices, i.e. num_replicas * num_partitions should equal the "
|
|
f"number of local devices. Got "
|
|
f"num_replicas={replicas.num_local_replicas}, "
|
|
f"num_partitions={parts.num_partitions}, and "
|
|
f"num_local_devices={xb.local_device_count(backend)}")
|
|
|
|
if no_nested_sharding and (
|
|
replicas.jaxpr_replicas > 1 or parts.num_partitions > 1):
|
|
raise ValueError(
|
|
f"On multi-host platforms, pmapped functions that both have `devices` "
|
|
f"specified and contain an inner_pmap or sharded_jit must specify an "
|
|
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
|
|
f"{replicas.jaxpr_replicas} and nested_partitions={parts.num_partitions}")
|
|
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
|
logging.log(log_priority,
|
|
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d"
|
|
" num_partitions=%d)", fun.__name__, id(fun),
|
|
shards.num_global_shards, avals, replicas.num_global_replicas,
|
|
parts.num_partitions)
|
|
|
|
axis_env = xla.AxisEnv(
|
|
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
|
|
name_stack = new_name_stack(wrap_name(name, 'pmap'))
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
|
replicated_args = [axis is None for axis in in_axes]
|
|
module: Union[str, xc.XlaComputation]
|
|
tuple_args = should_tuple_args(shards)
|
|
module_name = f"pmap_{fun.__name__}"
|
|
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
|
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
|
|
raise ValueError("Ordered effects not supported in `pmap`.")
|
|
unordered_effects = [eff for eff in closed_jaxpr.effects
|
|
if eff not in core.ordered_effects]
|
|
lowering_result = mlir.lower_jaxpr_to_module(
|
|
module_name, closed_jaxpr, unordered_effects, [],
|
|
backend.platform, mlir.ReplicaAxisContext(axis_env),
|
|
name_stack, donated_invars, replicated_args=replicated_args,
|
|
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
|
|
result_shardings=_shardings_to_mlir_shardings(parts.out_parts))
|
|
module, keepalive, host_callbacks = (
|
|
lowering_result.module, lowering_result.keepalive,
|
|
lowering_result.host_callbacks)
|
|
return PmapComputation(module, pci=pci, replicas=replicas, parts=parts,
|
|
shards=shards, tuple_args=tuple_args,
|
|
unordered_effects=unordered_effects,
|
|
keepalive=keepalive, host_callbacks=host_callbacks)
|
|
|
|
|
|
class PmapComputation(stages.XlaLowering):
|
|
_hlo: Union[ir.Module, xc.XlaComputation]
|
|
_executable: Optional[PmapExecutable]
|
|
|
|
def __init__(self, hlo: Union[ir.Module, xc.XlaComputation], **compile_args):
|
|
self._executable = None
|
|
self._hlo = hlo
|
|
self.compile_args = compile_args
|
|
|
|
# -- stages.XlaLowering overrides
|
|
|
|
def hlo(self) -> xc.XlaComputation:
|
|
# this is a method for api consistency with dispatch.XlaComputation
|
|
if isinstance(self._hlo, xc.XlaComputation):
|
|
return self._hlo
|
|
else:
|
|
return xe.mlir.mlir_module_to_xla_computation(
|
|
mlir.module_to_string(self._hlo),
|
|
use_tuple_args=self.compile_args["tuple_args"])
|
|
|
|
def mhlo(self) -> ir.Module:
|
|
if isinstance(self._hlo, xc.XlaComputation):
|
|
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
|
with mlir.make_ir_context():
|
|
return ir.Module.parse(module_str)
|
|
return self._hlo
|
|
|
|
@profiler.annotate_function
|
|
def compile(self) -> PmapExecutable:
|
|
if self._executable is None:
|
|
self._executable = PmapExecutable.from_hlo(self._hlo, **self.compile_args)
|
|
return self._executable
|
|
|
|
|
|
class PmapExecutable(stages.XlaExecutable):
|
|
__slots__ = ['xla_executable', 'unsafe_call', 'fingerprint', 'in_avals']
|
|
|
|
def __init__(self, xla_executable, unsafe_call, fingerprint, in_avals):
|
|
self.xla_executable = xla_executable
|
|
self.unsafe_call = unsafe_call
|
|
self.fingerprint = fingerprint
|
|
self.in_avals = in_avals
|
|
|
|
@staticmethod
|
|
def from_hlo(xla_computation,
|
|
pci: ParallelCallableInfo,
|
|
replicas: ReplicaInfo,
|
|
parts: PartitionInfo,
|
|
shards: ShardInfo,
|
|
tuple_args: bool,
|
|
unordered_effects: List[core.Effect],
|
|
host_callbacks: List[Any],
|
|
keepalive: Any):
|
|
devices = pci.devices
|
|
if devices is None:
|
|
if shards.num_global_shards > xb.device_count(pci.backend):
|
|
msg = ("compiling computation that requires {} logical devices, but only {} XLA "
|
|
"devices are available (num_replicas={}, num_partitions={})")
|
|
raise ValueError(msg.format(shards.num_global_shards,
|
|
xb.device_count(pci.backend),
|
|
replicas.num_global_replicas,
|
|
parts.num_partitions))
|
|
# On a single host, we use the platform's default device assignment to
|
|
# potentially take advantage of device locality. On multiple hosts, the
|
|
# default device assignment may interleave different hosts' replicas,
|
|
# violating pmap's semantics where data is sharded across replicas in
|
|
# row-major order. Instead, manually create a device assignment that ensures
|
|
# each host is responsible for a continguous set of replicas.
|
|
if shards.num_global_shards > shards.num_local_shards:
|
|
# TODO(skye): use a locality-aware assignment that satisfies the above
|
|
# constraint.
|
|
devices = [d for process_index in range(xb.process_count(pci.backend))
|
|
for d in xb.local_devices(process_index, pci.backend)]
|
|
else:
|
|
devices = xb.get_backend(pci.backend).get_default_device_assignment(
|
|
replicas.num_global_replicas, parts.num_partitions)
|
|
else:
|
|
if shards.num_local_shards != len(pci.local_devices):
|
|
local_devices_str = ", ".join(map(str, pci.local_devices))
|
|
if shards.num_local_shards == pci.axis_size:
|
|
raise ValueError(
|
|
f"Leading axis size of input to pmapped function must equal the "
|
|
f"number of local devices passed to pmap. Got axis_size="
|
|
f"{pci.axis_size}, num_local_devices={len(pci.local_devices)}.\n"
|
|
f"(Local devices available to pmap: {local_devices_str})")
|
|
else:
|
|
raise ValueError(
|
|
f"pmapped function requires {shards.num_local_shards} local "
|
|
f"devices to run due to nested pmapped or other parallel "
|
|
f"functions, but only {len(pci.local_devices)} are available.\n"
|
|
f"(outer axis size: {pci.axis_size}, local devices available to "
|
|
f"pmap: {local_devices_str})")
|
|
if shards.num_global_shards != len(devices):
|
|
raise ValueError("compiling computation that creates %s shards, "
|
|
"but %s devices were specified" %
|
|
(shards.num_global_shards, len(devices)))
|
|
|
|
# 'devices' may be 1D or 2D at this point (e.g.
|
|
# get_default_device_assignment() returns 2D assignment, caller may have
|
|
# provided 1D list of devices).
|
|
# Convert to 2D in case it's 1D and we have > 1 partitions.
|
|
device_assignment = np.array(devices).reshape(
|
|
(replicas.num_global_replicas, parts.num_partitions))
|
|
# TODO(b/162356737): Enabling SPMD partitioning causes issues with some
|
|
# non-partitioned workloads, so disable unless needed.
|
|
use_spmd_partitioning = parts.num_partitions > 1
|
|
compile_options = xb.get_compile_options(
|
|
num_replicas=replicas.num_global_replicas,
|
|
num_partitions=parts.num_partitions,
|
|
device_assignment=device_assignment,
|
|
use_spmd_partitioning=use_spmd_partitioning,
|
|
)
|
|
compile_options.parameter_is_tupled_arguments = tuple_args
|
|
|
|
process_index = xb.process_index(pci.backend)
|
|
local_device_assignment = np.array([
|
|
d for d in device_assignment.flat if d.process_index == process_index
|
|
])
|
|
|
|
local_arg_parts_ = parts.local_arg_parts or [None] * len(pci.avals)
|
|
# TODO(yashkatariya): Fix the input handling of `Array`s that span over
|
|
# multiple processes. Add multi-process tests for pmap.
|
|
input_sharding_specs = [
|
|
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
|
|
parts.local_num_partitions, arg_parts, aval, in_axis)
|
|
for aval, arg_parts, in_axis in safe_zip(
|
|
shards.sharded_avals, local_arg_parts_, pci.in_axes)]
|
|
input_indices = [spec_to_indices(aval.shape, spec)
|
|
if spec is not None else None
|
|
for aval, spec in safe_zip(pci.avals, input_sharding_specs)]
|
|
in_shardings = _get_pmap_sharding(local_device_assignment, input_sharding_specs)
|
|
nouts = len(shards.out_sharded_avals)
|
|
|
|
out_parts, local_out_parts = parts.out_parts, parts.local_out_parts
|
|
if parts.out_parts is None:
|
|
out_parts = (None,) * nouts
|
|
if parts.local_out_parts is None:
|
|
local_out_parts = (None,) * nouts
|
|
|
|
if config.jax_array:
|
|
global_unmapped_avals = [
|
|
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval)
|
|
if out_axis is not None else aval
|
|
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
|
|
global_out_specs = [
|
|
_pmap_sharding_spec(replicas.num_global_replicas, pci.axis_size,
|
|
parts.num_partitions, op, aval, out_axis)
|
|
for op, aval, out_axis in safe_zip(
|
|
out_parts, shards.out_sharded_avals, pci.out_axes)]
|
|
pmap_shardings = _get_pmap_sharding(device_assignment, global_out_specs)
|
|
handle_outs = global_avals_to_results_handler(
|
|
global_unmapped_avals, pmap_shardings)
|
|
else:
|
|
local_out_avals = [
|
|
get_local_aval(aval, parts, lparts)
|
|
for aval, parts, lparts
|
|
in safe_zip(shards.out_sharded_avals, out_parts, local_out_parts)]
|
|
local_unmapped_avals = [
|
|
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval)
|
|
if out_axis is not None else aval
|
|
for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)]
|
|
out_specs = [
|
|
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
|
|
parts.local_num_partitions, out_parts, aval, out_axis)
|
|
for out_parts, aval, out_axis in safe_zip(
|
|
local_out_parts, local_out_avals, pci.out_axes)]
|
|
pmap_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
|
|
handle_outs = local_avals_to_results_handler(local_unmapped_avals, pmap_shardings)
|
|
|
|
if hasattr(pci.backend, "compile_replicated"):
|
|
execute_fun = pci.backend.compile_replicated(
|
|
xla_computation, compile_options, pci.avals, input_indices,
|
|
in_shardings, handle_outs)
|
|
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
|
return PmapExecutable(None, execute_fun, None, pci.avals)
|
|
|
|
with dispatch.log_elapsed_time(
|
|
f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec"):
|
|
compiled = dispatch.compile_or_get_cached(
|
|
pci.backend, xla_computation, compile_options, host_callbacks)
|
|
handle_args = InputsHandler(
|
|
compiled.local_devices(), in_shardings, input_indices)
|
|
execute_fun = ExecuteReplicated(compiled, pci.backend, handle_args,
|
|
handle_outs, unordered_effects, keepalive)
|
|
fingerprint = getattr(compiled, "fingerprint", None)
|
|
|
|
return PmapExecutable(compiled, execute_fun, fingerprint, pci.avals)
|
|
|
|
# -- stages.XlaExecutable overrides
|
|
|
|
def xla_extension_executable(self):
|
|
return self.xla_executable
|
|
|
|
@profiler.annotate_function
|
|
def call(self, *args):
|
|
# TODO(frostig): do we need to check sharding and sharded avals?
|
|
arg_avals = map(xla.abstractify, args)
|
|
dispatch.check_arg_avals_for_call(self.in_avals, arg_avals)
|
|
return self.unsafe_call(*args)
|
|
|
|
|
|
def _get_pmap_sharding(devices, specs):
|
|
from jax.experimental.sharding import PmapSharding
|
|
|
|
return [PmapSharding(devices, spec) for spec in specs]
|
|
|
|
|
|
multi_host_supported_collectives: Set[core.Primitive] = set()
|
|
|
|
|
|
def check_multihost_collective_allowlist(jaxpr):
|
|
used_collectives = set(xla.jaxpr_collectives(jaxpr))
|
|
if not used_collectives.issubset(multi_host_supported_collectives):
|
|
bad_collectives = used_collectives - multi_host_supported_collectives
|
|
msg = "using collectives that aren't supported for multi-host: {}"
|
|
raise TypeError(msg.format(", ".join(map(str, bad_collectives))))
|
|
|
|
|
|
PartitionsOrReplicated = Optional[Tuple[int, ...]]
|
|
|
|
class PartitionInfo(NamedTuple):
|
|
arg_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
|
|
out_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
|
|
num_partitions: int
|
|
local_arg_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
|
|
local_out_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
|
|
local_num_partitions: Optional[int]
|
|
|
|
def _find_partitions(jaxpr):
|
|
"""Returns (in_partitions, out_partitions, num_partitions, local_in_parts,
|
|
local_out_parts, local_num_partitions).
|
|
"""
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive.name == "sharded_call":
|
|
if len(jaxpr.eqns) > 1:
|
|
raise NotImplementedError(
|
|
"pmap of sharded_jit + non-sharded operations not yet implemented.")
|
|
num_partitions = reconcile_num_partitions(eqn.params["call_jaxpr"],
|
|
eqn.params["nparts"])
|
|
return (eqn.params["in_parts"],
|
|
eqn.params["out_parts_thunk"](),
|
|
num_partitions,
|
|
eqn.params["local_in_parts"],
|
|
eqn.params["local_out_parts_thunk"](),
|
|
eqn.params["local_nparts"])
|
|
return None, None, 1, None, None, None
|
|
|
|
def find_partitions(jaxpr) -> PartitionInfo:
|
|
(arg_parts, out_parts, num_partitions, local_arg_parts, local_out_parts,
|
|
local_num_partitions) = _find_partitions(jaxpr)
|
|
|
|
if local_num_partitions is None:
|
|
local_num_partitions = num_partitions
|
|
if local_arg_parts is None:
|
|
local_arg_parts = arg_parts
|
|
if local_out_parts is None:
|
|
local_out_parts = out_parts
|
|
|
|
return PartitionInfo(arg_parts, out_parts, num_partitions,
|
|
local_arg_parts, local_out_parts, local_num_partitions)
|
|
|
|
|
|
def reconcile_num_partitions(jaxpr, outer_num_parts: Optional[int]):
|
|
"""Returns the total number of partitions to use.
|
|
|
|
Validates that any inner partitioning matches outer_num_parts if provided, and
|
|
returns the number of partitions to use based on outer_num_parts and any inner
|
|
partitioning.
|
|
"""
|
|
inner_num_parts = _inner_partitions(jaxpr, outer_num_parts)
|
|
if outer_num_parts is None and inner_num_parts is None:
|
|
# No partitions specified anywhere, everything is replicated.
|
|
return 1
|
|
if outer_num_parts is None:
|
|
return inner_num_parts
|
|
return outer_num_parts
|
|
|
|
|
|
def _inner_partitions(jaxpr, expected_num_parts: Optional[int]):
|
|
"""Returns the total number of partitions from PartitionSpecs inside `jaxpr`.
|
|
|
|
Also validates that this number matches `expected_num_parts` if provided.
|
|
"""
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive.name in ["sharding_constraint", "infeed"]:
|
|
parts = eqn.params["partitions"]
|
|
nparts = get_num_partitions(parts)
|
|
if expected_num_parts is None:
|
|
expected_num_parts = nparts
|
|
elif nparts is not None and nparts != expected_num_parts:
|
|
# TODO(skye): raise this error as we trace the jaxpr
|
|
raise ValueError(
|
|
f"with_sharding_constraint with partitions={parts} "
|
|
f"(total partitions: {nparts}) doesn't match expected number of "
|
|
f"partitions: {expected_num_parts}. If these partitions look "
|
|
f"right, check outer sharded_jit and/or other "
|
|
f"with_sharding_constraint calls.")
|
|
else:
|
|
for subjaxpr in core.jaxprs_in_params(eqn.params):
|
|
expected_num_parts = _inner_partitions(subjaxpr, expected_num_parts)
|
|
return expected_num_parts
|
|
|
|
|
|
def get_num_partitions(*partitions):
|
|
partition_specs = tree_flatten(partitions)[0]
|
|
if len(partition_specs) == 0:
|
|
# Everything is specified as replicated (all Nones).
|
|
return None
|
|
num_partitions_set = {np.prod(spec) for spec in partition_specs}
|
|
if len(num_partitions_set) > 1:
|
|
raise ValueError(
|
|
f"All partition specs must use the same number of total partitions, "
|
|
f"got {partitions}, with distinct number of partitions "
|
|
f"{num_partitions_set} (the total number of partitions is the product "
|
|
f"of a partition spec)")
|
|
assert len(num_partitions_set) == 1
|
|
return num_partitions_set.pop()
|
|
|
|
|
|
def get_global_aval(local_aval, global_parts: PartitionsOrReplicated,
|
|
local_parts: PartitionsOrReplicated):
|
|
if global_parts is None:
|
|
return local_aval
|
|
assert local_parts is not None
|
|
global_shape = [dim * _safe_div(ngparts, nlparts)
|
|
for dim, ngparts, nlparts
|
|
in safe_zip(local_aval.shape, global_parts, local_parts)]
|
|
return local_aval.update(shape=global_shape)
|
|
|
|
|
|
def get_local_aval(global_aval, global_parts: PartitionsOrReplicated,
|
|
local_parts: PartitionsOrReplicated):
|
|
if global_parts is None:
|
|
return global_aval
|
|
assert local_parts is not None
|
|
local_shape = [_safe_div(dim, _safe_div(ngparts, nlparts))
|
|
for dim, ngparts, nlparts
|
|
in safe_zip(global_aval.shape, global_parts, local_parts)]
|
|
return global_aval.update(shape=local_shape)
|
|
|
|
|
|
def _safe_div(x, y):
|
|
result, ragged = divmod(x, y)
|
|
assert not ragged, f"{x} % {y} != 0"
|
|
return result
|
|
|
|
|
|
class InputsHandler:
|
|
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
|
|
|
|
def __init__(self, local_devices, in_shardings, input_indices):
|
|
self.handler = partial(shard_args, local_devices, input_indices)
|
|
self.local_devices = local_devices
|
|
self.in_shardings = in_shardings
|
|
self.input_indices = input_indices
|
|
|
|
def __call__(self, input_buffers):
|
|
return self.handler(input_buffers)
|
|
|
|
def __str__(self):
|
|
return ("InputsHandler(\n"
|
|
f"local_devices={self.local_devices},\n"
|
|
f"in_shardings={self.in_shardings},\n"
|
|
f"input_indices={self.input_indices})")
|
|
|
|
|
|
class ResultsHandler:
|
|
# `out_avals` is the `GlobalDeviceArray` global avals when using pjit or xmap
|
|
# with `config.parallel_functions_output_gda=True`. It is the local one
|
|
# otherwise, and also when using `pmap`.
|
|
__slots__ = ("handlers", "out_shardings", "out_avals")
|
|
|
|
def __init__(self, handlers, out_shardings, out_avals):
|
|
self.handlers = handlers
|
|
self.out_shardings = out_shardings
|
|
self.out_avals = out_avals
|
|
|
|
def __call__(self, out_bufs):
|
|
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]
|
|
|
|
|
|
def _get_sharding_specs(
|
|
shardings: Sequence[XLACompatibleSharding], avals: Sequence[ShapedArray]
|
|
) -> Sequence[ShardingSpec]:
|
|
from jax.experimental import sharding
|
|
|
|
if all(isinstance(s, sharding.PmapSharding) for s in shardings):
|
|
return [s.sharding_spec for s in shardings] # type: ignore
|
|
elif all(isinstance(s, sharding.MeshPspecSharding) for s in shardings):
|
|
return [new_mesh_sharding_specs(s.mesh.shape, s.mesh.axis_names)(
|
|
aval.ndim, _get_array_mapping(s.spec))
|
|
for aval, s in safe_zip(avals, shardings)]
|
|
else:
|
|
raise ValueError('Getting sharding spec is only supported for '
|
|
'PmapSharding and MeshPspecSharding.')
|
|
|
|
def local_avals_to_results_handler(
|
|
unmapped_local_out_avals: Sequence[Optional[ShapedArray]],
|
|
local_shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler:
|
|
local_out_specs = _get_sharding_specs(
|
|
local_shardings, cast(Sequence[ShapedArray], unmapped_local_out_avals))
|
|
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
|
|
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
|
|
handlers = [
|
|
local_aval_to_result_handler(aval, spec, idcs)
|
|
for aval, spec, idcs in safe_zip(unmapped_local_out_avals, local_out_specs, out_indices)
|
|
]
|
|
return ResultsHandler(handlers, local_shardings, unmapped_local_out_avals)
|
|
|
|
|
|
def global_avals_to_results_handler(
|
|
global_out_avals: Sequence[ShapedArray],
|
|
shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler:
|
|
from jax.experimental.sharding import MeshPspecSharding
|
|
|
|
if config.jax_parallel_functions_output_gda or config.jax_array:
|
|
handlers = [
|
|
global_aval_to_result_handler(global_aval, s)
|
|
for global_aval, s in safe_zip(global_out_avals, shardings)
|
|
]
|
|
return ResultsHandler(handlers, shardings, global_out_avals)
|
|
else:
|
|
# This path is taken when the outputs are SDAs.
|
|
assert all(isinstance(s, MeshPspecSharding) for s in shardings)
|
|
local_out_avals = [s.mesh._global_to_local(_get_array_mapping(s.spec), aval)
|
|
for aval, s in safe_zip(global_out_avals, shardings)]
|
|
local_shardings = [MeshPspecSharding(s.mesh.local_mesh, s.spec) for s in shardings] # type: ignore
|
|
return local_avals_to_results_handler(local_out_avals, local_shardings)
|
|
|
|
|
|
@profiler.annotate_function
|
|
def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
|
|
"""Replicates ``val`` across multiple devices.
|
|
|
|
Args:
|
|
val: the value to be replicated.
|
|
axis_size: the length of the output, i.e. the logical number of replicas to
|
|
create. Usually equal to `nrep`, but in the case of nested pmaps, `nrep` may
|
|
be a multiple of `axis_size`.
|
|
nrep: the number of replicas to create. If ``devices`` is set, must be equal
|
|
to ``len(devices)``.
|
|
devices: the devices to replicate across. If None, ``nrep`` will be used to
|
|
generate a default device assignment.
|
|
backend: string specifying which backend to use.
|
|
in_axis: axis along which the value is to be replciated.
|
|
|
|
Returns:
|
|
A ShardedDeviceArray of length `axis_size` where each shard is equal to
|
|
``val``.
|
|
"""
|
|
device_count = (len(devices) if devices else xb.local_device_count(backend))
|
|
if nrep > device_count:
|
|
msg = ("Cannot replicate across %d replicas because only %d local devices "
|
|
"are available." % (nrep, device_count))
|
|
if devices:
|
|
msg += (" (local devices = %s)"
|
|
% ", ".join(map(str, devices)) if devices else str(None))
|
|
raise ValueError(msg)
|
|
|
|
if devices is None:
|
|
assert nrep is not None
|
|
# TODO(skye): use different device assignment on multihost
|
|
devices = xb.get_backend(backend).get_default_device_assignment(nrep)
|
|
assert nrep == len(devices)
|
|
|
|
aval = xla.abstractify(val) # type: ShapedArray
|
|
if in_axis is not None:
|
|
replicated_aval = aval.update(shape=(axis_size,) + aval.shape)
|
|
else:
|
|
replicated_aval = aval
|
|
# TODO(skye): figure out how partitioning should work here
|
|
sharding_spec = _pmap_sharding_spec(nrep, axis_size, 1, None, aval, in_axis)
|
|
device_buffers = device_put(val, devices, replicate=True)
|
|
return make_sharded_device_array(replicated_aval, sharding_spec,
|
|
device_buffers)
|
|
|
|
|
|
def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval,
|
|
map_axis: Optional[int]) -> ShardingSpec:
|
|
"""Sharding spec for arguments or results of a pmap.
|
|
Args:
|
|
nrep: number of local XLA replicas (product of local axis sizes)
|
|
axis_size: local axis size for outer pmap
|
|
npart: total number of XLA partitions (required by sharded_jit calls)
|
|
parts: the partitioning of the value or None
|
|
sharded_aval: the aval of the value inside the outer pmap, an instance of
|
|
a ShapedArray.
|
|
map_axis: the axis along which the value is mapped in the outer pmap
|
|
Returns:
|
|
A ShardingSpec.
|
|
"""
|
|
assert isinstance(sharded_aval, ShapedArray), sharded_aval
|
|
replication_factor, ragged = divmod(nrep, axis_size)
|
|
assert not ragged
|
|
# get the sharding spec from inner sharded_jits as if we weren't in a pmap
|
|
pspec = partitioned_sharding_spec(npart, parts, sharded_aval)
|
|
maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),)
|
|
if map_axis is not None:
|
|
sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis])
|
|
def shift_sharded_axis(a: MeshDimAssignment):
|
|
if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis:
|
|
return ShardedAxis(a.axis + 1)
|
|
return a
|
|
# replication_factor represents the product of inner pmaps, so it goes
|
|
# after the outer pmapped axis at index 0
|
|
return ShardingSpec(
|
|
sharding=tuple_insert(pspec.sharding, map_axis, Unstacked(axis_size)),
|
|
mesh_mapping=it.chain([ShardedAxis(sharded_in_axis)],
|
|
maybe_replicate,
|
|
map(shift_sharded_axis, pspec.mesh_mapping)))
|
|
else:
|
|
return ShardingSpec(
|
|
sharding=pspec.sharding,
|
|
mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping)
|
|
|
|
def partitioned_sharding_spec(num_partitions: int,
|
|
partitions: Optional[Sequence[int]],
|
|
aval) -> ShardingSpec:
|
|
if partitions is None:
|
|
maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),)
|
|
return ShardingSpec(
|
|
sharding=[_UNSHARDED_INSTANCE] * len(aval.shape),
|
|
mesh_mapping=maybe_replicate)
|
|
else:
|
|
assert len(partitions) == len(aval.shape)
|
|
return ShardingSpec(
|
|
# Chunked expects a list of integers
|
|
sharding=map(Chunked, [[x] for x in partitions]),
|
|
mesh_mapping=map(ShardedAxis, range(len(partitions))))
|
|
|
|
|
|
class ExecuteReplicated:
|
|
"""The logic to shard inputs, execute a replicated model, returning outputs."""
|
|
__slots__ = ['xla_executable', 'backend', 'in_handler', 'out_handler',
|
|
'has_unordered_effects', 'keepalive']
|
|
|
|
def __init__(self, xla_executable, backend, in_handler: InputsHandler,
|
|
out_handler: ResultsHandler,
|
|
unordered_effects: List[core.Effect], keepalive: Any):
|
|
self.xla_executable = xla_executable
|
|
self.backend = backend
|
|
self.in_handler = in_handler
|
|
self.out_handler = out_handler
|
|
self.has_unordered_effects = bool(unordered_effects)
|
|
self.keepalive = keepalive
|
|
|
|
@profiler.annotate_function
|
|
def __call__(self, *args):
|
|
input_bufs = self.in_handler(args)
|
|
if self.has_unordered_effects:
|
|
# TODO(sharadmv): simplify this logic when minimum jaxlib version is
|
|
# bumped
|
|
if xla_extension_version >= 81:
|
|
out_bufs, runtime_tokens = (
|
|
self.xla_executable.execute_sharded_on_local_devices_with_tokens(
|
|
input_bufs))
|
|
for device, token in zip(
|
|
self.xla_executable.local_devices(), runtime_tokens):
|
|
dispatch.runtime_tokens.set_output_runtime_token(device, token)
|
|
else:
|
|
out_bufs = self.xla_executable.execute_sharded_on_local_devices(
|
|
input_bufs)
|
|
token_bufs, *out_bufs = out_bufs
|
|
for i, device in enumerate(self.xla_executable.local_devices()):
|
|
token = (token_bufs[i],)
|
|
dispatch.runtime_tokens.set_output_token(device, token)
|
|
else:
|
|
out_bufs = self.xla_executable.execute_sharded_on_local_devices(
|
|
input_bufs)
|
|
if dispatch.needs_check_special():
|
|
for bufs in out_bufs:
|
|
dispatch.check_special("parallel computation", bufs)
|
|
return self.out_handler(out_bufs)
|
|
|
|
|
|
xla_pmap_p = core.MapPrimitive('xla_pmap')
|
|
xla_pmap = xla_pmap_p.bind
|
|
xla_pmap_p.def_impl(xla_pmap_impl)
|
|
|
|
def _pmap_partial_eval_custom_params_updater(
|
|
unks_in, inst_in, kept_outs_known, kept_outs_staged, num_res, params_known,
|
|
params_staged):
|
|
# prune inputs to jaxpr_known according to unks_in
|
|
donated_invars_known, _ = partition_list(unks_in, params_known['donated_invars'])
|
|
in_axes_known, _ = partition_list(unks_in, params_known['in_axes'])
|
|
_, out_axes_known = partition_list(kept_outs_known, params_known['out_axes'])
|
|
out_axes_known = out_axes_known + [0] * num_res
|
|
new_params_known = dict(params_known, in_axes=tuple(in_axes_known),
|
|
out_axes=tuple(out_axes_known),
|
|
donated_invars=tuple(donated_invars_known))
|
|
|
|
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
|
|
_, donated_invars_staged = partition_list(inst_in, params_staged['donated_invars'])
|
|
donated_invars_staged = [False] * num_res + donated_invars_staged
|
|
_, in_axes_staged = partition_list(inst_in, params_staged['in_axes'])
|
|
in_axes_staged = [0] * num_res + in_axes_staged
|
|
_, out_axes_staged = partition_list(kept_outs_staged, params_staged['out_axes'])
|
|
new_params_staged = dict(params_staged, in_axes=tuple(in_axes_staged),
|
|
out_axes=tuple(out_axes_staged),
|
|
donated_invars=tuple(donated_invars_staged))
|
|
return new_params_known, new_params_staged
|
|
|
|
def _pmap_partial_eval_custom_res_maker(params_known, aval):
|
|
return core.unmapped_aval(params_known['axis_size'], core.no_axis_name, 0, aval)
|
|
|
|
def _pmap_dce_rule(used_outputs, eqn):
|
|
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
|
|
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
|
|
_, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
|
|
_, out_axes = partition_list(used_outputs, eqn.params['out_axes'])
|
|
new_params = dict(eqn.params, call_jaxpr=new_jaxpr, in_axes=tuple(in_axes),
|
|
out_axes=tuple(out_axes))
|
|
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
|
|
return used_inputs, None
|
|
else:
|
|
new_eqn = pe.new_jaxpr_eqn(
|
|
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
|
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
|
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
|
return used_inputs, new_eqn
|
|
|
|
|
|
# Set param update handlers to update `donated_invars` just like xla_call_p
|
|
pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
|
|
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
|
|
partial(pe.call_partial_eval_custom_rule,
|
|
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
|
|
res_aval=_pmap_partial_eval_custom_res_maker)
|
|
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
|
|
ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p]
|
|
ad.call_transpose_param_updaters[xla_pmap_p] = \
|
|
ad.call_transpose_param_updaters[xla.xla_call_p]
|
|
|
|
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
|
|
|
|
|
|
def _unravel_index_mhlo(axis_env):
|
|
div = mlir.ir_constant(
|
|
np.array(axis_env.nreps // util.prod(axis_env.sizes), np.uint32))
|
|
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
|
|
return mhlo.RemOp(
|
|
mhlo.DivOp(mhlo.ReplicaIdOp().result, div).result, mod).result
|
|
|
|
def _mhlo_shard(aval, axis_env, xs, in_axis):
|
|
if aval is core.abstract_token:
|
|
return xs
|
|
elif isinstance(aval, core.ShapedArray):
|
|
x, = xs
|
|
dims = list(aval.shape)
|
|
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
|
idxs = [zero] * len(dims)
|
|
idxs.insert(in_axis, _unravel_index_mhlo(axis_env))
|
|
dims_unsqueezed = dims.copy()
|
|
dims_unsqueezed.insert(in_axis, 1)
|
|
dynamic_slice_result = mhlo.DynamicSliceOp(
|
|
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
|
|
return [
|
|
mhlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result
|
|
]
|
|
else:
|
|
raise TypeError(aval)
|
|
|
|
|
|
# TODO(b/110096942): more efficient gather
|
|
def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
|
|
if aval is core.abstract_token:
|
|
return xs
|
|
elif isinstance(aval, core.ShapedArray):
|
|
x, = xs
|
|
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
|
convert_bool = (np.issubdtype(aval.dtype, np.bool_)
|
|
and platform in ('cpu', 'gpu'))
|
|
if convert_bool:
|
|
aval = aval.update(dtype=np.dtype(np.float32))
|
|
x = mhlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result
|
|
|
|
dims = list(aval.shape)
|
|
padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims)
|
|
padded = mlir.full_like_aval(0, padded_aval)
|
|
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
|
idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims)
|
|
broadcast_result = mhlo.BroadcastOp(
|
|
x, mlir.dense_int_elements([1])).result
|
|
padded = mhlo.DynamicUpdateSliceOp(
|
|
padded.type, padded, broadcast_result, idxs).result
|
|
replica_groups = mlir.dense_int_elements(
|
|
xla.axis_groups(axis_env, axis_env.names[-1]))
|
|
out = mhlo.CrossReplicaSumOp(padded, replica_groups).result
|
|
if out_axis != 0:
|
|
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
|
|
perm = list(range(1, len(dims)))
|
|
perm.insert(out_axis, 0)
|
|
transposed_dims = list(dims)
|
|
transposed_dims.insert(out_axis, axis_env.sizes[-1])
|
|
aval = aval.update(shape=transposed_dims)
|
|
out = mhlo.TransposeOp(out, mlir.dense_int_elements(perm)).result
|
|
|
|
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
|
if convert_bool:
|
|
float_zero = mlir.full_like_aval(0, padded_aval)
|
|
out = mhlo.CompareOp(
|
|
out,
|
|
float_zero,
|
|
mhlo.ComparisonDirectionAttr.get("NE"),
|
|
compare_type=mhlo.ComparisonTypeAttr.get("FLOAT")).result
|
|
return out
|
|
else:
|
|
raise TypeError(aval)
|
|
|
|
|
|
def _pmap_lowering(ctx, *in_nodes, axis_name,
|
|
axis_size, global_axis_size, devices, name,
|
|
call_jaxpr, backend=None, in_axes, out_axes,
|
|
donated_invars, global_arg_shapes):
|
|
del donated_invars # Unused.
|
|
xla.check_backend_matches(backend, ctx.module_context.platform)
|
|
# We in-line here rather than generating a Call HLO as in the xla_call
|
|
# translation rule just because the extra tuple stuff is a pain.
|
|
if ctx.module_context.axis_env.names and devices is not None:
|
|
raise ValueError("Nested pmap with explicit devices argument.")
|
|
if global_axis_size is None:
|
|
global_axis_size = axis_size
|
|
new_env = xla.extend_axis_env(ctx.module_context.axis_env, axis_name,
|
|
global_axis_size)
|
|
# Shard the in_nodes that are mapped
|
|
in_avals = [v.aval for v in call_jaxpr.invars]
|
|
in_nodes_sharded = (
|
|
_mhlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis)
|
|
if in_axis is not None else mlir.wrap_singleton_ir_values(in_node)
|
|
for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes))
|
|
|
|
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
|
sub_ctx = ctx.module_context.replace(
|
|
axis_context=mlir.ReplicaAxisContext(new_env),
|
|
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
|
|
util.wrap_name(name, 'pmap')))
|
|
sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (),
|
|
*in_nodes_sharded)
|
|
out_avals = [v.aval for v in call_jaxpr.outvars]
|
|
outs = [_mhlo_unshard(aval, new_env, out_axis, shard,
|
|
platform=ctx.module_context.platform)
|
|
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
|
|
return outs
|
|
|
|
mlir.register_lowering(xla_pmap_p, _pmap_lowering)
|
|
|
|
|
|
# ------------------- xmap -------------------
|
|
|
|
class Mesh(ContextDecorator):
|
|
"""Declare the hardware resources available in the scope of this manager.
|
|
|
|
In particular, all ``axis_names`` become valid resource names inside the
|
|
managed block and can be used e.g. in the ``in_axis_resources`` argument of
|
|
:py:func:`jax.experimental.pjit.pjit`. Also see JAX's multi-process programming model (https://jax.readthedocs.io/en/latest/multi_process.html)
|
|
and pjit tutorial (https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html).
|
|
|
|
If you are compiling in multiple threads, make sure that the
|
|
``with Mesh`` context manager is inside the function that the threads will
|
|
execute.
|
|
|
|
Args:
|
|
devices: A NumPy ndarray object containing JAX device objects (as
|
|
obtained e.g. from :py:func:`jax.devices`).
|
|
axis_names: A sequence of resource axis names to be assigned to the
|
|
dimensions of the ``devices`` argument. Its length should match the
|
|
rank of ``devices``.
|
|
|
|
Example:
|
|
|
|
>>> from jax.experimental.maps import Mesh
|
|
>>> from jax.experimental.pjit import pjit
|
|
>>> from jax.experimental import PartitionSpec as P
|
|
>>> import numpy as np
|
|
...
|
|
>>> inp = np.arange(16).reshape((8, 2))
|
|
>>> devices = np.array(jax.devices()).reshape(4, 2)
|
|
...
|
|
>>> # Declare a 2D mesh with axes `x` and `y`.
|
|
>>> global_mesh = Mesh(devices, ('x', 'y'))
|
|
>>> # Use the mesh object directly as a context manager.
|
|
>>> with global_mesh:
|
|
... out = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(inp)
|
|
|
|
>>> # Initialize the Mesh and use the mesh as the context manager.
|
|
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
|
|
... out = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(inp)
|
|
|
|
>>> # Also you can use it as `with ... as ...`.
|
|
>>> global_mesh = Mesh(devices, ('x', 'y'))
|
|
>>> with global_mesh as m:
|
|
... out = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(inp)
|
|
|
|
>>> # You can also use it as `with Mesh(...)`.
|
|
>>> with Mesh(devices, ('x', 'y')):
|
|
... out = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(inp)
|
|
"""
|
|
|
|
devices: np.ndarray
|
|
axis_names: Tuple[MeshAxisName, ...]
|
|
|
|
def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]):
|
|
assert devices.ndim == len(axis_names)
|
|
# TODO: Make sure that devices are unique? At least with the quick and
|
|
# dirty check that the array size is not larger than the number of
|
|
# available devices?
|
|
self.devices = devices.copy()
|
|
self.devices.flags.writeable = False
|
|
self.axis_names = tuple(axis_names)
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, Mesh):
|
|
return False
|
|
# This is a performance optimization. Comparing thousands of devices
|
|
# can be expensive.
|
|
if id(self) == id(other):
|
|
return True
|
|
return (self.axis_names == other.axis_names and
|
|
np.array_equal(self.devices, other.devices))
|
|
|
|
def __hash__(self):
|
|
if not hasattr(self, '_hash'):
|
|
self._hash = hash(
|
|
(self.axis_names, tuple(self.devices.flat), self.devices.shape))
|
|
return self._hash
|
|
|
|
def __setattr__(self, name, value):
|
|
if hasattr(self, name):
|
|
raise RuntimeError("Cannot reassign attributes of immutable mesh objects")
|
|
super().__setattr__(name, value)
|
|
|
|
def __enter__(self):
|
|
new_env = _old_env.stack[-1].with_mesh(self)
|
|
_old_env.stack.append(new_env)
|
|
thread_resources.env = new_env
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
_old_env.stack.pop()
|
|
thread_resources.env = _old_env.stack[-1]
|
|
return False
|
|
|
|
@property
|
|
def shape(self):
|
|
return OrderedDict((name, size) for name, size in safe_zip(self.axis_names, self.devices.shape))
|
|
|
|
@property
|
|
def size(self):
|
|
return np.prod(list(self.shape.values()))
|
|
|
|
@property
|
|
def empty(self):
|
|
return self.devices.ndim == 0
|
|
|
|
@property
|
|
def is_multi_process(self):
|
|
return self.devices.size != len(self.local_devices)
|
|
|
|
@maybe_cached_property
|
|
def local_mesh(self):
|
|
return self._local_mesh(xb.process_index())
|
|
|
|
def _local_mesh(self, process_index):
|
|
if self.empty:
|
|
return self
|
|
is_local_device = np.vectorize(
|
|
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
|
|
subcube_indices = []
|
|
# We take the smallest slice of each dimension that doesn't skip any local device.
|
|
for axis in range(self.devices.ndim):
|
|
other_axes = tuple_delete(tuple(range(self.devices.ndim)), axis)
|
|
# NOTE: This re-reduces over many axes multiple times, so we could definitely
|
|
# optimize it, but I hope it won't be a bottleneck anytime soon.
|
|
local_slices = is_local_device.any(other_axes, keepdims=False)
|
|
nonzero_indices = np.flatnonzero(local_slices)
|
|
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
|
|
subcube_indices.append(slice(start, end + 1))
|
|
subcube_indices = tuple(subcube_indices)
|
|
# We only end up with all conditions being true if the local devices formed a
|
|
# subcube of the full array. This is because we were biased towards taking a
|
|
# "hull" spanned by the devices, and in case the local devices don't form a
|
|
# subcube that hull will contain non-local devices.
|
|
if not is_local_device[subcube_indices].all():
|
|
raise ValueError(
|
|
"When passing non-GlobalDeviceArray inputs to pjit or xmap, devices "
|
|
"connected to a single host must form a contiguous subcube of the "
|
|
"global device mesh")
|
|
return Mesh(self.devices[subcube_indices], self.axis_names)
|
|
|
|
@property
|
|
def device_ids(self):
|
|
assert not self.empty
|
|
return np.vectorize(lambda d: d.id, otypes=[int])(self.devices)
|
|
|
|
def __repr__(self):
|
|
if self.empty:
|
|
return "Mesh([], ())"
|
|
return f"Mesh({self.device_ids!r}, {self.axis_names!r})"
|
|
|
|
@maybe_cached_property
|
|
def local_devices(self):
|
|
process_index = xb.process_index()
|
|
return [d for d in self.devices.flat if d.process_index == process_index]
|
|
|
|
def _local_to_global(self, axes: ArrayMapping, aval):
|
|
return untile_aval_nd(self.shape, axes,
|
|
tile_aval_nd(self.local_mesh.shape, axes, aval))
|
|
|
|
def _global_to_local(self, axes: ArrayMapping, aval):
|
|
return untile_aval_nd(self.local_mesh.shape, axes,
|
|
tile_aval_nd(self.shape, axes, aval))
|
|
|
|
|
|
ResourceAxisName = core.AxisName
|
|
|
|
class _Loop(NamedTuple):
|
|
name: ResourceAxisName
|
|
length: int
|
|
|
|
|
|
def show_axes(axes):
|
|
return ", ".join(sorted(f"`{a}`" for a in axes))
|
|
|
|
|
|
class ResourceEnv(NamedTuple):
|
|
physical_mesh: Mesh
|
|
loops: Tuple[_Loop, ...]
|
|
|
|
def with_mesh(self, mesh: Mesh):
|
|
overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names))
|
|
if overlap:
|
|
raise ValueError(f"Cannot update the mesh of the current resource "
|
|
f"environment. The new mesh shadows already defined axes "
|
|
f"{show_axes(overlap)}")
|
|
return self._replace(physical_mesh=mesh)
|
|
|
|
def with_extra_loop(self, loop: _Loop):
|
|
if loop.name in self.resource_axes:
|
|
raise ValueError(f"Cannot extend the resource environment with loop named "
|
|
f"`{loop.name}`. An axis of this name is already defined!")
|
|
return self._replace(loops=self.loops + (loop,))
|
|
|
|
@property
|
|
def physical_resource_axes(self) -> Set[ResourceAxisName]:
|
|
return set(self.physical_mesh.axis_names)
|
|
|
|
@property
|
|
def loop_resource_axes(self) -> Set[ResourceAxisName]:
|
|
return {loop.name for loop in self.loops}
|
|
|
|
@property
|
|
def resource_axes(self) -> Set[ResourceAxisName]:
|
|
return self.physical_resource_axes | self.loop_resource_axes
|
|
|
|
@property
|
|
def shape(self):
|
|
shape = self.physical_mesh.shape
|
|
shape.update(self.loops)
|
|
return shape
|
|
|
|
@property
|
|
def local_shape(self):
|
|
shape = self.physical_mesh.local_mesh.shape
|
|
shape.update(self.loops)
|
|
return shape
|
|
|
|
def __repr__(self):
|
|
return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})"
|
|
|
|
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ())
|
|
|
|
class _ThreadResourcesLocalState(threading.local):
|
|
|
|
def __init__(self):
|
|
self.env = EMPTY_ENV
|
|
|
|
thread_resources = _ThreadResourcesLocalState()
|
|
|
|
# TODO(yashkatariya): Merge this into `_ThreadResourcesLocalState` by
|
|
# maintaining a stack there and pointing `self.env` to `self.stack[-1]`.
|
|
# Do this after the old `mesh` context manager is deprecated.
|
|
class _ThreadLocalOldEnv(threading.local):
|
|
def __init__(self):
|
|
self.stack = [EMPTY_ENV]
|
|
|
|
_old_env = _ThreadLocalOldEnv()
|
|
|
|
|
|
def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
|
|
assert isinstance(aval, ShapedArray)
|
|
shape = list(aval.shape)
|
|
named_shape = dict(aval.named_shape)
|
|
for name, axis in in_axes.items():
|
|
assert shape[axis] % axis_sizes[name] == 0
|
|
assert name not in named_shape
|
|
named_shape[name] = axis_sizes[name]
|
|
shape[axis] //= axis_sizes[name]
|
|
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
|
|
|
def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval):
|
|
assert isinstance(aval, ShapedArray)
|
|
shape = list(aval.shape)
|
|
named_shape = dict(aval.named_shape)
|
|
for name, axis in out_axes.items():
|
|
shape[axis] *= axis_sizes[name]
|
|
named_shape.pop(name, None) # The name might be missing --- it's a broadcast.
|
|
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
|
|
|
|
|
class SPMDBatchTrace(batching.BatchTrace):
|
|
def get_axis_primitive_batcher(self, primitive, frame):
|
|
if primitive in spmd_primitive_batchers:
|
|
return partial(spmd_primitive_batchers[primitive],
|
|
frame.size, frame.name, frame.main_trace.trace_type)
|
|
return super().get_axis_primitive_batcher(primitive, frame)
|
|
|
|
|
|
spmd_primitive_batchers: Dict[core.Primitive, Callable] = {}
|
|
|
|
|
|
def vtile_by_mesh(fun: lu.WrappedFun,
|
|
mesh: Mesh,
|
|
in_axes: Sequence[ArrayMapping],
|
|
out_axes: Sequence[ArrayMapping]):
|
|
# We vectorize in reversed order, because vmap is often biased towards
|
|
# moving the batch axis to the front, and this way of stacking transforms
|
|
# will order the batch axes according to the mesh axis order.
|
|
# Not strictly necessary, but seems nicer than reversing it?
|
|
for name, size in reversed(mesh.shape.items()):
|
|
fun = batching.vtile(fun,
|
|
tuple(a.get(name, None) for a in in_axes),
|
|
tuple(a.get(name, None) for a in out_axes),
|
|
tile_size=size,
|
|
axis_name=name,
|
|
main_type=SPMDBatchTrace)
|
|
return fun
|
|
|
|
full_to_shard_p = core.Primitive('full_to_shard')
|
|
|
|
@full_to_shard_p.def_abstract_eval
|
|
def _full_to_shard_abstract_eval(x, axes, mesh, **_):
|
|
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
|
|
return tile_aval_nd(mesh.shape, axes, x)
|
|
|
|
def _manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[MeshAxisName], mesh: Mesh):
|
|
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
|
|
and all others as replicated.
|
|
"""
|
|
named_mesh_shape = mesh.shape
|
|
mesh_shape = list(named_mesh_shape.values())
|
|
axis_order = {axis: i for i, axis in enumerate(mesh.axis_names)}
|
|
|
|
manual_axes = list(sorted(manual_axes_set, key=str))
|
|
replicated_axes = list(axis for axis in mesh.axis_names if axis not in manual_axes_set)
|
|
|
|
tad_perm = ([axis_order[a] for a in replicated_axes] +
|
|
[axis_order[a] for a in manual_axes])
|
|
tad_shape = [1] * aval.ndim
|
|
tad_shape.append(int(np.prod([named_mesh_shape[a] for a in replicated_axes], dtype=int)))
|
|
tad_shape.append(int(np.prod([named_mesh_shape[a] for a in manual_axes], dtype=int)))
|
|
|
|
raw_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
|
|
proto = xc.OpSharding()
|
|
proto.type = xc.OpSharding.Type.OTHER
|
|
proto.tile_assignment_dimensions = tad_shape
|
|
proto.tile_assignment_devices = list(raw_mesh.transpose(tad_perm).reshape(tad_shape).flat)
|
|
proto.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL]
|
|
return proto
|
|
|
|
@partial(mlir.register_lowering, full_to_shard_p)
|
|
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]):
|
|
# TODO: Can we short-circuit for replicated values? Probably not.
|
|
aval_in, = ctx.avals_in
|
|
aval_out, = ctx.avals_out
|
|
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto()
|
|
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
|
|
sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=unspecified_dims)
|
|
manual_proto = _manual_proto(aval_in, manual_axes, mesh)
|
|
result_type, = mlir.aval_to_ir_types(aval_out)
|
|
return mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, unspecified_dims=unspecified_dims),
|
|
|
|
shard_to_full_p = core.Primitive('shard_to_full')
|
|
|
|
@shard_to_full_p.def_abstract_eval
|
|
def _shard_to_full_abstract_eval(x, axes, mesh, **_):
|
|
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
|
|
return untile_aval_nd(mesh.shape, axes, x)
|
|
|
|
@partial(mlir.register_lowering, shard_to_full_p)
|
|
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]):
|
|
aval_in, = ctx.avals_in
|
|
aval_out, = ctx.avals_out
|
|
manual_proto = _manual_proto(aval_in, manual_axes, mesh)
|
|
result_type, = mlir.aval_to_ir_types(aval_out)
|
|
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
|
|
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=unspecified_dims)
|
|
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_out, axes).sharding_proto()
|
|
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims),
|
|
|
|
@lu.transformation
|
|
def vtile_manual(manual_axes: FrozenSet[MeshAxisName],
|
|
mesh: Mesh,
|
|
in_axes: Sequence[ArrayMapping],
|
|
out_axes: Sequence[ArrayMapping],
|
|
*args):
|
|
tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh, manual_axes=manual_axes)
|
|
for arg, axes in zip(args, in_axes)]
|
|
tiled_outs = yield tiled_args, {}
|
|
outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh, manual_axes=manual_axes)
|
|
for out, axes in zip(tiled_outs, out_axes)]
|
|
yield outs
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TileVectorize:
|
|
pass
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TileManual:
|
|
manual_axes: FrozenSet[MeshAxisName]
|
|
|
|
TilingMethod = Union[TileVectorize, TileManual]
|
|
|
|
|
|
def _check_if_any_auto(shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource]]) -> bool:
|
|
for s in shardings:
|
|
if _is_auto(s):
|
|
return True
|
|
return False
|
|
|
|
|
|
class _UnconstrainedPartitionSingleton:
|
|
|
|
def __str__(self):
|
|
return "UNCONSTRAINED"
|
|
|
|
|
|
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
|
|
# which the user wants XLA to assign the best partitioning.
|
|
# TODO(yashkatariya): May rename to AUTO.
|
|
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
|
|
|
|
|
|
class PartitionSpec(tuple):
|
|
"""Tuple of integer specifying how a value should be partitioned.
|
|
|
|
Each integer corresponds to how many ways a dimension is partitioned. We
|
|
create a separate class for this so JAX's pytree utilities can distinguish it
|
|
from a tuple that should be treated as a pytree.
|
|
"""
|
|
|
|
def __init__(self, *partitions):
|
|
pass
|
|
|
|
def __new__(cls, *partitions):
|
|
return tuple.__new__(PartitionSpec, partitions)
|
|
|
|
def __repr__(self):
|
|
return "PartitionSpec%s" % tuple.__repr__(self)
|
|
|
|
def __reduce__(self):
|
|
return (PartitionSpec, tuple(self))
|
|
|
|
"""A sentinel value representing a dim is unconstrained."""
|
|
UNCONSTRAINED = _UNCONSTRAINED_PARTITION
|
|
|
|
|
|
def _get_backend_from_shardings(
|
|
shardings: Iterable[XLACompatibleSharding]) -> Tuple[xb.XlaBackend, XLACompatibleSharding]:
|
|
da = None
|
|
first_sharding = None
|
|
for s in shardings:
|
|
if _is_unspecified(s):
|
|
continue
|
|
da = s._device_assignment
|
|
first_sharding = s
|
|
break
|
|
assert len(da) > 0 # type: ignore
|
|
return xb.get_device_backend(da[0]), first_sharding # type: ignore
|
|
|
|
|
|
@profiler.annotate_function
|
|
def lower_sharding_computation(
|
|
fun: lu.WrappedFun,
|
|
api_name: str,
|
|
fun_name: str,
|
|
in_shardings: Sequence[XLACompatibleSharding],
|
|
out_shardings: Sequence[XLACompatibleSharding],
|
|
donated_invars: Sequence[bool],
|
|
global_in_avals: Sequence[core.ShapedArray],
|
|
in_is_global: Sequence[bool]):
|
|
# Device assignment across all inputs and outputs should be the same. This
|
|
# is checked in pjit.
|
|
backend, first_sharding = _get_backend_from_shardings(
|
|
it.chain(in_shardings, out_shardings)) # type: ignore
|
|
name_stack = new_name_stack(wrap_name(fun_name, api_name))
|
|
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
|
logging.log(log_priority,
|
|
"Compiling %s (%d) for with global shapes and types %s. "
|
|
"Argument mapping: %s.",
|
|
getattr(fun, '__name__', '<unnamed function>'), id(fun),
|
|
global_in_avals, in_shardings)
|
|
|
|
# 1. Trace to jaxpr and preprocess/verify it
|
|
in_jaxpr_avals = global_in_avals
|
|
|
|
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
|
|
"in {elapsed_time} sec"):
|
|
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
|
|
assert len(out_shardings) == len(out_jaxpr_avals)
|
|
|
|
global_out_avals = out_jaxpr_avals
|
|
|
|
_sanitize_mesh_jaxpr(jaxpr)
|
|
if not first_sharding.is_fully_addressable():
|
|
check_multihost_collective_allowlist(jaxpr)
|
|
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
|
|
|
# 2. Build up the HLO
|
|
tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU
|
|
in_op_shardings: Optional[List[Optional[xc.OpSharding]]]
|
|
out_op_shardings: Optional[List[Optional[xc.OpSharding]]]
|
|
axis_ctx: mlir.ShardingContext
|
|
|
|
in_op_shardings = [i._to_xla_op_sharding(aval.ndim)
|
|
for aval, i in safe_zip(global_in_avals, in_shardings)]
|
|
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
|
|
# [None, OpShardingProto] has the sharding annotations.
|
|
out_op_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
|
|
for aval, o in safe_zip(global_out_avals, out_shardings)]
|
|
replicated_args = [False] * len(in_jaxpr_avals)
|
|
axis_ctx = mlir.ShardingContext(first_sharding)
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
|
module: Union[str, xc.XlaComputation]
|
|
module_name = f"{api_name}_{fun_name}"
|
|
|
|
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
|
|
raise ValueError("Ordered effects not supported in mesh computations.")
|
|
unordered_effects = [eff for eff in closed_jaxpr.effects
|
|
if eff not in core.ordered_effects]
|
|
lowering_result = mlir.lower_jaxpr_to_module(
|
|
module_name, closed_jaxpr, unordered_effects, [], backend.platform,
|
|
axis_ctx, name_stack, donated_invars, replicated_args=replicated_args,
|
|
arg_shardings=in_op_shardings, result_shardings=out_op_shardings)
|
|
module, keepalive, host_callbacks = (
|
|
lowering_result.module, lowering_result.keepalive,
|
|
lowering_result.host_callbacks)
|
|
|
|
return MeshComputation(
|
|
str(name_stack),
|
|
module,
|
|
donated_invars,
|
|
mesh=None,
|
|
global_in_avals=global_in_avals,
|
|
global_out_avals=global_out_avals,
|
|
in_shardings=in_shardings,
|
|
out_shardings=out_shardings,
|
|
spmd_lowering=True,
|
|
tuple_args=tuple_args,
|
|
in_is_global=in_is_global,
|
|
auto_spmd_lowering=False,
|
|
unordered_effects=unordered_effects,
|
|
host_callbacks=host_callbacks,
|
|
keepalive=keepalive)
|
|
|
|
|
|
@profiler.annotate_function
|
|
def lower_mesh_computation(
|
|
fun: lu.WrappedFun,
|
|
api_name: str,
|
|
fun_name: str,
|
|
mesh: Mesh,
|
|
in_shardings: Sequence[Union[MeshPspecSharding, _AUTOAxisResource]],
|
|
out_shardings: Sequence[Union[MeshPspecSharding, _AUTOAxisResource,
|
|
_UnspecifiedValue]],
|
|
donated_invars: Sequence[bool],
|
|
spmd_lowering: bool,
|
|
global_in_avals: Sequence[core.ShapedArray],
|
|
tiling_method: Optional[TilingMethod],
|
|
in_is_global: Sequence[bool]):
|
|
assert not mesh.empty
|
|
backend = xb.get_device_backend(mesh.devices.flat[0])
|
|
name_stack = new_name_stack(wrap_name(fun_name, api_name))
|
|
|
|
auto_spmd_lowering = _check_if_any_auto(in_shardings + out_shardings) # type: ignore
|
|
|
|
if auto_spmd_lowering and not spmd_lowering:
|
|
raise ValueError('Enable spmd_lowering to use auto spmd lowering.')
|
|
|
|
global_axis_sizes = mesh.shape
|
|
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
|
logging.log(log_priority,
|
|
"Compiling %s (%d) for %s mesh with global shapes and types %s. "
|
|
"Argument mapping: %s.",
|
|
getattr(fun, '__name__', '<unnamed function>'), id(fun),
|
|
tuple(global_axis_sizes.items()), global_in_avals,
|
|
in_shardings)
|
|
|
|
# 1. Trace to jaxpr and preprocess/verify it
|
|
if spmd_lowering:
|
|
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
|
|
if tiling_method is not None:
|
|
if isinstance(tiling_method, TileVectorize):
|
|
tiling_transform = vtile_by_mesh
|
|
elif isinstance(tiling_method, TileManual):
|
|
tiling_transform = lambda f, *args: vtile_manual(f, tiling_method.manual_axes, *args) # type: ignore
|
|
else:
|
|
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
|
|
assert not callable(out_shardings)
|
|
assert not auto_spmd_lowering
|
|
# This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which
|
|
# is why `.spec` can be accessed.
|
|
fun = tiling_transform(
|
|
fun, mesh, [_get_array_mapping(i.spec) for i in in_shardings], # type: ignore
|
|
[_get_array_mapping(o.spec) for o in out_shardings]) # type: ignore
|
|
in_jaxpr_avals = global_in_avals
|
|
else:
|
|
assert isinstance(tiling_method, TileVectorize)
|
|
assert not auto_spmd_lowering
|
|
# In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is
|
|
# why `.spec` can be accessed.
|
|
in_tiled_avals = [tile_aval_nd(global_axis_sizes, _get_array_mapping(i.spec), aval) # type: ignore
|
|
for aval, i in safe_zip(global_in_avals, in_shardings)]
|
|
in_jaxpr_avals = in_tiled_avals
|
|
with core.extend_axis_env_nd(mesh.shape.items()):
|
|
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
|
|
"in {elapsed_time} sec"):
|
|
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
|
|
assert len(out_shardings) == len(out_jaxpr_avals)
|
|
if spmd_lowering:
|
|
global_out_avals = out_jaxpr_avals
|
|
else:
|
|
# In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is
|
|
# why `.spec` can be accessed.
|
|
global_out_avals = [untile_aval_nd(global_axis_sizes, _get_array_mapping(o.spec), aval) # type: ignore
|
|
for aval, o in safe_zip(out_jaxpr_avals, out_shardings)]
|
|
_sanitize_mesh_jaxpr(jaxpr)
|
|
if mesh.is_multi_process:
|
|
check_multihost_collective_allowlist(jaxpr)
|
|
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
|
|
|
# 2. Build up the HLO
|
|
tuple_args = len(in_jaxpr_avals) > 100 # pass long arg lists as tuple for TPU
|
|
in_partitions: Optional[List[Optional[xc.OpSharding]]]
|
|
out_partitions: Optional[List[Optional[xc.OpSharding]]]
|
|
axis_ctx: mlir.AxisContext
|
|
if spmd_lowering:
|
|
in_partitions = [
|
|
None if _is_auto(i) else i._to_xla_op_sharding(aval.ndim)
|
|
for aval, i in safe_zip(global_in_avals, in_shardings)
|
|
]
|
|
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
|
|
# [None, OpShardingProto] has the sharding annotations.
|
|
out_partitions = [
|
|
None if _is_auto(o) or _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
|
|
for aval, o in safe_zip(global_out_avals, out_shardings)
|
|
]
|
|
replicated_args = [False] * len(in_jaxpr_avals)
|
|
axis_ctx = mlir.SPMDAxisContext(mesh)
|
|
axis_env = axis_ctx.axis_env
|
|
else:
|
|
replicated_args = [not _get_array_mapping(i.spec) for i in in_shardings] # type: ignore
|
|
in_partitions = None
|
|
out_partitions = None
|
|
axis_env = xla.AxisEnv(nreps=mesh.size,
|
|
names=tuple(global_axis_sizes.keys()),
|
|
sizes=tuple(global_axis_sizes.values()))
|
|
axis_ctx = mlir.ReplicaAxisContext(axis_env)
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
|
module: Union[str, xc.XlaComputation]
|
|
module_name = f"{api_name}_{fun_name}"
|
|
with core.extend_axis_env_nd(mesh.shape.items()):
|
|
if any(eff in core.ordered_effects for eff in closed_jaxpr.effects):
|
|
raise ValueError("Ordered effects not supported in mesh computations.")
|
|
unordered_effects = [eff for eff in closed_jaxpr.effects
|
|
if eff not in core.ordered_effects]
|
|
lowering_result = mlir.lower_jaxpr_to_module(
|
|
module_name, closed_jaxpr, unordered_effects, [], backend.platform,
|
|
axis_ctx, name_stack, donated_invars, replicated_args=replicated_args,
|
|
arg_shardings=in_partitions, result_shardings=out_partitions)
|
|
module, keepalive, host_callbacks = (
|
|
lowering_result.module, lowering_result.keepalive,
|
|
lowering_result.host_callbacks)
|
|
return MeshComputation(
|
|
str(name_stack),
|
|
module,
|
|
donated_invars,
|
|
mesh=mesh,
|
|
global_in_avals=global_in_avals,
|
|
global_out_avals=global_out_avals,
|
|
in_shardings=in_shardings,
|
|
out_shardings=out_shardings,
|
|
spmd_lowering=spmd_lowering,
|
|
tuple_args=tuple_args,
|
|
in_is_global=in_is_global,
|
|
auto_spmd_lowering=auto_spmd_lowering,
|
|
unordered_effects=unordered_effects,
|
|
host_callbacks=host_callbacks,
|
|
keepalive=keepalive)
|
|
|
|
|
|
class MeshComputation(stages.XlaLowering):
|
|
_hlo: Union[ir.Module, xc.XlaComputation]
|
|
_executable: Optional[MeshExecutable]
|
|
|
|
def __init__(self, name: str, hlo: Union[ir.Module, xc.XlaComputation],
|
|
donated_invars: Sequence[bool], **compile_args):
|
|
self._name = name
|
|
self._hlo = hlo
|
|
self._donated_invars = donated_invars
|
|
self.compile_args = compile_args
|
|
self._executable = None
|
|
|
|
# -- stages.XlaLowering overrides
|
|
|
|
def hlo(self) -> xc.XlaComputation:
|
|
# this is a method for api consistency with dispatch.XlaComputation
|
|
if isinstance(self._hlo, xc.XlaComputation):
|
|
return self._hlo
|
|
return xe.mlir.mlir_module_to_xla_computation(
|
|
mlir.module_to_string(self._hlo),
|
|
use_tuple_args=self.compile_args["tuple_args"])
|
|
|
|
def mhlo(self) -> ir.Module:
|
|
if isinstance(self._hlo, xc.XlaComputation):
|
|
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
|
with mlir.make_ir_context():
|
|
return ir.Module.parse(module_str)
|
|
return self._hlo
|
|
|
|
def compile(self,
|
|
_allow_propagation_to_outputs : bool = False,
|
|
_allow_compile_replicated : bool = True) -> MeshExecutable:
|
|
if self._executable is None:
|
|
self._executable = MeshExecutable.from_hlo(
|
|
self._name, self._hlo, **self.compile_args,
|
|
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
|
|
_allow_compile_replicated=_allow_compile_replicated) # type: ignore
|
|
return self._executable
|
|
|
|
|
|
def _get_input_metadata(
|
|
global_in_avals: Sequence[ShapedArray],
|
|
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
|
|
) -> Tuple[Sequence[XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
|
|
Sequence[ShapedArray]]:
|
|
from jax.experimental.sharding import MeshPspecSharding
|
|
|
|
shardings, input_indices, input_avals = [], [], []
|
|
for gaval, i, is_global in safe_zip(global_in_avals, in_shardings, in_is_global):
|
|
if is_global:
|
|
aval = gaval
|
|
sharding = i
|
|
else:
|
|
assert isinstance(i, MeshPspecSharding)
|
|
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
|
|
sharding = MeshPspecSharding(i.mesh.local_mesh, i.spec)
|
|
|
|
# We special case this logic to support fully replicated values because
|
|
# the mesh is global mesh and the indices returned by `spec_to_indices` will
|
|
# represent index for each device in the global mesh. But here we want
|
|
# indices for the local devices of the global mesh.
|
|
proto = sharding._to_xla_op_sharding(aval.ndim)
|
|
if is_op_sharding_replicated(proto):
|
|
index = tuple((slice(None),) * aval.ndim for _ in range(len(sharding.addressable_devices)))
|
|
else:
|
|
index = tuple(sharding.devices_indices_map(aval.shape).values())
|
|
|
|
shardings.append(sharding)
|
|
input_indices.append(index)
|
|
input_avals.append(aval)
|
|
return shardings, input_indices, input_avals
|
|
|
|
|
|
def _get_op_sharding_shardings_from_executable(
|
|
xla_executable, device_assignment, num_in_avals, num_out_avals):
|
|
from jax.experimental import pjit
|
|
from jax.experimental.sharding import OpShardingSharding, SingleDeviceSharding
|
|
|
|
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
|
|
|
|
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
|
# Hence the op shardings will not be set on the `hlo_module`. In that case,
|
|
# just return SingleDeviceShardings since we know the computation is running
|
|
# only on 1 device.
|
|
if not in_op_shardings and not out_op_shardings and len(device_assignment) == 1:
|
|
return ([SingleDeviceSharding(device_assignment[0]) for _ in range(num_in_avals)],
|
|
[SingleDeviceSharding(device_assignment[0]) for _ in range(num_out_avals)])
|
|
|
|
return ([OpShardingSharding(device_assignment, i) for i in in_op_shardings],
|
|
[OpShardingSharding(device_assignment, o) for o in out_op_shardings])
|
|
|
|
|
|
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
|
|
# without mesh.
|
|
def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):
|
|
from jax.experimental import pjit
|
|
from jax.experimental.sharding import MeshPspecSharding
|
|
|
|
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
|
|
return ([MeshPspecSharding(mesh, i) for i in in_pspec],
|
|
[MeshPspecSharding(mesh, o) for o in out_pspec])
|
|
|
|
|
|
class MeshExecutable(stages.XlaExecutable):
|
|
__slots__ = ['xla_executable', 'unsafe_call', '_input_avals',
|
|
'_in_shardings', '_out_shardings', '_auto_spmd_lowering']
|
|
|
|
def __init__(self, xla_executable, unsafe_call, input_avals,
|
|
in_shardings, out_shardings, auto_spmd_lowering):
|
|
self.xla_executable = xla_executable
|
|
self.unsafe_call = unsafe_call
|
|
# input_avals is a list of global and local avals. Aval is global if input
|
|
# is a GDA else local.
|
|
self._input_avals = input_avals
|
|
self._in_shardings = in_shardings
|
|
self._out_shardings = out_shardings
|
|
self._auto_spmd_lowering = auto_spmd_lowering
|
|
|
|
@staticmethod
|
|
def from_hlo(name: str,
|
|
computation: Union[ir.Module, xc.XlaComputation],
|
|
# TODO(yashkatariya): Remove `mesh` from here once AUTO can work
|
|
# without mesh.
|
|
mesh: Optional[Mesh],
|
|
global_in_avals: Sequence[ShapedArray],
|
|
global_out_avals: Sequence[ShapedArray],
|
|
in_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource]],
|
|
out_shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource,
|
|
_UnspecifiedValue]],
|
|
spmd_lowering: bool,
|
|
tuple_args: bool,
|
|
in_is_global: Sequence[bool],
|
|
auto_spmd_lowering: bool,
|
|
_allow_propagation_to_outputs: bool,
|
|
_allow_compile_replicated: bool,
|
|
unordered_effects: List[core.Effect],
|
|
host_callbacks: List[Any],
|
|
keepalive: Any) -> MeshExecutable:
|
|
if auto_spmd_lowering:
|
|
assert mesh is not None
|
|
assert not mesh.empty
|
|
backend = xb.get_device_backend(mesh.devices.flat[0])
|
|
else:
|
|
backend, first_sharding = _get_backend_from_shardings(
|
|
it.chain(in_shardings, out_shardings)) # type: ignore
|
|
|
|
dev: np.ndarray
|
|
if auto_spmd_lowering:
|
|
assert mesh is not None and spmd_lowering
|
|
dev = mesh.devices
|
|
num_replicas, num_partitions = 1, mesh.size
|
|
else:
|
|
dev = np.array(first_sharding._device_assignment)
|
|
if spmd_lowering:
|
|
num_replicas, num_partitions = 1, dev.size
|
|
else:
|
|
num_replicas, num_partitions = dev.size, 1
|
|
device_assignment = dev.reshape((num_replicas, num_partitions))
|
|
|
|
compile_options = xb.get_compile_options(
|
|
num_replicas=num_replicas,
|
|
num_partitions=num_partitions,
|
|
device_assignment=device_assignment,
|
|
use_spmd_partitioning=spmd_lowering,
|
|
use_auto_spmd_partitioning=auto_spmd_lowering,
|
|
)
|
|
if auto_spmd_lowering:
|
|
assert mesh is not None
|
|
compile_options.executable_build_options.auto_spmd_partitioning_mesh_shape = \
|
|
list(mesh.shape.values())
|
|
compile_options.executable_build_options.auto_spmd_partitioning_mesh_ids = \
|
|
_get_logical_mesh_ids(list(mesh.shape.values())).reshape(-1)
|
|
compile_options.parameter_is_tupled_arguments = tuple_args
|
|
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \
|
|
_allow_propagation_to_outputs
|
|
|
|
if _allow_compile_replicated and hasattr(backend, "compile_replicated"):
|
|
assert not auto_spmd_lowering
|
|
in_shardings, input_indices, input_avals = _get_input_metadata(
|
|
global_in_avals, in_shardings, in_is_global) # type: ignore
|
|
handle_outs = global_avals_to_results_handler(
|
|
global_out_avals, out_shardings) # type: ignore # arg-type
|
|
unsafe_call = backend.compile_replicated(
|
|
computation, compile_options, input_avals, input_indices,
|
|
in_shardings, handle_outs)
|
|
xla_executable = None
|
|
else:
|
|
with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} "
|
|
"in {elapsed_time} sec"):
|
|
xla_executable = dispatch.compile_or_get_cached(
|
|
backend, computation, compile_options, host_callbacks)
|
|
|
|
if auto_spmd_lowering:
|
|
assert mesh is not None
|
|
in_shardings, out_shardings = _get_mesh_pspec_shardings_from_executable(
|
|
xla_executable, mesh)
|
|
elif out_shardings and all(_is_unspecified(o) for o in out_shardings):
|
|
assert mesh is None
|
|
in_shardings, out_shardings = _get_op_sharding_shardings_from_executable(
|
|
xla_executable, first_sharding._device_assignment,
|
|
len(global_in_avals), len(global_out_avals))
|
|
|
|
in_shardings, input_indices, input_avals = _get_input_metadata(
|
|
global_in_avals, in_shardings, in_is_global) # type: ignore
|
|
handle_outs = global_avals_to_results_handler(
|
|
global_out_avals, out_shardings) # type: ignore # arg-type
|
|
handle_args = InputsHandler(xla_executable.local_devices(), in_shardings,
|
|
input_indices)
|
|
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args,
|
|
handle_outs, unordered_effects, keepalive)
|
|
|
|
return MeshExecutable(xla_executable, unsafe_call, input_avals,
|
|
in_shardings, out_shardings, auto_spmd_lowering)
|
|
|
|
# -- stages.XlaExecutable overrides
|
|
|
|
def xla_extension_executable(self):
|
|
return self.xla_executable
|
|
|
|
def call(self, *args):
|
|
arg_avals = map(xla.abstractify, args)
|
|
ref_avals = self._input_avals
|
|
dispatch.check_arg_avals_for_call(ref_avals, arg_avals)
|
|
# Check the GDA sharding and the input sharding.
|
|
_check_gda_or_array_xla_sharding_match(args, self._in_shardings)
|
|
return self.unsafe_call(*args)
|
|
|
|
|
|
@lru_cache()
|
|
def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
|
|
from jax.experimental.sharding import MeshPspecSharding
|
|
return MeshPspecSharding(mesh, pspec, parsed_pspec)
|
|
|
|
|
|
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
|
|
from jax.experimental.global_device_array import GlobalDeviceArray
|
|
from jax.experimental.array import Array
|
|
|
|
@lru_cache(maxsize=4096)
|
|
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim):
|
|
if not are_op_shardings_equal(
|
|
arg_sharding._to_xla_op_sharding(ndim),
|
|
in_xla_sharding._to_xla_op_sharding(ndim)):
|
|
raise ValueError(
|
|
f"{arg_type} sharding does not match the input sharding. "
|
|
f"Got {arg_type} sharding: {arg_sharding} and "
|
|
f"xla sharding: {in_xla_sharding}")
|
|
|
|
for arg, xs in safe_zip(args, in_xla_shardings):
|
|
if not isinstance(arg, (GlobalDeviceArray, Array)):
|
|
continue
|
|
if isinstance(arg, GlobalDeviceArray):
|
|
_cached_check(_create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes), xs,
|
|
'GDA', arg.ndim)
|
|
else:
|
|
_cached_check(arg.sharding, xs, 'Array', arg.ndim)
|
|
|
|
|
|
def _get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
|
# Import here to avoid cyclic import error when importing gda in pjit.py.
|
|
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
|
|
|
|
parsed_pspec, _, _, _ = _prepare_axis_resources(pspec, "pspec to array_mapping")
|
|
return get_array_mapping(parsed_pspec)
|
|
|
|
|
|
def are_op_shardings_equal(op1, op2) -> bool:
|
|
if id(op1) == id(op2):
|
|
return True
|
|
if xla_extension_version >= 81:
|
|
if is_op_sharding_replicated(op1) and is_op_sharding_replicated(op2):
|
|
return True
|
|
return xc.HloSharding.from_proto(op1) == xc.HloSharding.from_proto(op2)
|
|
else:
|
|
if op1.type == xc.OpSharding.Type.TUPLE:
|
|
return all(are_op_shardings_equal(i, j)
|
|
for i, j in safe_zip(op1.tuple_shardings, op2.tuple_shardings))
|
|
return (op1.type == op2.type and
|
|
op1.tile_assignment_dimensions == op2.tile_assignment_dimensions and
|
|
op1.tile_assignment_devices == op2.tile_assignment_devices and
|
|
op1.last_tile_dims == op2.last_tile_dims and
|
|
op1.replicate_on_last_tile_dim == op2.replicate_on_last_tile_dim)
|
|
|
|
|
|
def is_op_sharding_replicated(op: xc.OpSharding) -> bool:
|
|
if xla_extension_version >= 82:
|
|
if len(op.tile_assignment_devices) == 1:
|
|
return True
|
|
return xc.HloSharding.from_proto(op).is_replicated()
|
|
else:
|
|
return op.type == xc.OpSharding.Type.REPLICATED
|
|
|
|
|
|
_forbidden_primitives = {
|
|
'xla_pmap': 'pmap',
|
|
'sharded_call': 'sharded_jit',
|
|
}
|
|
def _sanitize_mesh_jaxpr(jaxpr):
|
|
if isinstance(jaxpr, core.ClosedJaxpr):
|
|
jaxpr = jaxpr.jaxpr
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive.name in _forbidden_primitives:
|
|
raise RuntimeError(f"Nesting {_forbidden_primitives[eqn.primitive.name]} "
|
|
f"inside xmaps not supported!")
|
|
core.traverse_jaxpr_params(_sanitize_mesh_jaxpr, eqn.params)
|
|
|
|
|
|
custom_resource_typing_rules: Dict[core.Primitive, Callable] = {}
|
|
|
|
def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
|
|
if isinstance(jaxpr, core.ClosedJaxpr):
|
|
jaxpr = jaxpr.jaxpr
|
|
def _check_aval(aval, what_thunk):
|
|
if not hasattr(aval, 'named_shape'):
|
|
return
|
|
resource_to_axis = {}
|
|
for axis in aval.named_shape:
|
|
for resource in axis_resources[axis]:
|
|
if resource in resource_to_axis:
|
|
other_axis = resource_to_axis[resource]
|
|
axis, other_axis = sorted([str(axis), str(other_axis)])
|
|
raise JAXTypeError(
|
|
f"Axes `{axis}` and `{other_axis}` are both mapped to the "
|
|
f"resource `{resource}`, but they coincide in the named_shape "
|
|
f"of {what_thunk()}")
|
|
resource_to_axis[resource] = axis
|
|
|
|
what_thunk = lambda: (f"an input to {what_jaxpr_thunk()}")
|
|
for v in jaxpr.constvars:
|
|
_check_aval(v.aval, what_thunk)
|
|
for v in jaxpr.invars:
|
|
_check_aval(v.aval, what_thunk)
|
|
what_thunk = lambda: (f"a value returned from a primitive {eqn.primitive} created "
|
|
f"at {source_info_util.summarize(eqn.source_info)}")
|
|
rec_what_jaxpr_thunk = lambda: (f"a primitive {eqn.primitive} created at"
|
|
f"{source_info_util.summarize(eqn.source_info)}")
|
|
for eqn in jaxpr.eqns:
|
|
typing_rule = custom_resource_typing_rules.get(eqn.primitive, None)
|
|
if typing_rule:
|
|
typing_rule([v.aval for v in eqn.invars], eqn.params, eqn.source_info,
|
|
resource_env, axis_resources)
|
|
else:
|
|
core.traverse_jaxpr_params(partial(resource_typecheck,
|
|
resource_env=resource_env,
|
|
axis_resources=axis_resources,
|
|
what_jaxpr_thunk=rec_what_jaxpr_thunk),
|
|
eqn.params)
|
|
for v in eqn.outvars:
|
|
_check_aval(v.aval, what_thunk)
|
|
|
|
|
|
def _make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes):
|
|
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
|
|
sharding = [_UNSHARDED_INSTANCE] * num_dimensions
|
|
next_sharded_axis = 0
|
|
# NOTE: sorted is stable, which is important when multiple resources
|
|
# map to the same axis.
|
|
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
|
|
chunked = sharding[axis]
|
|
if isinstance(chunked, NoSharding):
|
|
chunked = Chunked([])
|
|
sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]])
|
|
assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \
|
|
"Value mapped to the same mesh axis twice"
|
|
mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis)
|
|
next_sharded_axis += 1
|
|
return ShardingSpec(sharding, mesh_mapping)
|
|
|
|
|
|
def new_mesh_sharding_specs(axis_sizes, axis_names):
|
|
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
|
|
return partial(_make_sharding_spec, axis_sizes, mesh_axis_pos)
|
|
|
|
|
|
def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False):
|
|
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
|
|
# NOTE: This takes in the non-sharded avals!
|
|
def mk_sharding_spec(aval, aval_axes):
|
|
if aval is core.abstract_token:
|
|
assert not aval_axes
|
|
return ShardingSpec([], [Replicated(axis_size) for axis_size in axis_sizes.values()])
|
|
aval_shape = list(aval.shape)
|
|
# NOTE: sorted is stable, which is important when multiple resources
|
|
# map to the same axis.
|
|
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
|
|
if not allow_uneven_axes:
|
|
if aval_shape[axis] % axis_sizes[name] != 0:
|
|
raise ValueError(
|
|
f'The aval shape on dimension {axis} is {aval_shape[axis]} and '
|
|
f'the size of axis {name} is {axis_sizes[name]}. The aval shape % '
|
|
'axis size should be zero but got '
|
|
f'{aval_shape[axis] % axis_sizes[name]}')
|
|
aval_shape[axis] //= axis_sizes[name]
|
|
return _make_sharding_spec(axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes)
|
|
return mk_sharding_spec
|
|
|
|
|
|
@contextmanager
|
|
def maybe_extend_axis_env(*args, **kwargs):
|
|
with core.extend_axis_env(*args, **kwargs):
|
|
yield
|
|
|
|
class DynamicAxisEnvFrame:
|
|
__slots__ = ["name", "pmap_trace", "hard_size"]
|
|
def __init__(self, name, pmap_trace, hard_size):
|
|
self.name = name
|
|
self.pmap_trace = pmap_trace
|
|
self.hard_size = hard_size
|
|
|
|
class DynamicAxisEnv(list):
|
|
def __contains__(self, axis_name):
|
|
return axis_name in (frame.name for frame in self)
|
|
|
|
def __getitem__(self, axis_name):
|
|
if axis_name not in self:
|
|
raise NameError(f"unbound axis name: {axis_name}")
|
|
for frame in reversed(self):
|
|
if frame.name == axis_name:
|
|
return frame
|
|
|
|
raise AssertionError
|
|
|
|
@property
|
|
def sizes(self):
|
|
return tuple(frame.hard_size for frame in self)
|
|
|
|
@property
|
|
def nreps(self):
|
|
return prod(frame.hard_size for frame in self)
|
|
|
|
class _ThreadLocalState(threading.local):
|
|
def __init__(self):
|
|
self.dynamic_axis_env = DynamicAxisEnv()
|
|
|
|
_thread_local_state = _ThreadLocalState()
|
|
|
|
def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool=False) -> List[xb.xla_client.Buffer]:
|
|
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
|
|
if replicate:
|
|
return list(it.chain.from_iterable(dispatch.device_put(x, device) for device in devices))
|
|
else:
|
|
return list(it.chain.from_iterable(dispatch.device_put(val, device) for val, device in safe_zip(x, devices)))
|
|
|
|
def _set_aval(val):
|
|
if val.aval is None:
|
|
val.aval = core.ShapedArray(val.shape, val.dtype)
|
|
return val
|