2023-02-06 14:28:36 -08:00
|
|
|
|
# Copyright 2018 The JAX Authors.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
"""Implementation of pmap and related functionality."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import enum
|
2023-03-10 10:07:37 -08:00
|
|
|
|
from contextlib import contextmanager
|
2023-04-10 10:15:08 -07:00
|
|
|
|
from collections import defaultdict, namedtuple
|
2023-02-06 14:28:36 -08:00
|
|
|
|
import dataclasses
|
|
|
|
|
from functools import partial, lru_cache, cached_property
|
|
|
|
|
import itertools as it
|
|
|
|
|
import logging
|
2023-02-28 12:40:30 -08:00
|
|
|
|
import math
|
2023-06-23 15:11:37 -07:00
|
|
|
|
from typing import (Any, Callable, NamedTuple, Optional, Sequence, Union,
|
2023-07-11 12:42:32 -07:00
|
|
|
|
Iterable, cast, TypeVar)
|
2023-02-06 22:51:50 -08:00
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
import jax
|
|
|
|
|
from jax.errors import JAXTypeError
|
2023-04-12 08:49:07 -07:00
|
|
|
|
from jax.tree_util import tree_map
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
from jax._src import api_util
|
|
|
|
|
from jax._src import core
|
|
|
|
|
from jax._src import dispatch
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import dtypes
|
2023-02-01 17:50:00 -08:00
|
|
|
|
from jax._src import effects
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import linear_util as lu
|
2023-04-10 10:15:08 -07:00
|
|
|
|
from jax._src import mesh as mesh_lib
|
2023-04-06 08:31:47 -07:00
|
|
|
|
from jax._src import op_shardings
|
2023-04-06 09:48:14 -07:00
|
|
|
|
from jax._src import sharding_specs
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax._src import profiler
|
2023-03-13 08:49:39 -07:00
|
|
|
|
from jax._src import sharding_impls
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import source_info_util
|
|
|
|
|
from jax._src import stages
|
|
|
|
|
from jax._src import util
|
2023-02-28 07:01:14 -08:00
|
|
|
|
from jax._src import xla_bridge as xb
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax._src.abstract_arrays import array_types
|
|
|
|
|
from jax._src.config import config
|
2023-03-20 09:09:15 -07:00
|
|
|
|
from jax._src.core import ShapedArray
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src.interpreters import ad
|
2023-02-09 15:11:20 -08:00
|
|
|
|
from jax._src.interpreters import batching
|
2023-03-27 13:29:59 -07:00
|
|
|
|
from jax._src.interpreters import partial_eval as pe
|
2023-02-09 15:11:20 -08:00
|
|
|
|
from jax._src.interpreters import mlir
|
2023-02-07 15:00:56 -08:00
|
|
|
|
from jax._src.interpreters import xla
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax._src.lib import xla_client as xc
|
|
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
|
from jax._src.lib.mlir.dialects import hlo
|
2023-04-06 11:42:45 -07:00
|
|
|
|
from jax._src.partition_spec import PartitionSpec
|
2023-04-10 10:15:08 -07:00
|
|
|
|
from jax._src.sharding_impls import (
|
|
|
|
|
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
|
2023-05-20 22:59:52 -07:00
|
|
|
|
AUTO, UnspecifiedValue, UNSPECIFIED,
|
2023-04-10 10:15:08 -07:00
|
|
|
|
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
|
|
|
|
|
)
|
2023-02-28 12:40:30 -08:00
|
|
|
|
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
|
2023-04-06 09:48:14 -07:00
|
|
|
|
wrap_name, tuple_delete, distributed_debug_log,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
unzip2, HashableFunction, weakref_lru_cache)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Built in Python lists don't support weak refs but subclasses of lists do.
|
|
|
|
|
class WeakRefList(list):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xe = xc._xla
|
|
|
|
|
|
|
|
|
|
unsafe_map, map = map, safe_map # type: ignore
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
Index = Union[int, slice, tuple[Union[int, slice], ...]]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-06 09:48:14 -07:00
|
|
|
|
NoSharding = sharding_specs.NoSharding
|
|
|
|
|
Chunked = sharding_specs.Chunked
|
|
|
|
|
Unstacked = sharding_specs.Unstacked
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-06 09:48:14 -07:00
|
|
|
|
ShardedAxis = sharding_specs.ShardedAxis
|
|
|
|
|
Replicated = sharding_specs.Replicated
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
|
2023-04-10 10:15:08 -07:00
|
|
|
|
Mesh = mesh_lib.Mesh
|
|
|
|
|
MeshAxisName = sharding_impls.MeshAxisName
|
2023-02-06 14:28:36 -08:00
|
|
|
|
MeshDimAssignment = Union[ShardedAxis, Replicated]
|
2023-04-06 09:48:14 -07:00
|
|
|
|
ShardingSpec = sharding_specs.ShardingSpec
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
### util
|
|
|
|
|
|
|
|
|
|
def identity(x): return x
|
|
|
|
|
|
2023-03-02 13:28:25 -08:00
|
|
|
|
def shard_arg(arg, devices, arg_indices, sharding):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""Returns a list of size len(devices) containing per-device buffers.
|
|
|
|
|
|
|
|
|
|
For the C++ pmap path, we fallback to Python (this function) to shard
|
|
|
|
|
arguments that are not supported by the C++ `ShardArg`.
|
|
|
|
|
|
2023-04-21 06:44:09 -07:00
|
|
|
|
Args:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
arg: The Python argument.
|
|
|
|
|
devices: The list of devices to shard over.
|
|
|
|
|
arg_indices: A list of `len(devices)` indices to use to shard the argument.
|
|
|
|
|
"""
|
2023-03-20 14:17:25 -07:00
|
|
|
|
arg = xla.canonicalize_dtype(arg)
|
|
|
|
|
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
2023-03-02 13:28:25 -08:00
|
|
|
|
def shard_args(
|
|
|
|
|
devices: Sequence[xb.xla_client.Device],
|
|
|
|
|
indices: Sequence[Sequence[Index]],
|
2023-03-13 08:49:39 -07:00
|
|
|
|
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
2023-03-02 13:28:25 -08:00
|
|
|
|
args,
|
2023-03-31 11:41:49 -07:00
|
|
|
|
) -> Sequence[jax.Array]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""Shard each argument data array along its leading axis.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
devices: sequence of Devices mapping replica index to a physical device.
|
|
|
|
|
indices: sequence of the same length as `args` describing how each arg
|
|
|
|
|
should be sharded/replicated across `devices`. Each element in `indices`
|
|
|
|
|
is the same length as `devices`.
|
|
|
|
|
args: a sequence of JaxTypes representing arguments to be sharded according
|
|
|
|
|
to `indices` and placed on `devices`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A list of length matching args, containing lists of per-device buffers
|
|
|
|
|
for each argument.
|
|
|
|
|
"""
|
2023-03-02 13:28:25 -08:00
|
|
|
|
return [shard_arg(arg, devices, indices[i], shardings[i])
|
|
|
|
|
for i, arg in enumerate(args)]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
shard_arg_handlers: dict[Any, Callable[[Any, Any, Any, Any], Any]] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-02 13:28:25 -08:00
|
|
|
|
def _shard_token(x, devices, indices, sharding):
|
2023-03-15 17:08:21 -07:00
|
|
|
|
zeros = np.zeros((), dtype=np.dtype(np.bool_))
|
|
|
|
|
aval = api_util.shaped_abstractify(zeros)
|
|
|
|
|
out = batched_device_put(aval, sharding, [zeros for i in indices], devices)
|
|
|
|
|
return out
|
2023-02-06 14:28:36 -08:00
|
|
|
|
shard_arg_handlers[core.Token] = _shard_token
|
|
|
|
|
|
2023-03-02 13:28:25 -08:00
|
|
|
|
def _masked_array_error(x, devices, indices, sharding):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
|
|
|
|
"Use arr.filled() to convert the value to a standard numpy array.")
|
|
|
|
|
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error
|
|
|
|
|
|
2023-03-02 13:28:25 -08:00
|
|
|
|
def _shard_array(x, devices, indices, sharding):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if x.dtype == dtypes.float0:
|
|
|
|
|
x = np.zeros(x.shape, dtype=np.dtype(bool))
|
2023-03-15 17:08:21 -07:00
|
|
|
|
aval = api_util.shaped_abstractify(x)
|
|
|
|
|
out = batched_device_put(aval, sharding, [x[i] for i in indices], devices)
|
|
|
|
|
return out
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for _t in array_types:
|
|
|
|
|
shard_arg_handlers[_t] = _shard_array
|
|
|
|
|
|
2023-03-02 13:28:25 -08:00
|
|
|
|
def shard_device_array(x, devices, indices, sharding):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
start_indices, limit_indices, removed_dims = unzip3(
|
2023-02-07 11:16:01 -08:00
|
|
|
|
as_slice_indices(x, idx) for idx in indices)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
2023-03-15 17:08:21 -07:00
|
|
|
|
aval = api_util.shaped_abstractify(x)
|
|
|
|
|
out = batched_device_put(aval, sharding, shards, devices)
|
|
|
|
|
return out
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-28 18:30:36 -07:00
|
|
|
|
def batched_device_put(aval: core.ShapedArray,
|
2023-03-14 10:19:03 -07:00
|
|
|
|
sharding: jax.sharding.Sharding, xs: Sequence[Any],
|
|
|
|
|
devices: Sequence[jax.Device], committed: bool = True):
|
|
|
|
|
from jax._src import array
|
|
|
|
|
|
|
|
|
|
bufs = [x for x, d in safe_zip(xs, devices)
|
|
|
|
|
if (isinstance(x, array.ArrayImpl) and
|
|
|
|
|
dispatch.is_single_device_sharding(x.sharding) and
|
|
|
|
|
x.device() == d)]
|
|
|
|
|
if len(bufs) == len(xs):
|
|
|
|
|
return array.ArrayImpl(
|
|
|
|
|
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
|
|
|
|
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
|
|
|
|
|
|
2023-03-08 19:12:37 -08:00
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
|
|
|
|
|
# from the input ShardingSpec, rather than the indices. However, this would
|
|
|
|
|
# require duplicating the ordering logic of spec_to_indices, which is more
|
|
|
|
|
# subtle and more likely to change than the index logic we have to support here.
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def as_slice_indices(arr: Any, idx: Index) -> tuple[
|
|
|
|
|
tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""Returns start_indices, limit_indices, removed_dims"""
|
|
|
|
|
start_indices = [0] * arr.ndim
|
|
|
|
|
limit_indices = list(arr.shape)
|
|
|
|
|
removed_dims = []
|
|
|
|
|
|
|
|
|
|
tuple_idx = idx if isinstance(idx, tuple) else (idx,)
|
|
|
|
|
for dim, sub_idx in enumerate(tuple_idx):
|
|
|
|
|
if isinstance(sub_idx, int):
|
|
|
|
|
start_indices[dim] = sub_idx
|
|
|
|
|
limit_indices[dim] = sub_idx + 1
|
|
|
|
|
removed_dims.append(dim)
|
|
|
|
|
elif sub_idx == slice(None):
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(sub_idx, slice), sub_idx
|
|
|
|
|
assert isinstance(sub_idx.start, int), sub_idx
|
|
|
|
|
assert isinstance(sub_idx.stop, int), sub_idx
|
|
|
|
|
start_indices[dim] = sub_idx.start
|
|
|
|
|
limit_indices[dim] = sub_idx.stop
|
|
|
|
|
|
|
|
|
|
return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def shard_aval(size, axis: int, aval):
|
|
|
|
|
try:
|
|
|
|
|
return shard_aval_handlers[type(aval)](size, axis, aval)
|
|
|
|
|
except KeyError as err:
|
|
|
|
|
raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err
|
2023-06-23 15:11:37 -07:00
|
|
|
|
shard_aval_handlers: dict[type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _shard_abstract_array(size, axis: int, x):
|
|
|
|
|
try:
|
|
|
|
|
if x.shape[axis] != size:
|
|
|
|
|
raise ValueError(f"Axis size {size} does not match dimension {axis} of "
|
|
|
|
|
f"shape {x.shape}")
|
|
|
|
|
except IndexError:
|
|
|
|
|
raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None
|
|
|
|
|
return x.update(shape=tuple_delete(x.shape, axis))
|
|
|
|
|
shard_aval_handlers[ShapedArray] = _shard_abstract_array
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def local_aval_to_result_handler(
|
|
|
|
|
aval: core.AbstractValue,
|
2023-03-13 08:49:39 -07:00
|
|
|
|
sharding: sharding_impls.XLACompatibleSharding,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
indices: Optional[tuple[Index, ...]],
|
|
|
|
|
) -> Callable[[list[xc.ArrayImpl]], Any]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""Returns a function for handling the raw buffers of a single output aval.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
aval: The local output AbstractValue.
|
|
|
|
|
sharding_spec: Indicates how the output is sharded across devices, or None
|
|
|
|
|
for non-array avals.
|
|
|
|
|
indices: The pre-computed result of spec_to_indices, or None for non-array
|
|
|
|
|
avals.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A function for handling the Buffers that will eventually be produced
|
|
|
|
|
for this output. The function will return an object suitable for returning
|
|
|
|
|
to the user, e.g. a ShardedDeviceArray.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
2023-03-20 09:09:15 -07:00
|
|
|
|
return local_result_handlers[(type(aval))](aval, sharding, indices)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
except KeyError as err:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"No pxla_result_handler for type: {type(aval)}") from err
|
|
|
|
|
|
2023-03-30 20:11:11 +00:00
|
|
|
|
PxlaResultHandler = Callable[..., Callable[[Any], Any]]
|
2023-06-23 15:11:37 -07:00
|
|
|
|
local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def global_aval_to_result_handler(
|
|
|
|
|
aval: core.AbstractValue, out_sharding, committed: bool,
|
|
|
|
|
is_out_sharding_from_xla: bool
|
2023-03-31 11:41:49 -07:00
|
|
|
|
) -> Callable[[Sequence[xc.ArrayImpl]], Any]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""Returns a function for handling the raw buffers of a single output aval.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
aval: The global output AbstractValue.
|
|
|
|
|
out_axis_resources: A PartitionSpec specifying the sharding of outputs.
|
|
|
|
|
Used for creating GSDAs.
|
|
|
|
|
global_mesh: The global device mesh that generated this output. Used
|
|
|
|
|
for creating GSDAs.
|
|
|
|
|
is_out_sharding_from_xla: True, if the out_sharding comes from XLA i.e.
|
|
|
|
|
the sharding is extracted from the HLO.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A function for handling the Buffers that will eventually be produced
|
|
|
|
|
for this output. The function will return an object suitable for returning
|
|
|
|
|
to the user, e.g. a ShardedDeviceArray.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
2023-03-20 09:09:15 -07:00
|
|
|
|
return global_result_handlers[type(aval)](
|
2023-02-06 14:28:36 -08:00
|
|
|
|
aval, out_sharding, committed, is_out_sharding_from_xla)
|
|
|
|
|
except KeyError as err:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"No pxla_result_handler for type: {type(aval)}") from err
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
global_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
### lazy device-memory persistence and result handling
|
|
|
|
|
|
2023-02-16 14:52:19 -08:00
|
|
|
|
# TODO(yashkatariya, phawkins): Remove this function after March 15, 2023.
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def make_sharded_device_array(
|
|
|
|
|
aval: ShapedArray,
|
|
|
|
|
sharding_spec: Optional[ShardingSpec],
|
|
|
|
|
# Any is for JAX extensions implementing their own buffer.
|
2023-06-23 15:11:37 -07:00
|
|
|
|
device_buffers: list[Any],
|
|
|
|
|
indices: Optional[tuple[Index, ...]] = None,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
):
|
|
|
|
|
"""Returns a ShardedDeviceArray implementation based on arguments.
|
|
|
|
|
|
|
|
|
|
Returns either a C++ SDA or a Python DeviceArray when the buffers are not
|
|
|
|
|
JAX buffers.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
aval: The `ShapedArray` for this array.
|
|
|
|
|
sharding_spec: If `None`, assumes a pmap-style ShardedDeviceArrays over the
|
|
|
|
|
first dimension.
|
|
|
|
|
device_buffers: If a list of Jax `Buffer` objects, a C++ SDA will be
|
|
|
|
|
returned (if the version is high enough). Otherwise, a Python object will
|
|
|
|
|
be returned, for JAX extensions not implementing the C++ API.
|
|
|
|
|
indices: For caching purposes, will be computed if `None`.
|
|
|
|
|
"""
|
|
|
|
|
if sharding_spec is None:
|
2023-04-06 09:48:14 -07:00
|
|
|
|
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
|
mesh = mesh_lib.thread_resources.env.physical_mesh
|
2023-03-15 17:08:21 -07:00
|
|
|
|
sharding: sharding_impls.XLACompatibleSharding
|
|
|
|
|
if mesh.empty:
|
|
|
|
|
sharding = sharding_impls.PmapSharding(
|
|
|
|
|
np.asarray([d.device() for d in device_buffers]), sharding_spec)
|
2023-02-16 14:52:19 -08:00
|
|
|
|
else:
|
2023-06-05 13:40:59 -07:00
|
|
|
|
hlo_sharding = sharding_specs.sharding_spec_sharding_proto(sharding_spec)
|
2023-04-11 19:25:56 -07:00
|
|
|
|
pspec = sharding_impls.parse_flatten_op_sharding(
|
2023-06-05 13:40:59 -07:00
|
|
|
|
hlo_sharding, mesh)[0].get_partition_spec()
|
2023-03-15 17:08:21 -07:00
|
|
|
|
sharding = sharding_impls.NamedSharding(mesh, pspec)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
|
return jax.make_array_from_single_device_arrays(
|
|
|
|
|
aval.shape, sharding, device_buffers) # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _hashable_index(idx):
|
|
|
|
|
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
|
|
|
|
|
|
|
|
|
|
# The fast path is handled directly in shard_args().
|
|
|
|
|
# TODO(yashkatariya): Move this to array.py when SDA is deleted. The local
|
|
|
|
|
# import of Array should go away at that time.
|
2023-03-02 13:28:25 -08:00
|
|
|
|
def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax._src.array import ArrayImpl
|
|
|
|
|
|
|
|
|
|
candidates = defaultdict(list)
|
|
|
|
|
if isinstance(x, ArrayImpl):
|
2023-03-14 11:11:17 -07:00
|
|
|
|
bufs = [buf.data for buf in x.addressable_shards]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
|
|
|
|
|
else:
|
|
|
|
|
bufs = x.device_buffers
|
|
|
|
|
arr_indices = x.indices
|
|
|
|
|
for buf, idx in safe_zip(bufs, arr_indices):
|
|
|
|
|
candidates[_hashable_index(idx)].append(buf)
|
|
|
|
|
|
|
|
|
|
bufs = []
|
|
|
|
|
for idx, device in safe_zip(indices, devices):
|
|
|
|
|
# Look up all buffers that contain the correct slice of the logical array.
|
|
|
|
|
candidates_list = candidates[_hashable_index(idx)]
|
|
|
|
|
if not candidates_list:
|
|
|
|
|
# This array isn't sharded correctly. Reshard it via host roundtrip.
|
|
|
|
|
# TODO(skye): more efficient reshard?
|
2023-03-02 13:28:25 -08:00
|
|
|
|
return shard_arg_handlers[type(x._value)](
|
|
|
|
|
x._value, devices, indices, sharding)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# Try to find a candidate buffer already on the correct device,
|
|
|
|
|
# otherwise copy one of them.
|
|
|
|
|
for buf in candidates_list:
|
|
|
|
|
if buf.device() == device:
|
|
|
|
|
bufs.append(buf)
|
|
|
|
|
break
|
|
|
|
|
else:
|
2023-03-15 17:08:21 -07:00
|
|
|
|
bufs.append(buf)
|
2023-03-14 11:11:17 -07:00
|
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
|
return batched_device_put(x.aval, sharding, bufs, devices)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def xla_pmap_impl_lazy(
|
|
|
|
|
fun: lu.WrappedFun,
|
|
|
|
|
*args,
|
|
|
|
|
backend: Optional[str],
|
|
|
|
|
axis_name: core.AxisName,
|
|
|
|
|
axis_size: int,
|
|
|
|
|
global_axis_size: int,
|
|
|
|
|
devices: Optional[Sequence[Any]],
|
|
|
|
|
name: str,
|
|
|
|
|
in_axes: Sequence[Optional[int]],
|
|
|
|
|
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
is_explicit_global_axis_size: bool,
|
2023-02-15 18:11:55 -08:00
|
|
|
|
) -> Callable:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if (config.jax_disable_jit and config.jax_eager_pmap and
|
2023-03-29 09:22:34 -07:00
|
|
|
|
not is_explicit_global_axis_size and not any(d for d in donated_invars)):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _emap_apply_fn(*args):
|
|
|
|
|
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
|
|
|
|
|
axis_size=axis_size, global_axis_size=global_axis_size,
|
|
|
|
|
devices=devices, name=name, in_axes=in_axes,
|
|
|
|
|
out_axes_thunk=out_axes_thunk,
|
|
|
|
|
donated_invars=donated_invars,
|
|
|
|
|
is_explicit_global_axis_size=is_explicit_global_axis_size)
|
|
|
|
|
return _emap_apply_fn
|
|
|
|
|
abstract_args = unsafe_map(xla.abstractify, args)
|
|
|
|
|
compiled_fun, fingerprint = parallel_callable(
|
|
|
|
|
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
|
2023-03-29 09:22:34 -07:00
|
|
|
|
in_axes, out_axes_thunk, donated_invars,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
is_explicit_global_axis_size, *abstract_args)
|
|
|
|
|
|
|
|
|
|
# Don't re-abstractify args unless logging is enabled for performance.
|
|
|
|
|
if config.jax_distributed_debug:
|
|
|
|
|
distributed_debug_log(("Running pmapped function", name),
|
|
|
|
|
("python function", fun.f),
|
|
|
|
|
("devices", devices),
|
|
|
|
|
("abstract args", map(xla.abstractify, args)),
|
|
|
|
|
("fingerprint", fingerprint))
|
|
|
|
|
return compiled_fun
|
|
|
|
|
|
|
|
|
|
def xla_pmap_impl(fun: lu.WrappedFun, *args, **params):
|
|
|
|
|
compiled_fun = xla_pmap_impl_lazy(fun, *args, **params)
|
|
|
|
|
return compiled_fun(*args)
|
|
|
|
|
|
|
|
|
|
class EmapInfo(NamedTuple):
|
|
|
|
|
backend: Optional[str]
|
|
|
|
|
devices: Optional[Sequence[Any]]
|
|
|
|
|
|
|
|
|
|
def _emap_impl(fun: lu.WrappedFun, *args,
|
|
|
|
|
backend: Optional[str],
|
|
|
|
|
axis_name: core.AxisName,
|
|
|
|
|
axis_size: int,
|
|
|
|
|
global_axis_size: int,
|
|
|
|
|
devices: Optional[Sequence[Any]],
|
|
|
|
|
name: str,
|
|
|
|
|
in_axes: Sequence[Optional[int]],
|
|
|
|
|
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
is_explicit_global_axis_size: bool,
|
|
|
|
|
):
|
|
|
|
|
from jax._src import array
|
|
|
|
|
# TODO(sharadmv,mattjj): implement these cases
|
|
|
|
|
if any(d for d in donated_invars):
|
|
|
|
|
raise NotImplementedError("Buffer donation not supported in eager pmap.")
|
|
|
|
|
if is_explicit_global_axis_size:
|
|
|
|
|
raise NotImplementedError("Non-default global_axis_size not supported in "
|
|
|
|
|
"eager pmap.")
|
|
|
|
|
|
|
|
|
|
emap_info = EmapInfo(backend, devices)
|
|
|
|
|
shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes]
|
|
|
|
|
with core.new_base_main(MapTrace, emap_info=emap_info) as main:
|
|
|
|
|
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main):
|
|
|
|
|
t = main.with_cur_sublevel()
|
|
|
|
|
tracers = [
|
|
|
|
|
MapTracer(t, arg, s) for arg, s in zip(args, shard_axes)]
|
|
|
|
|
ans = fun.call_wrapped(*tracers)
|
|
|
|
|
out_tracers = map(t.full_raise, ans)
|
|
|
|
|
outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers)
|
|
|
|
|
del main
|
|
|
|
|
out_axes = out_axes_thunk()
|
|
|
|
|
|
|
|
|
|
platform = xb.get_backend(backend).platform
|
|
|
|
|
donate_argnums = (1,) if platform in {"cuda", "rocm", "tpu"} else ()
|
|
|
|
|
new_outvals = []
|
|
|
|
|
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
|
|
|
|
|
with jax.disable_jit(False):
|
|
|
|
|
donate_argnums_ = donate_argnums
|
2023-03-20 14:17:25 -07:00
|
|
|
|
if isinstance(outval, array.ArrayImpl):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# We don't want to donate if it's already sharded.
|
|
|
|
|
donate_argnums_ = ()
|
|
|
|
|
out = jax.pmap(
|
|
|
|
|
lambda _, x: x,
|
|
|
|
|
in_axes=(0, out_axis_src.get(axis_name)),
|
|
|
|
|
out_axes=out_axis,
|
|
|
|
|
devices=(None if devices is None else list(devices)),
|
|
|
|
|
backend=backend,
|
|
|
|
|
donate_argnums=donate_argnums_)(np.arange(axis_size), outval)
|
|
|
|
|
new_outvals.append(out)
|
|
|
|
|
return new_outvals
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def _map_schedule(idx: tuple[Optional[int], ...]) -> tuple[Optional[int], ...]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# In order to do a multi-map (a simultaneous map over several axes), we will
|
|
|
|
|
# nest several maps. Each time we do a map, we "remove" an input axis so we
|
|
|
|
|
# need to update the remaining map axes. For example, if we are to map over
|
|
|
|
|
# the axes 0, 3, and 4, we make three calls to pmap with in_axes as 0, 2, 2.
|
|
|
|
|
return tuple(None if i is None else
|
|
|
|
|
i - sum(j is not None and j < i for j in idx[:l])
|
|
|
|
|
for l, i in enumerate(idx))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# We're often creating `f`s on the fly and we try to carefully make them have
|
|
|
|
|
# the right __hash__ and __eq__. However, despite our attempts pmap's caching
|
|
|
|
|
# still ends up not working, because it has a separate cache per
|
|
|
|
|
# _function object_. Adding this annotation here lets us reuse the same pmap
|
|
|
|
|
# callable for all equivalent primitive pmaps.
|
|
|
|
|
@lru_cache()
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName],
|
|
|
|
|
all_axes: list[tuple[Optional[int], ...]]
|
|
|
|
|
) -> tuple[Callable, dict[core.AxisName, int]]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
used_names = []
|
|
|
|
|
for i, name in reversed(list(enumerate(names))):
|
|
|
|
|
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
|
|
|
|
|
if any(in_axis is not None for in_axis in in_axes):
|
|
|
|
|
f = jax.pmap(
|
|
|
|
|
f,
|
|
|
|
|
in_axes=in_axes,
|
|
|
|
|
axis_name=name,
|
|
|
|
|
out_axes=0,
|
|
|
|
|
backend=info.backend,
|
|
|
|
|
devices=(None if info.devices is None else list(info.devices)))
|
|
|
|
|
used_names.append(name)
|
|
|
|
|
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
|
|
|
|
|
return f, out_shard_axes
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
FakePrimitive = namedtuple("FakePrimitive", ["multiple_results", "bind"])
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
class MapTrace(core.Trace):
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, emap_info):
|
|
|
|
|
super().__init__(*args)
|
|
|
|
|
self.emap_info = emap_info
|
|
|
|
|
|
|
|
|
|
def pure(self, val):
|
|
|
|
|
return MapTracer(self, val, {})
|
|
|
|
|
|
|
|
|
|
def sublift(self, tracer):
|
|
|
|
|
return MapTracer(self, tracer.val, tracer.shard_axes)
|
|
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
|
info = self.main.payload["emap_info"]
|
|
|
|
|
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
|
|
|
|
|
names = tuple(f.name for f in core.thread_local_state.trace_state.axis_env
|
|
|
|
|
if f.main_trace is self.main)
|
|
|
|
|
all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes) # pytype: disable=wrong-arg-types # always-use-return-annotations
|
|
|
|
|
f = HashableFunction(lambda *args: primitive.bind(*args, **params),
|
|
|
|
|
(primitive, tuple(params.items())))
|
|
|
|
|
f_mapped, out_shard_axes = _multi_pmap(f, info, names, all_axes)
|
|
|
|
|
with core.eval_context(), jax.disable_jit(False):
|
|
|
|
|
outvals = f_mapped(*vals)
|
|
|
|
|
if primitive.multiple_results:
|
|
|
|
|
return [MapTracer(self, val, out_shard_axes) for val in outvals]
|
|
|
|
|
return MapTracer(self, outvals, out_shard_axes)
|
|
|
|
|
|
|
|
|
|
def process_call(self, call_primitive, fun, tracers, params):
|
2023-03-23 11:43:49 -07:00
|
|
|
|
raise NotImplementedError
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-08 21:12:40 -07:00
|
|
|
|
def process_map(self, map_primitive, fun, tracers, params):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if params['devices'] is not None:
|
|
|
|
|
raise ValueError("Nested pmap with explicit devices argument.")
|
|
|
|
|
if not config.jax_disable_jit:
|
|
|
|
|
bind = HashableFunction(
|
2023-04-08 21:12:40 -07:00
|
|
|
|
lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs),
|
|
|
|
|
(map_primitive, fun))
|
2023-02-07 11:16:01 -08:00
|
|
|
|
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self.process_primitive(fake_primitive, tracers, params)
|
|
|
|
|
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
|
|
|
|
|
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
|
|
|
|
|
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
|
|
|
|
|
shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s}
|
|
|
|
|
if ax is not None else s
|
|
|
|
|
for v, ax, s in zip(vals, in_axes, shard_axes)]
|
|
|
|
|
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
|
|
|
|
|
t = self.main.with_cur_sublevel()
|
|
|
|
|
in_tracers = map(partial(MapTracer, t), vals, shard_axes)
|
|
|
|
|
ans = fun.call_wrapped(*in_tracers)
|
|
|
|
|
out_tracers = map(t.full_raise, ans)
|
|
|
|
|
out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers)
|
|
|
|
|
del t, in_tracers, ans, out_tracers
|
|
|
|
|
out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst)
|
|
|
|
|
for v, s, dst in zip(out, outaxes, out_axes_thunk()))
|
|
|
|
|
return map(partial(MapTracer, self), out, outaxes)
|
|
|
|
|
|
2023-04-08 21:12:40 -07:00
|
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
bind = HashableFunction(
|
2023-04-08 21:12:40 -07:00
|
|
|
|
lambda *args, **kwargs: prim.bind(
|
2023-02-17 14:03:28 -08:00
|
|
|
|
fun, jvp, *args, symbolic_zeros=symbolic_zeros, **kwargs),
|
2023-04-08 21:12:40 -07:00
|
|
|
|
(prim, fun, jvp, symbolic_zeros))
|
2023-02-07 11:16:01 -08:00
|
|
|
|
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self.process_primitive(fake_primitive, tracers, {})
|
|
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
|
2023-03-24 14:42:19 -07:00
|
|
|
|
out_trees, symbolic_zeros):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
bind = HashableFunction(
|
2023-03-24 14:42:19 -07:00
|
|
|
|
lambda *args, **kwargs: primitive.bind(
|
|
|
|
|
fun, fwd, bwd, *args, out_trees=out_trees,
|
|
|
|
|
symbolic_zeros=symbolic_zeros, **kwargs),
|
2023-02-06 14:28:36 -08:00
|
|
|
|
(primitive, fun, fwd, bwd))
|
2023-02-07 11:16:01 -08:00
|
|
|
|
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self.process_primitive(fake_primitive, tracers, {})
|
|
|
|
|
|
|
|
|
|
def process_axis_index(self, frame):
|
|
|
|
|
bind = HashableFunction(
|
|
|
|
|
lambda _: jax.lax.axis_index(frame.name),
|
|
|
|
|
(jax.lax.axis_index, frame.name))
|
2023-02-07 11:16:01 -08:00
|
|
|
|
fake_primitive = FakePrimitive(multiple_results=False, bind=bind)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
with core.eval_context():
|
|
|
|
|
range = jax.lax.iota(np.int32, frame.size)
|
|
|
|
|
dummy_tracer = MapTracer(self, range, {frame.name: 0})
|
|
|
|
|
return self.process_primitive(fake_primitive, (dummy_tracer,), {})
|
|
|
|
|
|
|
|
|
|
def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
|
|
|
|
|
annotation: Optional[int]) -> Optional[int]:
|
|
|
|
|
if annotation is None: return None
|
|
|
|
|
mapped_axes_ = set(mapped_axes)
|
|
|
|
|
return [i for i in range(ndim) if i not in mapped_axes_][annotation]
|
|
|
|
|
|
|
|
|
|
def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
shard_axis_src: dict[core.AxisName, int],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
dst_annotation: Optional[int]
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[Any, dict[core.AxisName, int]]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
shard_axis_out = dict(shard_axis_src)
|
|
|
|
|
src = shard_axis_out.pop(axis_name, None)
|
|
|
|
|
dst = _annot_to_flat(np.ndim(val) + (src is None), shard_axis_out.values(),
|
|
|
|
|
dst_annotation)
|
|
|
|
|
with core.eval_context():
|
|
|
|
|
if src == dst:
|
|
|
|
|
outval = val
|
|
|
|
|
elif type(src) == type(dst) == int:
|
|
|
|
|
outval = batching.moveaxis(val, src, dst)
|
|
|
|
|
shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst)
|
|
|
|
|
elif src is None and dst is not None:
|
|
|
|
|
outval = batching.broadcast(val, axis_size, dst)
|
|
|
|
|
shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()}
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
return outval, shard_axis_out
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def _moveaxis(ndim: int, shard_axes: dict[core.AxisName, int],
|
|
|
|
|
src: int, dst: int) -> dict[core.AxisName, int]:
|
|
|
|
|
lst: list[Optional[core.AxisName]] = [None] * ndim
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for k, v in shard_axes.items():
|
|
|
|
|
lst[v] = k
|
|
|
|
|
name = lst.pop(src)
|
|
|
|
|
lst.insert(dst - (src < dst), name)
|
|
|
|
|
return {name: i for i, name in enumerate(lst) if name is not None}
|
|
|
|
|
|
|
|
|
|
class MapTracer(core.Tracer):
|
|
|
|
|
__slots__ = ["val", "shard_axes"]
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def __init__(self, trace: MapTrace, val, shard_axes: dict[core.AxisName, int]):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self._trace = trace
|
|
|
|
|
self.val = val
|
|
|
|
|
self.shard_axes = shard_axes
|
|
|
|
|
assert all(val < self.val.ndim for val in self.shard_axes.values())
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def aval(self):
|
|
|
|
|
aval = xla.abstractify(self.val)
|
|
|
|
|
shard_axes = dict(self.shard_axes)
|
|
|
|
|
for axis_idx in sorted(shard_axes.values())[::-1]:
|
|
|
|
|
aval = core.mapped_aval(aval.shape[axis_idx], axis_idx, aval)
|
|
|
|
|
return aval
|
|
|
|
|
|
|
|
|
|
def full_lower(self):
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
named_axes = [f"{k}={v}" for k, v in self.shard_axes.items()]
|
|
|
|
|
return f"{self.val}{{{','.join(named_axes)}}}"
|
|
|
|
|
|
|
|
|
|
@lu.cache
|
|
|
|
|
def parallel_callable(fun: lu.WrappedFun,
|
|
|
|
|
backend_name: Optional[str],
|
|
|
|
|
axis_name: core.AxisName,
|
|
|
|
|
axis_size: int,
|
|
|
|
|
global_axis_size: int,
|
|
|
|
|
devices: Optional[Sequence[Any]],
|
|
|
|
|
name: str,
|
|
|
|
|
in_axes: Sequence[Optional[int]],
|
|
|
|
|
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
is_explicit_global_axis_size: bool,
|
|
|
|
|
*avals):
|
|
|
|
|
pmap_computation = lower_parallel_callable(
|
|
|
|
|
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
|
2023-03-29 09:22:34 -07:00
|
|
|
|
in_axes, out_axes_thunk, donated_invars,
|
2023-02-28 11:30:23 +01:00
|
|
|
|
is_explicit_global_axis_size, avals, lowering_platform=None)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
pmap_executable = pmap_computation.compile()
|
|
|
|
|
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class ParallelCallableInfo:
|
|
|
|
|
name: str
|
2023-02-16 11:54:25 -08:00
|
|
|
|
backend: xc.Client
|
2023-02-06 14:28:36 -08:00
|
|
|
|
axis_name: core.AxisName
|
|
|
|
|
axis_size: int
|
|
|
|
|
global_axis_size: int
|
2023-02-16 11:54:25 -08:00
|
|
|
|
devices: Optional[Sequence[xc.Device]]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
in_axes: Iterable[Optional[int]]
|
|
|
|
|
out_axes_thunk: Callable[[], Sequence[Optional[int]]]
|
|
|
|
|
avals: Sequence[core.AbstractValue]
|
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def local_devices(self):
|
|
|
|
|
if self.devices:
|
|
|
|
|
out = [d for d in self.devices
|
|
|
|
|
if d.process_index == xb.process_index(self.backend)]
|
|
|
|
|
assert len(out) > 0
|
|
|
|
|
else:
|
|
|
|
|
out = None # type: ignore
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def out_axes(self):
|
|
|
|
|
return self.out_axes_thunk()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShardInfo(NamedTuple):
|
|
|
|
|
sharded_avals: Sequence[core.AbstractValue]
|
2023-04-12 08:49:07 -07:00
|
|
|
|
out_sharded_avals: Sequence[core.ShapedArray]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
global_sharded_avals: Sequence[core.AbstractValue]
|
|
|
|
|
num_local_shards: int
|
|
|
|
|
num_global_shards: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReplicaInfo(NamedTuple):
|
|
|
|
|
jaxpr_replicas: int
|
|
|
|
|
num_local_replicas: int
|
|
|
|
|
num_global_replicas: int
|
|
|
|
|
|
|
|
|
|
|
2023-04-12 12:53:32 -07:00
|
|
|
|
def find_replicas(
|
|
|
|
|
jaxpr: core.Jaxpr, axis_size: int, global_axis_size: int
|
|
|
|
|
) -> ReplicaInfo:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
|
|
|
|
|
jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr)
|
|
|
|
|
num_local_replicas = axis_size * jaxpr_replicas
|
|
|
|
|
num_global_replicas = global_axis_size * jaxpr_replicas
|
|
|
|
|
return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stage_parallel_callable(
|
2023-04-12 12:53:32 -07:00
|
|
|
|
pci: ParallelCallableInfo, fun: lu.WrappedFun
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[core.Jaxpr, list[Any], ReplicaInfo, ShardInfo]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
sharded_avals = tuple(
|
|
|
|
|
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
|
|
|
|
|
for axis, aval in safe_zip(pci.in_axes, pci.avals))
|
|
|
|
|
|
|
|
|
|
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
|
2023-05-15 09:15:22 -07:00
|
|
|
|
with dispatch.log_elapsed_time(
|
|
|
|
|
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
|
|
|
|
|
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
|
2023-03-29 09:22:34 -07:00
|
|
|
|
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
|
make mlir arg and result names work with pmap
This is a follow-up on #15080 to restore (and indeed fix!) how pmap builds a
jaxpr with debug info (i.e. parameter names and result paths). The difference
with the machinery in #15080 is just to deal with pmap being final-style (i.e.
build the jaxpr at the last second, well after pytrees have been flattened away
and transformations have been applied), whereas the machinery for pjit in
imagine, plumbing for the former is a bit more long-range and subtle.
The main idea here is that we need to annotate and maintain debug info on the
lu.WrappedFun instance, which we first form at the api.py level, then pass
through transformations (which can either update or drop debug info), then
finally hand off to the impl rule to be traced to a jaxpr. It makes sense as an
annotation, parallel with the in_type annotation used for dynamic shapes,
because the debug info has to be updated as transformations are applied, since
they might e.g. add tangent inputs and outputs.
In more detail: with an initial-style higher-orer primitive (like pjit), a
jaxpr is formed immediately. Transformations, like autodiff, are
jaxpr-to-jaxpr, and so those transformations (like ad.jvp_jaxpr) need to return
a new jaxpr either with updated debug info or no debug info at all. (The initial
implementation in #15080 doesn't provide updated debug info in any of those
jaxpr-to-jaxpr transformation functions, so the debug info is only applied to
the jaxpr and then lowered to MLIR when the pjit as at the top level.)
For final-style, like pmap here, instead of transformations being
jaxpr-to-jaxpr, they're WrappedFun-to-WrappedFun. And so, analogously,
transformations, like ad.JVPTrace.process_map, would need to produce a
WrappedFun with updated debug info or no debug info at all. (ALso analogously
to #15080, this PR only implements enough for the debug info to be preserved
for top-level pmaps.)
This PR doens't yet delete the trace-time debug info in partial_eval.py. But
that'll happen too!
2023-03-17 17:45:41 -07:00
|
|
|
|
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
|
|
|
|
|
|
|
|
|
assert len(out_sharded_avals) == len(pci.out_axes), (
|
|
|
|
|
len(out_sharded_avals), len(pci.out_axes))
|
|
|
|
|
|
|
|
|
|
# TODO(skye,mattjj): allow more collectives on multi-host as we test them, but
|
|
|
|
|
# for now raise an error
|
|
|
|
|
if pci.devices is not None:
|
|
|
|
|
is_multi_host_pmap = len(pci.local_devices) != len(pci.devices)
|
|
|
|
|
else:
|
|
|
|
|
is_multi_host_pmap = xb.process_count(pci.backend) > 1
|
|
|
|
|
if is_multi_host_pmap:
|
|
|
|
|
check_multihost_collective_allowlist(jaxpr)
|
|
|
|
|
|
|
|
|
|
replicas = find_replicas(jaxpr, pci.axis_size, pci.global_axis_size)
|
2023-04-12 08:49:07 -07:00
|
|
|
|
num_local_shards = replicas.num_local_replicas
|
|
|
|
|
num_global_shards = replicas.num_global_replicas
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
shards = ShardInfo(
|
2023-03-29 09:22:34 -07:00
|
|
|
|
sharded_avals, out_sharded_avals, sharded_avals,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
num_local_shards, num_global_shards)
|
|
|
|
|
|
2023-04-12 08:49:07 -07:00
|
|
|
|
return jaxpr, consts, replicas, shards
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
|
|
|
|
def lower_parallel_callable(
|
|
|
|
|
fun: lu.WrappedFun,
|
|
|
|
|
backend_name: Optional[str],
|
|
|
|
|
axis_name: core.AxisName,
|
|
|
|
|
axis_size: int,
|
|
|
|
|
global_axis_size: int,
|
2023-02-16 11:54:25 -08:00
|
|
|
|
devices: Optional[Sequence[xc.Device]],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
name: str,
|
|
|
|
|
in_axes: Iterable[Optional[int]],
|
|
|
|
|
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
is_explicit_global_axis_size: bool,
|
2023-02-28 11:30:23 +01:00
|
|
|
|
avals: Sequence[core.AbstractValue],
|
|
|
|
|
*,
|
|
|
|
|
lowering_platform: Optional[str]):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# Determine global_axis_size for use in AxisEnv.
|
|
|
|
|
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
|
|
|
|
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
|
|
|
|
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
|
|
|
|
|
if (xb.process_count() == 1 and is_explicit_global_axis_size
|
|
|
|
|
and global_axis_size != axis_size):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Specified axis_size {global_axis_size} doesn't match received "
|
|
|
|
|
f"axis_size {axis_size}.")
|
|
|
|
|
|
|
|
|
|
if devices is not None and backend_name is None:
|
|
|
|
|
backend = xb.get_device_backend(devices[0])
|
|
|
|
|
else:
|
|
|
|
|
backend = xb.get_backend(backend_name)
|
|
|
|
|
|
|
|
|
|
no_nested_sharding = False
|
|
|
|
|
must_run_on_all_devices = False
|
|
|
|
|
if not is_explicit_global_axis_size:
|
|
|
|
|
if xb.process_count(backend) > 1:
|
|
|
|
|
if devices:
|
|
|
|
|
# This allows each host in a multi-host pmap to run on a different number
|
2023-04-12 08:49:07 -07:00
|
|
|
|
# of devices, but precludes nested sharding (i.e. inner pmaps).
|
2023-02-06 14:28:36 -08:00
|
|
|
|
no_nested_sharding = True
|
|
|
|
|
else:
|
|
|
|
|
# This assumes all hosts run on the same number of devices. We make sure
|
|
|
|
|
# this assumption is true by requiring that the pmap is run on all devices
|
|
|
|
|
# (and making the further assumption that each host has the same number of
|
|
|
|
|
# devices). Nested sharding is ok in this case.
|
|
|
|
|
must_run_on_all_devices = True
|
|
|
|
|
|
|
|
|
|
pci = ParallelCallableInfo(
|
|
|
|
|
name, backend, axis_name, axis_size, global_axis_size, devices,
|
|
|
|
|
in_axes, out_axes_thunk, avals)
|
2023-04-12 08:49:07 -07:00
|
|
|
|
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if logger.isEnabledFor(logging.DEBUG):
|
|
|
|
|
logger.debug("sharded_avals: %s", shards.sharded_avals)
|
|
|
|
|
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
|
|
|
|
|
logger.debug("num_replicas: %d num_local_replicas: %d",
|
|
|
|
|
replicas.num_global_replicas, replicas.num_local_replicas)
|
|
|
|
|
logger.debug("devices: %s", devices)
|
|
|
|
|
logger.debug("local_devices: %s", pci.local_devices)
|
|
|
|
|
|
|
|
|
|
if (xb.process_count(backend) > 1 and must_run_on_all_devices and
|
|
|
|
|
shards.num_local_shards != xb.local_device_count(backend)):
|
|
|
|
|
if shards.num_local_shards == axis_size:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"On multi-host platforms, the input to pmapped functions must have "
|
|
|
|
|
f"leading axis size equal to the number of local devices if no "
|
|
|
|
|
f"`devices` argument is specified. Got {axis_size=}, "
|
|
|
|
|
f"num_local_devices={xb.local_device_count(backend)}")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"On multi-host platforms, pmapped functions must run across all "
|
|
|
|
|
f"devices, i.e. num_replicas * num_partitions should equal the "
|
|
|
|
|
f"number of local devices. Got "
|
2023-04-12 08:49:07 -07:00
|
|
|
|
f"num_replicas={replicas.num_local_replicas}, and "
|
2023-02-06 14:28:36 -08:00
|
|
|
|
f"num_local_devices={xb.local_device_count(backend)}")
|
|
|
|
|
|
2023-04-12 08:49:07 -07:00
|
|
|
|
if no_nested_sharding and replicas.jaxpr_replicas > 1:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise ValueError(
|
|
|
|
|
f"On multi-host platforms, pmapped functions that both have `devices` "
|
2023-04-12 08:49:07 -07:00
|
|
|
|
f"specified and contain an inner_pmap must specify an "
|
2023-02-06 14:28:36 -08:00
|
|
|
|
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
|
2023-04-12 08:49:07 -07:00
|
|
|
|
f"{replicas.jaxpr_replicas}")
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
2023-04-17 07:52:56 -07:00
|
|
|
|
if logger.isEnabledFor(log_priority):
|
|
|
|
|
logger.log(log_priority,
|
|
|
|
|
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
|
|
|
|
|
fun.__name__, id(fun),
|
|
|
|
|
shards.num_global_shards, avals, replicas.num_global_replicas)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
|
axis_env = sharding_impls.AxisEnv(
|
2023-02-06 14:28:36 -08:00
|
|
|
|
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
|
2023-02-27 11:37:10 -08:00
|
|
|
|
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
|
|
|
|
replicated_args = [axis is None for axis in in_axes]
|
|
|
|
|
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
|
|
|
|
|
backend.platform)
|
|
|
|
|
module_name = f"pmap_{fun.__name__}"
|
|
|
|
|
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
2023-02-01 17:50:00 -08:00
|
|
|
|
ordered_effects = list(
|
|
|
|
|
effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
|
|
|
|
if ordered_effects:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise ValueError("Ordered effects not supported in `pmap`.")
|
2023-02-01 17:50:00 -08:00
|
|
|
|
unordered_effects = list(
|
|
|
|
|
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
2023-05-15 08:07:31 -07:00
|
|
|
|
with dispatch.log_elapsed_time(
|
2023-05-15 09:15:22 -07:00
|
|
|
|
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
2023-05-15 08:07:31 -07:00
|
|
|
|
lowering_result = mlir.lower_jaxpr_to_module(
|
|
|
|
|
module_name,
|
|
|
|
|
closed_jaxpr,
|
|
|
|
|
ordered_effects,
|
|
|
|
|
backend,
|
|
|
|
|
lowering_platform or backend.platform,
|
|
|
|
|
sharding_impls.ReplicaAxisContext(axis_env),
|
|
|
|
|
name_stack,
|
|
|
|
|
donated_invars,
|
|
|
|
|
replicated_args=replicated_args,
|
|
|
|
|
arg_shardings=None,
|
|
|
|
|
result_shardings=None,
|
|
|
|
|
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
|
|
|
|
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
|
|
|
|
num_replicas=replicas.num_global_replicas)
|
2023-04-21 14:37:52 -07:00
|
|
|
|
return PmapComputation(lowering_result.module, pci=pci, replicas=replicas,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
shards=shards, tuple_args=tuple_args,
|
|
|
|
|
unordered_effects=unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects,
|
2023-04-21 14:37:52 -07:00
|
|
|
|
keepalive=lowering_result.keepalive,
|
|
|
|
|
host_callbacks=lowering_result.host_callbacks,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PmapComputation(stages.XlaLowering):
|
|
|
|
|
_hlo: ir.Module
|
|
|
|
|
_executable: Optional[PmapExecutable]
|
|
|
|
|
|
|
|
|
|
def __init__(self, hlo: ir.Module, **compile_args):
|
|
|
|
|
self._executable = None
|
|
|
|
|
self._hlo = hlo
|
|
|
|
|
self.compile_args = compile_args
|
|
|
|
|
|
|
|
|
|
# -- stages.XlaLowering overrides
|
|
|
|
|
|
|
|
|
|
def stablehlo(self) -> ir.Module:
|
|
|
|
|
return self._hlo
|
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
2023-03-30 17:13:46 -07:00
|
|
|
|
def compile(self, compiler_options=None) -> PmapExecutable:
|
|
|
|
|
if self._executable is None or compiler_options is not None:
|
|
|
|
|
executable = UnloadedPmapExecutable.from_hlo(
|
|
|
|
|
self._hlo, **self.compile_args,
|
|
|
|
|
compiler_options=compiler_options)
|
|
|
|
|
if compiler_options is None:
|
|
|
|
|
self._executable = executable
|
|
|
|
|
return executable
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self._executable
|
|
|
|
|
|
2023-04-12 08:49:07 -07:00
|
|
|
|
def _cast_to_shaped_array(aval: core.AbstractValue) -> ShapedArray:
|
|
|
|
|
assert isinstance(aval, ShapedArray), aval
|
|
|
|
|
return cast(ShapedArray, aval)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class UnloadedPmapExecutable:
|
|
|
|
|
compiled: Any
|
|
|
|
|
backend: xb.XlaBackend
|
2023-02-14 23:00:40 -08:00
|
|
|
|
local_input_avals: Sequence[core.AbstractValue]
|
2023-03-13 08:49:39 -07:00
|
|
|
|
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
local_output_avals: Sequence[ShapedArray]
|
2023-03-13 08:49:39 -07:00
|
|
|
|
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
2023-06-23 15:11:37 -07:00
|
|
|
|
unordered_effects: list[core.Effect]
|
|
|
|
|
ordered_effects: list[core.Effect]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
keepalive: Sequence[Any]
|
|
|
|
|
host_callbacks: Sequence[Any]
|
2023-04-19 15:08:21 -07:00
|
|
|
|
jaxpr_debug_info: core.JaxprDebugInfo
|
|
|
|
|
|
|
|
|
|
def build_execute_fun(self):
|
|
|
|
|
input_indices = []
|
|
|
|
|
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
|
|
|
|
|
assert isinstance(spec, sharding_impls.PmapSharding), spec
|
|
|
|
|
assert isinstance(aval, core.ShapedArray), aval
|
|
|
|
|
input_indices.append(
|
|
|
|
|
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
|
|
|
|
|
if spec.sharding_spec is not None else None)
|
|
|
|
|
handle_outs = local_avals_to_results_handler(self.local_output_avals,
|
|
|
|
|
self.output_shardings)
|
|
|
|
|
handle_args = InputsHandler(self.compiled.local_devices(),
|
|
|
|
|
self.input_shardings, input_indices)
|
|
|
|
|
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
|
|
|
|
|
self.backend, handle_args, handle_outs,
|
|
|
|
|
self.unordered_effects,
|
|
|
|
|
self.ordered_effects, self.keepalive,
|
|
|
|
|
bool(self.host_callbacks),
|
|
|
|
|
set(range(len(input_indices))))
|
|
|
|
|
return execute_fun
|
|
|
|
|
|
|
|
|
|
def load(self) -> PmapExecutable:
|
|
|
|
|
fingerprint = getattr(self.compiled, "fingerprint", None)
|
|
|
|
|
|
|
|
|
|
return PmapExecutable(
|
|
|
|
|
self.compiled, self.build_execute_fun, fingerprint,
|
|
|
|
|
self.local_input_avals, self.jaxpr_debug_info, self)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
@staticmethod
|
2023-04-21 14:37:52 -07:00
|
|
|
|
def from_hlo(hlo: ir.Module,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
pci: ParallelCallableInfo,
|
|
|
|
|
replicas: ReplicaInfo,
|
|
|
|
|
shards: ShardInfo,
|
|
|
|
|
tuple_args: bool,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
unordered_effects: list[core.Effect],
|
|
|
|
|
ordered_effects: list[core.Effect],
|
|
|
|
|
host_callbacks: list[Any],
|
2023-03-30 17:13:46 -07:00
|
|
|
|
keepalive: Any,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
jaxpr_debug_info: core.JaxprDebugInfo,
|
2023-03-30 17:13:46 -07:00
|
|
|
|
compiler_options=None):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
devices = pci.devices
|
|
|
|
|
if devices is None:
|
|
|
|
|
if shards.num_global_shards > xb.device_count(pci.backend):
|
|
|
|
|
msg = ("compiling computation that requires {} logical devices, but only {} XLA "
|
2023-04-12 08:49:07 -07:00
|
|
|
|
"devices are available (num_replicas={})")
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise ValueError(msg.format(shards.num_global_shards,
|
|
|
|
|
xb.device_count(pci.backend),
|
2023-04-12 08:49:07 -07:00
|
|
|
|
replicas.num_global_replicas))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# On a single host, we simply grab the first N devices from jax.devices().
|
|
|
|
|
# In the single host case, we want the default device order of pmap to
|
|
|
|
|
# match jax.devices().
|
|
|
|
|
# On multiple hosts, we create a default device assignment that ensures
|
|
|
|
|
# each host is responsible for a contiguous set of replicas.
|
|
|
|
|
if shards.num_global_shards > shards.num_local_shards:
|
|
|
|
|
# TODO(skye): use a locality-aware assignment that satisfies the above
|
|
|
|
|
# constraint.
|
|
|
|
|
devices = [d for process_index in range(xb.process_count(pci.backend))
|
|
|
|
|
for d in xb.local_devices(process_index, pci.backend)]
|
|
|
|
|
else:
|
|
|
|
|
devices = xb.local_devices(backend=pci.backend)[:shards.num_local_shards]
|
|
|
|
|
else:
|
|
|
|
|
if shards.num_local_shards != len(pci.local_devices):
|
|
|
|
|
local_devices_str = ", ".join(map(str, pci.local_devices))
|
|
|
|
|
if shards.num_local_shards == pci.axis_size:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Leading axis size of input to pmapped function must equal the "
|
|
|
|
|
f"number of local devices passed to pmap. Got axis_size="
|
|
|
|
|
f"{pci.axis_size}, num_local_devices={len(pci.local_devices)}.\n"
|
|
|
|
|
f"(Local devices available to pmap: {local_devices_str})")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"pmapped function requires {shards.num_local_shards} local "
|
|
|
|
|
f"devices to run due to nested pmapped or other parallel "
|
|
|
|
|
f"functions, but only {len(pci.local_devices)} are available.\n"
|
|
|
|
|
f"(outer axis size: {pci.axis_size}, local devices available to "
|
|
|
|
|
f"pmap: {local_devices_str})")
|
|
|
|
|
if shards.num_global_shards != len(devices):
|
|
|
|
|
raise ValueError("compiling computation that creates %s shards, "
|
|
|
|
|
"but %s devices were specified" %
|
|
|
|
|
(shards.num_global_shards, len(devices)))
|
|
|
|
|
|
|
|
|
|
# 'devices' may be 1D or 2D at this point (e.g.
|
|
|
|
|
# get_default_device_assignment() returns 2D assignment, caller may have
|
|
|
|
|
# provided 1D list of devices).
|
|
|
|
|
# Convert to 2D in case it's 1D and we have > 1 partitions.
|
2023-04-12 08:49:07 -07:00
|
|
|
|
num_partitions = 1
|
2023-02-06 14:28:36 -08:00
|
|
|
|
device_assignment: np.ndarray = np.array(devices).reshape(
|
2023-04-12 08:49:07 -07:00
|
|
|
|
(replicas.num_global_replicas, num_partitions))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
compile_options = xb.get_compile_options(
|
|
|
|
|
num_replicas=replicas.num_global_replicas,
|
2023-04-12 08:49:07 -07:00
|
|
|
|
num_partitions=num_partitions,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
device_assignment=device_assignment,
|
2023-04-12 08:49:07 -07:00
|
|
|
|
use_spmd_partitioning=False,
|
2023-03-30 17:13:46 -07:00
|
|
|
|
env_options_overrides=compiler_options,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
)
|
|
|
|
|
compile_options.parameter_is_tupled_arguments = tuple_args
|
|
|
|
|
|
|
|
|
|
process_index = xb.process_index(pci.backend)
|
|
|
|
|
local_device_assignment = np.array([
|
|
|
|
|
d for d in device_assignment.flat if d.process_index == process_index
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
input_sharding_specs = [
|
2023-04-06 09:48:14 -07:00
|
|
|
|
sharding_specs.pmap_sharding_spec(
|
|
|
|
|
replicas.num_local_replicas, pci.axis_size,
|
|
|
|
|
cast(ShapedArray, aval).shape, in_axis)
|
2023-04-12 08:49:07 -07:00
|
|
|
|
for aval, in_axis in safe_zip(shards.sharded_avals, pci.in_axes)]
|
|
|
|
|
in_shardings = _get_pmap_sharding(local_device_assignment,
|
|
|
|
|
input_sharding_specs)
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
local_unmapped_avals = [
|
2023-04-12 08:49:07 -07:00
|
|
|
|
_cast_to_shaped_array(
|
|
|
|
|
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if out_axis is not None else aval
|
2023-04-12 08:49:07 -07:00
|
|
|
|
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_specs = [
|
2023-04-06 09:48:14 -07:00
|
|
|
|
sharding_specs.pmap_sharding_spec(
|
2023-04-12 08:49:07 -07:00
|
|
|
|
replicas.num_local_replicas, pci.axis_size, aval.shape, out_axis)
|
|
|
|
|
for aval, out_axis in safe_zip(
|
|
|
|
|
shards.out_sharded_avals, pci.out_axes)]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
|
|
|
|
|
|
|
|
|
|
if hasattr(pci.backend, "compile_replicated"):
|
|
|
|
|
input_indices = [
|
2023-04-06 09:48:14 -07:00
|
|
|
|
sharding_specs.spec_to_indices(aval.shape, spec)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if spec is not None else None
|
|
|
|
|
for aval, spec in safe_zip(pci.avals, input_sharding_specs)
|
|
|
|
|
]
|
|
|
|
|
handle_outs = local_avals_to_results_handler(local_unmapped_avals,
|
|
|
|
|
out_shardings)
|
|
|
|
|
return _compile_replicated_pmap_executable_from_hlo(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
hlo, pci, input_indices, in_shardings, handle_outs,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
compile_options, host_callbacks, bool(unordered_effects),
|
2023-04-19 15:08:21 -07:00
|
|
|
|
ordered_effects, jaxpr_debug_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
with dispatch.log_elapsed_time(
|
2023-05-15 09:15:22 -07:00
|
|
|
|
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
compiled = dispatch.compile_or_get_cached(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
pci.backend, hlo, device_assignment, compile_options,
|
2023-04-20 06:16:12 -07:00
|
|
|
|
host_callbacks)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
return UnloadedPmapExecutable(
|
|
|
|
|
compiled=compiled,
|
|
|
|
|
backend=pci.backend,
|
|
|
|
|
local_input_avals=pci.avals,
|
|
|
|
|
input_shardings=in_shardings,
|
|
|
|
|
local_output_avals=local_unmapped_avals,
|
|
|
|
|
output_shardings=out_shardings,
|
|
|
|
|
unordered_effects=unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects,
|
|
|
|
|
keepalive=keepalive,
|
|
|
|
|
host_callbacks=host_callbacks,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
jaxpr_debug_info=jaxpr_debug_info).load()
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-03-13 14:08:48 -07:00
|
|
|
|
def _compile_replicated_pmap_executable_from_hlo(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
hlo: ir.Module, pci, input_indices, in_shardings, handle_outs,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
compile_options, host_callbacks, has_unordered_effects, ordered_effects,
|
|
|
|
|
jaxpr_debug_info):
|
2023-03-13 14:08:48 -07:00
|
|
|
|
# Use the standard out_handler.
|
|
|
|
|
execute_fun = pci.backend.compile_replicated(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
is_trivial=False, name=pci.name, computation=hlo,
|
2023-03-13 14:08:48 -07:00
|
|
|
|
compile_options=compile_options, host_callbacks=host_callbacks,
|
|
|
|
|
has_unordered_effects=has_unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects, in_avals=pci.avals,
|
|
|
|
|
in_indices=input_indices, in_shardings=in_shardings,
|
|
|
|
|
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
|
|
|
|
|
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
2023-04-19 15:08:21 -07:00
|
|
|
|
return PmapExecutable(None, lambda: execute_fun, None, pci.avals,
|
|
|
|
|
jaxpr_debug_info, None)
|
2023-03-13 14:08:48 -07:00
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
class PmapExecutable(stages.XlaExecutable):
|
2023-03-22 17:22:39 -07:00
|
|
|
|
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
|
2023-04-19 15:08:21 -07:00
|
|
|
|
"fingerprint", "in_avals", "_jaxpr_debug_info",
|
|
|
|
|
"_unloaded_executable"]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-22 17:22:39 -07:00
|
|
|
|
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
in_avals, jaxpr_debug_info, unloaded_executable):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.xla_executable = xla_executable
|
2023-03-22 17:22:39 -07:00
|
|
|
|
self._unsafe_call = None
|
|
|
|
|
self.build_unsafe_call = build_unsafe_call
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.fingerprint = fingerprint
|
|
|
|
|
self.in_avals = in_avals
|
2023-04-19 15:08:21 -07:00
|
|
|
|
self._jaxpr_debug_info = jaxpr_debug_info
|
2023-03-22 17:22:39 -07:00
|
|
|
|
self._unloaded_executable = unloaded_executable
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def unsafe_call(self) -> Callable[..., Any]:
|
|
|
|
|
if self._unsafe_call is None:
|
|
|
|
|
self._unsafe_call = self.build_unsafe_call()
|
|
|
|
|
return self._unsafe_call
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# -- stages.XlaExecutable overrides
|
|
|
|
|
|
|
|
|
|
def xla_extension_executable(self):
|
|
|
|
|
return self.xla_executable
|
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
|
|
|
|
def call(self, *args):
|
|
|
|
|
# TODO(frostig): do we need to check sharding and sharded avals?
|
|
|
|
|
arg_avals = map(xla.abstractify, args)
|
2023-04-19 15:08:21 -07:00
|
|
|
|
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
|
2023-03-22 17:22:39 -07:00
|
|
|
|
return self.unsafe_call(*args) # pylint: disable=not-callable
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_pmap_sharding(devices, specs):
|
2023-03-13 08:49:39 -07:00
|
|
|
|
return [sharding_impls.PmapSharding(devices, spec) for spec in specs]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
multi_host_supported_collectives: set[core.Primitive] = set()
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_multihost_collective_allowlist(jaxpr):
|
|
|
|
|
used_collectives = set(xla.jaxpr_collectives(jaxpr))
|
|
|
|
|
if not used_collectives.issubset(multi_host_supported_collectives):
|
|
|
|
|
bad_collectives = used_collectives - multi_host_supported_collectives
|
|
|
|
|
msg = "using collectives that aren't supported for multi-host: {}"
|
|
|
|
|
raise TypeError(msg.format(", ".join(map(str, bad_collectives))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InputsHandler:
|
|
|
|
|
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
|
|
|
|
|
|
|
|
|
|
def __init__(self, local_devices, in_shardings, input_indices):
|
2023-03-02 13:28:25 -08:00
|
|
|
|
self.handler = partial(
|
|
|
|
|
shard_args, local_devices, input_indices, in_shardings)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.local_devices = local_devices
|
|
|
|
|
self.in_shardings = in_shardings
|
|
|
|
|
self.input_indices = input_indices
|
|
|
|
|
|
|
|
|
|
def __call__(self, input_buffers):
|
|
|
|
|
return self.handler(input_buffers)
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
return ("InputsHandler(\n"
|
|
|
|
|
f"local_devices={self.local_devices},\n"
|
|
|
|
|
f"in_shardings={self.in_shardings},\n"
|
|
|
|
|
f"input_indices={self.input_indices})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResultsHandler:
|
|
|
|
|
# `out_avals` is the `GlobalDeviceArray` global avals when using pjit or xmap
|
|
|
|
|
# with `config.parallel_functions_output_gda=True`. It is the local one
|
|
|
|
|
# otherwise, and also when using `pmap`.
|
|
|
|
|
__slots__ = ("handlers", "out_shardings", "out_avals")
|
|
|
|
|
|
|
|
|
|
def __init__(self, handlers, out_shardings, out_avals):
|
|
|
|
|
self.handlers = handlers
|
|
|
|
|
self.out_shardings = out_shardings
|
|
|
|
|
self.out_avals = out_avals
|
|
|
|
|
|
|
|
|
|
def __call__(self, out_bufs):
|
|
|
|
|
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def local_avals_to_results_handler(
|
|
|
|
|
unmapped_local_out_avals: Sequence[ShapedArray],
|
2023-03-13 08:49:39 -07:00
|
|
|
|
local_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> ResultsHandler:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
|
|
|
|
|
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
|
|
|
|
|
handlers = [
|
|
|
|
|
local_aval_to_result_handler(aval, s, idcs)
|
|
|
|
|
for aval, s, idcs in safe_zip(unmapped_local_out_avals, local_shardings, out_indices)
|
|
|
|
|
]
|
|
|
|
|
return ResultsHandler(handlers, local_shardings, unmapped_local_out_avals)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def global_avals_to_results_handler(
|
|
|
|
|
global_out_avals: Sequence[ShapedArray],
|
2023-03-13 08:49:39 -07:00
|
|
|
|
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
committed: bool,
|
|
|
|
|
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
|
2023-03-15 17:08:21 -07:00
|
|
|
|
handlers = [
|
|
|
|
|
global_aval_to_result_handler(global_aval, s, committed, x)
|
|
|
|
|
for global_aval, s, x in safe_zip(global_out_avals, shardings,
|
|
|
|
|
are_out_shardings_from_xla)
|
|
|
|
|
]
|
|
|
|
|
return ResultsHandler(handlers, shardings, global_out_avals)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
|
|
|
|
def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
|
|
|
|
|
"""Replicates ``val`` across multiple devices.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
val: the value to be replicated.
|
|
|
|
|
axis_size: the length of the output, i.e. the logical number of replicas to
|
|
|
|
|
create. Usually equal to `nrep`, but in the case of nested pmaps, `nrep` may
|
|
|
|
|
be a multiple of `axis_size`.
|
|
|
|
|
nrep: the number of replicas to create. If ``devices`` is set, must be equal
|
|
|
|
|
to ``len(devices)``.
|
|
|
|
|
devices: the devices to replicate across. If None, ``nrep`` will be used to
|
|
|
|
|
generate a default device assignment.
|
|
|
|
|
backend: string specifying which backend to use.
|
|
|
|
|
in_axis: axis along which the value is to be replciated.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A ShardedDeviceArray of length `axis_size` where each shard is equal to
|
|
|
|
|
``val``.
|
|
|
|
|
"""
|
|
|
|
|
device_count = (len(devices) if devices else xb.local_device_count(backend))
|
|
|
|
|
if nrep > device_count:
|
|
|
|
|
msg = ("Cannot replicate across %d replicas because only %d local devices "
|
|
|
|
|
"are available." % (nrep, device_count))
|
|
|
|
|
if devices:
|
|
|
|
|
msg += (" (local devices = %s)"
|
|
|
|
|
% ", ".join(map(str, devices)) if devices else str(None))
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|
if devices is None:
|
|
|
|
|
assert nrep is not None
|
|
|
|
|
# TODO(skye): use different device assignment on multihost
|
|
|
|
|
devices = xb.get_backend(backend).get_default_device_assignment(nrep)
|
|
|
|
|
assert nrep == len(devices)
|
|
|
|
|
|
2023-03-28 18:30:36 -07:00
|
|
|
|
aval = xla.abstractify(val)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if in_axis is not None:
|
|
|
|
|
replicated_aval = aval.update(shape=(axis_size,) + aval.shape)
|
|
|
|
|
else:
|
|
|
|
|
replicated_aval = aval
|
|
|
|
|
# TODO(skye): figure out how partitioning should work here
|
2023-04-06 09:48:14 -07:00
|
|
|
|
sharding_spec = sharding_specs.pmap_sharding_spec(
|
2023-04-12 08:49:07 -07:00
|
|
|
|
nrep, axis_size, aval.shape, in_axis)
|
2023-02-16 14:52:19 -08:00
|
|
|
|
|
2023-03-15 17:08:21 -07:00
|
|
|
|
buf = jax.device_put(val, devices[0])
|
|
|
|
|
sharding = sharding_impls.PmapSharding(
|
|
|
|
|
np.asarray([d for d in devices]), sharding_spec)
|
|
|
|
|
return batched_device_put(replicated_aval, sharding, [buf] * len(devices),
|
|
|
|
|
devices)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExecuteReplicated:
|
|
|
|
|
"""The logic to shard inputs, execute a replicated model, returning outputs."""
|
|
|
|
|
__slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler',
|
|
|
|
|
'has_unordered_effects', 'ordered_effects', 'keepalive',
|
|
|
|
|
'has_host_callbacks', '_local_devices', 'kept_var_idx',
|
|
|
|
|
'__weakref__']
|
|
|
|
|
|
|
|
|
|
def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
|
|
|
|
|
out_handler: ResultsHandler,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
unordered_effects: list[core.Effect],
|
|
|
|
|
ordered_effects: list[core.Effect], keepalive: Any,
|
|
|
|
|
has_host_callbacks: bool, kept_var_idx: set[int]):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.xla_executable = xla_executable
|
|
|
|
|
self.name = name
|
|
|
|
|
self.backend = backend
|
|
|
|
|
self.in_handler = in_handler
|
|
|
|
|
self.out_handler = out_handler
|
|
|
|
|
self.has_unordered_effects = bool(unordered_effects)
|
|
|
|
|
self.ordered_effects = ordered_effects
|
|
|
|
|
self._local_devices = self.xla_executable.local_devices()
|
|
|
|
|
if ordered_effects:
|
|
|
|
|
assert len(self._local_devices) == 1
|
|
|
|
|
self.keepalive = keepalive
|
|
|
|
|
self.has_host_callbacks = has_host_callbacks
|
|
|
|
|
self.kept_var_idx = kept_var_idx
|
|
|
|
|
|
2023-02-27 18:26:12 -08:00
|
|
|
|
def _add_tokens_to_inputs(self, input_bufs):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if self.ordered_effects:
|
|
|
|
|
device, = self._local_devices
|
|
|
|
|
tokens = [list(dispatch.runtime_tokens.get_token(eff, device))
|
|
|
|
|
for eff in self.ordered_effects]
|
|
|
|
|
input_bufs = [*tokens, *input_bufs]
|
2023-02-27 18:26:12 -08:00
|
|
|
|
return input_bufs
|
|
|
|
|
|
|
|
|
|
def _handle_token_bufs(self, token_bufs, sharded_token):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for i, device in enumerate(self._local_devices):
|
|
|
|
|
dispatch.runtime_tokens.set_output_runtime_token(
|
|
|
|
|
device, sharded_token.get_token(i))
|
|
|
|
|
for eff, token_buf in zip(self.ordered_effects, token_bufs):
|
|
|
|
|
dispatch.runtime_tokens.update_token(eff, token_buf)
|
2023-02-27 18:26:12 -08:00
|
|
|
|
|
|
|
|
|
def _call_with_tokens(self, input_bufs):
|
|
|
|
|
input_bufs = self._add_tokens_to_inputs(input_bufs)
|
|
|
|
|
out_bufs, sharded_token = (
|
|
|
|
|
self.xla_executable.execute_sharded_on_local_devices_with_tokens(
|
|
|
|
|
input_bufs
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
num_output_tokens = len(self.ordered_effects)
|
|
|
|
|
token_bufs, out_bufs = util.split_list(out_bufs, [num_output_tokens])
|
|
|
|
|
self._handle_token_bufs(token_bufs, sharded_token)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return out_bufs
|
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
|
|
|
|
def __call__(self, *args):
|
|
|
|
|
args = [x for i, x in enumerate(args) if i in self.kept_var_idx]
|
|
|
|
|
input_bufs = self.in_handler(args)
|
2023-03-13 17:09:06 -07:00
|
|
|
|
if (self.ordered_effects or self.has_unordered_effects
|
|
|
|
|
or self.has_host_callbacks):
|
|
|
|
|
input_bufs = self._add_tokens_to_inputs(input_bufs)
|
|
|
|
|
results = self.xla_executable.execute_sharded(
|
|
|
|
|
input_bufs, with_tokens=True
|
|
|
|
|
)
|
|
|
|
|
self._handle_token_bufs(
|
|
|
|
|
results.disassemble_prefix_into_single_device_arrays(
|
|
|
|
|
len(self.ordered_effects)),
|
|
|
|
|
results.consume_token())
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-03-13 17:09:06 -07:00
|
|
|
|
results = self.xla_executable.execute_sharded(input_bufs)
|
|
|
|
|
if dispatch.needs_check_special():
|
|
|
|
|
out_arrays = results.disassemble_into_single_device_arrays()
|
|
|
|
|
for arrays in out_arrays:
|
|
|
|
|
dispatch.check_special(self.name, arrays)
|
|
|
|
|
return self.out_handler(out_arrays)
|
|
|
|
|
return results.consume_with_handlers(self.out_handler.handlers)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xla_pmap_p = core.MapPrimitive('xla_pmap')
|
|
|
|
|
xla_pmap = xla_pmap_p.bind
|
|
|
|
|
xla_pmap_p.def_impl(xla_pmap_impl)
|
|
|
|
|
|
|
|
|
|
def _pmap_partial_eval_custom_params_updater(
|
|
|
|
|
unks_in, inst_in, kept_outs_known, kept_outs_staged, num_res, params_known,
|
|
|
|
|
params_staged):
|
|
|
|
|
# prune inputs to jaxpr_known according to unks_in
|
|
|
|
|
donated_invars_known, _ = partition_list(unks_in, params_known['donated_invars'])
|
|
|
|
|
in_axes_known, _ = partition_list(unks_in, params_known['in_axes'])
|
|
|
|
|
_, out_axes_known = partition_list(kept_outs_known, params_known['out_axes'])
|
|
|
|
|
out_axes_known = out_axes_known + [0] * num_res
|
|
|
|
|
new_params_known = dict(params_known, in_axes=tuple(in_axes_known),
|
|
|
|
|
out_axes=tuple(out_axes_known),
|
|
|
|
|
donated_invars=tuple(donated_invars_known))
|
|
|
|
|
|
|
|
|
|
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
|
|
|
|
|
_, donated_invars_staged = partition_list(inst_in, params_staged['donated_invars'])
|
|
|
|
|
donated_invars_staged = [False] * num_res + donated_invars_staged
|
|
|
|
|
_, in_axes_staged = partition_list(inst_in, params_staged['in_axes'])
|
|
|
|
|
in_axes_staged = [0] * num_res + in_axes_staged
|
|
|
|
|
_, out_axes_staged = partition_list(kept_outs_staged, params_staged['out_axes'])
|
|
|
|
|
new_params_staged = dict(params_staged, in_axes=tuple(in_axes_staged),
|
|
|
|
|
out_axes=tuple(out_axes_staged),
|
|
|
|
|
donated_invars=tuple(donated_invars_staged))
|
|
|
|
|
return new_params_known, new_params_staged
|
|
|
|
|
|
|
|
|
|
def _pmap_partial_eval_custom_res_maker(params_known, aval):
|
|
|
|
|
return core.unmapped_aval(params_known['axis_size'], core.no_axis_name, 0, aval)
|
|
|
|
|
|
|
|
|
|
def _pmap_dce_rule(used_outputs, eqn):
|
|
|
|
|
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
|
|
|
|
|
with maybe_extend_axis_env(eqn.params['axis_name'],
|
|
|
|
|
eqn.params['global_axis_size'], None):
|
|
|
|
|
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
|
|
|
|
|
_, donated_invars = partition_list(used_inputs, eqn.params['donated_invars'])
|
|
|
|
|
_, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
|
|
|
|
|
_, out_axes = partition_list(used_outputs, eqn.params['out_axes'])
|
|
|
|
|
new_params = dict(eqn.params, call_jaxpr=new_jaxpr,
|
|
|
|
|
donated_invars=tuple(donated_invars),
|
|
|
|
|
in_axes=tuple(in_axes), out_axes=tuple(out_axes))
|
|
|
|
|
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
|
|
|
|
|
return used_inputs, None
|
|
|
|
|
else:
|
|
|
|
|
new_eqn = pe.new_jaxpr_eqn(
|
|
|
|
|
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
|
|
|
|
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
|
|
|
|
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
|
|
|
|
return used_inputs, new_eqn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Set param update handlers to update `donated_invars` just like xla_call_p
|
2023-03-23 11:43:49 -07:00
|
|
|
|
pe.call_param_updaters[xla_pmap_p] = xla.xla_call_partial_eval_update_params
|
2023-02-06 14:28:36 -08:00
|
|
|
|
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
|
|
|
|
|
partial(pe.call_partial_eval_custom_rule,
|
|
|
|
|
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
|
|
|
|
|
res_aval=_pmap_partial_eval_custom_res_maker)
|
|
|
|
|
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
|
2023-03-23 11:43:49 -07:00
|
|
|
|
ad.call_param_updaters[xla_pmap_p] = xla.xla_call_jvp_update_params
|
|
|
|
|
ad.call_transpose_param_updaters[xla_pmap_p] = xla.xla_call_transpose_update_params
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
|
|
|
|
|
|
|
|
|
|
def _pmap_axis_subst(params, subst, traverse):
|
|
|
|
|
if 'call_jaxpr' not in params:
|
|
|
|
|
return params
|
|
|
|
|
if not traverse:
|
|
|
|
|
return params
|
|
|
|
|
def shadowed_subst(name):
|
|
|
|
|
return (name,) if name in params['axis_name'] else subst(name)
|
|
|
|
|
with maybe_extend_axis_env(params['axis_name'],
|
|
|
|
|
params['global_axis_size'], None):
|
|
|
|
|
new_jaxpr = core.subst_axis_names_jaxpr(params['call_jaxpr'],
|
|
|
|
|
shadowed_subst)
|
|
|
|
|
return dict(params, call_jaxpr=new_jaxpr)
|
|
|
|
|
core.axis_substitution_rules[xla_pmap_p] = _pmap_axis_subst
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _unravel_index_hlo(axis_env):
|
|
|
|
|
div = mlir.ir_constant(
|
2023-02-28 12:40:30 -08:00
|
|
|
|
np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
|
|
|
|
|
return hlo.RemOp(
|
|
|
|
|
hlo.DivOp(hlo.ReplicaIdOp().result, div).result, mod).result
|
|
|
|
|
|
|
|
|
|
def _hlo_shard(aval, axis_env, xs, in_axis):
|
|
|
|
|
if aval is core.abstract_token:
|
|
|
|
|
return xs
|
|
|
|
|
elif isinstance(aval, core.ShapedArray):
|
|
|
|
|
x, = xs
|
|
|
|
|
dims = list(aval.shape)
|
|
|
|
|
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
|
|
|
|
idxs = [zero] * len(dims)
|
|
|
|
|
idxs.insert(in_axis, _unravel_index_hlo(axis_env))
|
|
|
|
|
dims_unsqueezed = dims.copy()
|
|
|
|
|
dims_unsqueezed.insert(in_axis, 1)
|
|
|
|
|
dynamic_slice_result = hlo.DynamicSliceOp(
|
|
|
|
|
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
|
|
|
|
|
return [
|
|
|
|
|
hlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(b/110096942): more efficient gather
|
|
|
|
|
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform):
|
|
|
|
|
if aval is core.abstract_token:
|
|
|
|
|
return xs
|
|
|
|
|
elif isinstance(aval, core.ShapedArray):
|
|
|
|
|
x, = xs
|
|
|
|
|
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
|
|
|
|
convert_bool = (np.issubdtype(aval.dtype, np.bool_)
|
|
|
|
|
and platform in ('cpu', 'gpu'))
|
|
|
|
|
if convert_bool:
|
|
|
|
|
aval = aval.update(dtype=np.dtype(np.float32))
|
|
|
|
|
x = hlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result
|
|
|
|
|
|
|
|
|
|
dims = list(aval.shape)
|
|
|
|
|
padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims)
|
|
|
|
|
padded = mlir.full_like_aval(ctx, 0, padded_aval)
|
|
|
|
|
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
|
|
|
|
idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims)
|
|
|
|
|
broadcast_result = hlo.BroadcastOp(
|
|
|
|
|
x, mlir.dense_int_elements([1])).result
|
|
|
|
|
padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result
|
|
|
|
|
replica_groups = mlir.dense_int_elements(
|
|
|
|
|
xla.axis_groups(axis_env, axis_env.names[-1]))
|
|
|
|
|
out = hlo.CrossReplicaSumOp(padded, replica_groups).result
|
|
|
|
|
if out_axis != 0:
|
|
|
|
|
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
|
|
|
|
|
perm = list(range(1, len(dims)))
|
|
|
|
|
perm.insert(out_axis, 0)
|
|
|
|
|
transposed_dims = list(dims)
|
|
|
|
|
transposed_dims.insert(out_axis, axis_env.sizes[-1])
|
|
|
|
|
aval = aval.update(shape=transposed_dims)
|
|
|
|
|
out = hlo.TransposeOp(out, mlir.dense_int_elements(perm)).result
|
|
|
|
|
|
|
|
|
|
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
|
|
|
|
if convert_bool:
|
|
|
|
|
float_zero = mlir.full_like_aval(ctx, 0, padded_aval)
|
|
|
|
|
out = hlo.CompareOp(
|
|
|
|
|
out,
|
|
|
|
|
float_zero,
|
|
|
|
|
hlo.ComparisonDirectionAttr.get("NE"),
|
|
|
|
|
compare_type=hlo.ComparisonTypeAttr.get("FLOAT")).result
|
|
|
|
|
return out
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _pmap_lowering(ctx, *in_nodes, axis_name,
|
|
|
|
|
axis_size, global_axis_size, devices, name,
|
|
|
|
|
call_jaxpr, backend=None, in_axes, out_axes,
|
2023-03-29 09:22:34 -07:00
|
|
|
|
donated_invars, is_explicit_global_axis_size):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
del donated_invars # Unused.
|
|
|
|
|
xla.check_backend_matches(backend, ctx.module_context.platform)
|
|
|
|
|
# We in-line here rather than generating a Call HLO as in the xla_call
|
|
|
|
|
# translation rule just because the extra tuple stuff is a pain.
|
|
|
|
|
if ctx.module_context.axis_env.names and devices is not None:
|
|
|
|
|
raise ValueError("Nested pmap with explicit devices argument.")
|
|
|
|
|
new_env = xla.extend_axis_env(ctx.module_context.axis_env, axis_name,
|
|
|
|
|
global_axis_size)
|
|
|
|
|
# Shard the in_nodes that are mapped
|
|
|
|
|
in_avals = [v.aval for v in call_jaxpr.invars]
|
|
|
|
|
in_nodes_sharded = (
|
|
|
|
|
_hlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis)
|
|
|
|
|
if in_axis is not None else mlir.wrap_singleton_ir_values(in_node)
|
|
|
|
|
for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes))
|
|
|
|
|
|
|
|
|
|
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
|
|
|
|
sub_ctx = ctx.module_context.replace(
|
2023-04-10 10:15:08 -07:00
|
|
|
|
axis_context=sharding_impls.ReplicaAxisContext(new_env),
|
2023-02-27 11:37:10 -08:00
|
|
|
|
name_stack=ctx.module_context.name_stack.extend(
|
|
|
|
|
util.wrap_name(name, 'pmap')))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (),
|
|
|
|
|
*in_nodes_sharded,
|
|
|
|
|
dim_var_values=ctx.dim_var_values)
|
|
|
|
|
out_avals = [v.aval for v in call_jaxpr.outvars]
|
|
|
|
|
outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard,
|
|
|
|
|
platform=ctx.module_context.platform)
|
|
|
|
|
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
|
|
|
|
|
return outs
|
|
|
|
|
|
|
|
|
|
mlir.register_lowering(xla_pmap_p, _pmap_lowering)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------- xmap -------------------
|
|
|
|
|
|
|
|
|
|
def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
|
|
|
|
|
assert isinstance(aval, ShapedArray)
|
|
|
|
|
shape = list(aval.shape)
|
|
|
|
|
named_shape = dict(aval.named_shape)
|
|
|
|
|
for name, axis in in_axes.items():
|
|
|
|
|
assert shape[axis] % axis_sizes[name] == 0
|
|
|
|
|
assert name not in named_shape
|
|
|
|
|
named_shape[name] = axis_sizes[name]
|
|
|
|
|
shape[axis] //= axis_sizes[name]
|
|
|
|
|
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
|
|
|
|
|
|
|
|
|
def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval):
|
|
|
|
|
assert isinstance(aval, ShapedArray)
|
|
|
|
|
shape = list(aval.shape)
|
|
|
|
|
named_shape = dict(aval.named_shape)
|
|
|
|
|
for name, axis in out_axes.items():
|
|
|
|
|
shape[axis] *= axis_sizes[name]
|
|
|
|
|
named_shape.pop(name, None) # The name might be missing --- it's a broadcast.
|
|
|
|
|
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
|
|
|
|
|
|
|
|
|
|
2023-03-10 10:07:37 -08:00
|
|
|
|
def mesh_local_to_global(mesh, axes: ArrayMapping, aval):
|
|
|
|
|
return untile_aval_nd(mesh.shape, axes,
|
|
|
|
|
tile_aval_nd(mesh.local_mesh.shape, axes, aval))
|
|
|
|
|
|
|
|
|
|
def mesh_global_to_local(mesh, axes: ArrayMapping, aval):
|
|
|
|
|
return untile_aval_nd(mesh.local_mesh.shape, axes,
|
|
|
|
|
tile_aval_nd(mesh.shape, axes, aval))
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
class SPMDBatchTrace(batching.BatchTrace):
|
|
|
|
|
def get_axis_primitive_batcher(self, primitive, frame):
|
|
|
|
|
if primitive in spmd_primitive_batchers:
|
|
|
|
|
return partial(spmd_primitive_batchers[primitive],
|
|
|
|
|
frame.size, frame.name, frame.main_trace.trace_type)
|
|
|
|
|
return super().get_axis_primitive_batcher(primitive, frame)
|
|
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
spmd_primitive_batchers: dict[core.Primitive, Callable] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vtile_by_mesh(fun: lu.WrappedFun,
|
|
|
|
|
mesh: Mesh,
|
|
|
|
|
in_axes: Sequence[ArrayMapping],
|
|
|
|
|
out_axes: Sequence[ArrayMapping]):
|
|
|
|
|
# We vectorize in reversed order, because vmap is often biased towards
|
|
|
|
|
# moving the batch axis to the front, and this way of stacking transforms
|
|
|
|
|
# will order the batch axes according to the mesh axis order.
|
|
|
|
|
# Not strictly necessary, but seems nicer than reversing it?
|
|
|
|
|
for name, size in reversed(mesh.shape.items()):
|
|
|
|
|
fun = batching.vtile(fun,
|
|
|
|
|
tuple(a.get(name, None) for a in in_axes),
|
|
|
|
|
tuple(a.get(name, None) for a in out_axes),
|
|
|
|
|
tile_size=size,
|
|
|
|
|
axis_name=name,
|
|
|
|
|
main_type=SPMDBatchTrace)
|
|
|
|
|
return fun
|
|
|
|
|
|
|
|
|
|
full_to_shard_p = core.Primitive('full_to_shard')
|
|
|
|
|
|
|
|
|
|
@full_to_shard_p.def_abstract_eval
|
|
|
|
|
def _full_to_shard_abstract_eval(x, axes, mesh, **_):
|
|
|
|
|
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
|
|
|
|
|
return tile_aval_nd(mesh.shape, axes, x)
|
|
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
|
def manual_proto(
|
|
|
|
|
aval: core.ShapedArray,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
manual_axes_set: frozenset[sharding_impls.MeshAxisName], mesh: Mesh):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
|
|
|
|
|
and all others as replicated.
|
|
|
|
|
"""
|
|
|
|
|
named_mesh_shape = mesh.shape
|
|
|
|
|
mesh_shape = list(named_mesh_shape.values())
|
|
|
|
|
axis_order = {axis: i for i, axis in enumerate(mesh.axis_names)}
|
|
|
|
|
|
|
|
|
|
manual_axes = list(sorted(manual_axes_set, key=str))
|
|
|
|
|
replicated_axes = list(axis for axis in mesh.axis_names if axis not in manual_axes_set)
|
|
|
|
|
|
|
|
|
|
tad_perm = ([axis_order[a] for a in replicated_axes] +
|
|
|
|
|
[axis_order[a] for a in manual_axes])
|
|
|
|
|
tad_shape = [1] * aval.ndim
|
2023-04-13 11:48:11 -07:00
|
|
|
|
tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes]))
|
|
|
|
|
tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes]))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-13 11:48:11 -07:00
|
|
|
|
raw_mesh = np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
proto = xc.OpSharding()
|
|
|
|
|
proto.type = xc.OpSharding.Type.OTHER
|
|
|
|
|
proto.tile_assignment_dimensions = tad_shape
|
|
|
|
|
proto.tile_assignment_devices = list(raw_mesh.transpose(tad_perm).reshape(tad_shape).flat)
|
|
|
|
|
proto.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL]
|
|
|
|
|
return proto
|
|
|
|
|
|
|
|
|
|
@partial(mlir.register_lowering, full_to_shard_p)
|
2023-03-10 10:07:37 -08:00
|
|
|
|
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
manual_axes: frozenset[sharding_impls.MeshAxisName]):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# TODO: Can we short-circuit for replicated values? Probably not.
|
|
|
|
|
aval_in, = ctx.avals_in
|
|
|
|
|
aval_out, = ctx.avals_out
|
2023-06-05 13:40:59 -07:00
|
|
|
|
sharding_proto = mesh_sharding_specs(
|
|
|
|
|
mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto().to_proto()
|
2023-02-06 14:28:36 -08:00
|
|
|
|
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
|
2023-04-05 09:38:37 +02:00
|
|
|
|
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto, unspecified_dims=unspecified_dims)
|
2023-02-07 11:16:01 -08:00
|
|
|
|
proto = manual_proto(aval_in, manual_axes, mesh)
|
2023-04-05 09:38:37 +02:00
|
|
|
|
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto, unspecified_dims=unspecified_dims),
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
shard_to_full_p = core.Primitive('shard_to_full')
|
|
|
|
|
|
|
|
|
|
@shard_to_full_p.def_abstract_eval
|
|
|
|
|
def _shard_to_full_abstract_eval(x, axes, mesh, **_):
|
|
|
|
|
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
|
|
|
|
|
return untile_aval_nd(mesh.shape, axes, x)
|
|
|
|
|
|
|
|
|
|
@partial(mlir.register_lowering, shard_to_full_p)
|
2023-04-05 09:38:37 +02:00
|
|
|
|
def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapping, mesh: Mesh,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
manual_axes: frozenset[sharding_impls.MeshAxisName]):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
aval_in, = ctx.avals_in
|
|
|
|
|
aval_out, = ctx.avals_out
|
2023-04-05 09:38:37 +02:00
|
|
|
|
proto = manual_proto(aval_in, manual_axes, mesh) # type: ignore
|
|
|
|
|
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) # type: ignore
|
|
|
|
|
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto, unspecified_dims=unspecified_dims)
|
2023-06-05 13:40:59 -07:00
|
|
|
|
sharding_proto = mesh_sharding_specs(
|
|
|
|
|
mesh.shape, mesh.axis_names)(aval_out, axes).sharding_proto().to_proto()
|
2023-04-05 09:38:37 +02:00
|
|
|
|
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto, unspecified_dims),
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
@lu.transformation
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def vtile_manual(manual_axes: frozenset[sharding_impls.MeshAxisName],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
mesh: Mesh,
|
|
|
|
|
in_axes: Sequence[ArrayMapping],
|
|
|
|
|
out_axes: Sequence[ArrayMapping],
|
|
|
|
|
*args):
|
|
|
|
|
tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh, manual_axes=manual_axes)
|
|
|
|
|
for arg, axes in zip(args, in_axes)]
|
|
|
|
|
tiled_outs = yield tiled_args, {}
|
|
|
|
|
outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh, manual_axes=manual_axes)
|
|
|
|
|
for out, axes in zip(tiled_outs, out_axes)]
|
|
|
|
|
yield outs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class TileVectorize:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class TileManual:
|
2023-06-23 15:11:37 -07:00
|
|
|
|
manual_axes: frozenset[sharding_impls.MeshAxisName]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
TilingMethod = Union[TileVectorize, TileManual]
|
|
|
|
|
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
def check_if_any_auto(
|
2023-03-13 08:49:39 -07:00
|
|
|
|
shardings: Iterable[Union[sharding_impls.XLACompatibleSharding,
|
2023-05-20 22:59:52 -07:00
|
|
|
|
AUTO, UnspecifiedValue]]) -> bool:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for s in shardings:
|
2023-02-07 11:16:01 -08:00
|
|
|
|
if is_auto(s):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
|
class MismatchType(enum.Enum):
|
|
|
|
|
ARG_SHARDING = 0
|
|
|
|
|
OUT_SHARDING = 1
|
|
|
|
|
SHARDING_INSIDE_COMPUTATION = 2
|
|
|
|
|
CONTEXT_DEVICES = 3
|
|
|
|
|
IN_SHARDING = 4
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
if self.name == 'IN_SHARDING':
|
|
|
|
|
return 'explicit input sharding'
|
|
|
|
|
elif self.name == 'OUT_SHARDING':
|
|
|
|
|
return 'explicit output sharding'
|
|
|
|
|
elif self.name == 'CONTEXT_DEVICES':
|
|
|
|
|
return 'devices'
|
|
|
|
|
return f'{self.name}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class DeviceAssignmentMismatch:
|
|
|
|
|
da: Sequence[xc.Device]
|
|
|
|
|
m_type: MismatchType
|
2023-02-10 15:36:04 -08:00
|
|
|
|
source_info: Optional[dispatch.SourceInfo]
|
2023-02-10 13:53:43 -08:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def device_ids(self) -> Sequence[int]:
|
|
|
|
|
return [d.id for d in self.da]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def platform(self) -> str:
|
|
|
|
|
return self.da[0].platform.upper()
|
|
|
|
|
|
|
|
|
|
def _maybe_api_name(self, api_name) -> str:
|
|
|
|
|
return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else ""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def source_info_str(self):
|
2023-02-10 15:36:04 -08:00
|
|
|
|
return "" if self.source_info is None else f" at {self.source_info.source_info}"
|
2023-02-10 13:53:43 -08:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _dev_ids_plat_str(self):
|
|
|
|
|
return f"device ids {self.device_ids} on platform {self.platform}"
|
|
|
|
|
|
2023-02-10 15:36:04 -08:00
|
|
|
|
def m_type_str(self, api_name):
|
2023-04-07 07:09:44 -07:00
|
|
|
|
return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}'
|
2023-02-10 15:36:04 -08:00
|
|
|
|
if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type)
|
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
|
def _str(self, api_name):
|
2023-02-10 15:36:04 -08:00
|
|
|
|
return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with "
|
2023-02-10 13:53:43 -08:00
|
|
|
|
f"{self._dev_ids_plat_str}{self.source_info_str}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeviceAssignmentMismatchError(Exception):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
ShardingInfo = tuple[
|
2023-05-20 22:59:52 -07:00
|
|
|
|
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO],
|
2023-02-10 15:36:04 -08:00
|
|
|
|
MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-09 16:18:31 -08:00
|
|
|
|
|
|
|
|
|
def _get_default_device() -> xc.Device:
|
|
|
|
|
return config.jax_default_device or xb.local_devices()[0]
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _get_and_check_device_assignment(
|
2023-02-10 13:53:43 -08:00
|
|
|
|
shardings: Iterable[ShardingInfo],
|
|
|
|
|
devices: Optional[Sequence[xc.Device]],
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[xc.Client, tuple[xc.Device, ...]]:
|
2023-02-10 13:53:43 -08:00
|
|
|
|
first_sharding_info = None
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if devices is None:
|
2023-04-13 08:02:53 -07:00
|
|
|
|
devices = ()
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-04-13 08:02:53 -07:00
|
|
|
|
devices = tuple(devices)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
|
for i, s_type, source_info in shardings:
|
2023-05-20 22:59:52 -07:00
|
|
|
|
if is_unspecified(i):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
continue
|
2023-05-20 22:59:52 -07:00
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
|
if first_sharding_info is None:
|
2023-05-20 22:59:52 -07:00
|
|
|
|
first_sharding_info = (
|
|
|
|
|
(i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore
|
|
|
|
|
else (i._device_assignment, s_type, source_info)) # type: ignore
|
|
|
|
|
arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if not devices:
|
2023-02-10 13:53:43 -08:00
|
|
|
|
if first_sharding_info[0] != arr_device_assignment:
|
|
|
|
|
raise DeviceAssignmentMismatchError([
|
|
|
|
|
DeviceAssignmentMismatch(*first_sharding_info),
|
|
|
|
|
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
|
|
|
|
if devices != arr_device_assignment:
|
2023-02-10 13:53:43 -08:00
|
|
|
|
raise DeviceAssignmentMismatchError([
|
|
|
|
|
DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None),
|
|
|
|
|
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
|
|
|
|
|
if first_sharding_info is None and devices:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
final_device_assignment = devices
|
2023-02-10 13:53:43 -08:00
|
|
|
|
elif first_sharding_info is None:
|
2023-04-13 08:02:53 -07:00
|
|
|
|
final_device_assignment = (_get_default_device(),)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-02-10 13:53:43 -08:00
|
|
|
|
final_device_assignment = first_sharding_info[0]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
|
|
|
|
|
|
2023-04-05 14:09:46 -07:00
|
|
|
|
MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
def cache_wrap(fn):
|
|
|
|
|
_wrapped_with_lu_cache = lu.cache(fn)
|
|
|
|
|
_wrapped_with_weakref_lru_cache = weakref_lru_cache(fn)
|
|
|
|
|
def wrapped(f, *args, **kwargs):
|
|
|
|
|
if isinstance(f, lu.WrappedFun):
|
|
|
|
|
return _wrapped_with_lu_cache(f, *args, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return _wrapped_with_weakref_lru_cache(f, *args, **kwargs)
|
|
|
|
|
return wrapped
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
@cache_wrap
|
|
|
|
|
def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
|
2023-05-20 22:59:52 -07:00
|
|
|
|
keep_unused, donated_invars, auto_spmd_lowering):
|
2023-02-27 11:37:10 -08:00
|
|
|
|
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-01 10:04:59 -08:00
|
|
|
|
if isinstance(fun_or_jaxpr, lu.WrappedFun):
|
2023-04-09 15:41:32 -07:00
|
|
|
|
with dispatch.log_elapsed_time(
|
2023-05-15 09:15:22 -07:00
|
|
|
|
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT):
|
2023-03-01 10:04:59 -08:00
|
|
|
|
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
|
make mlir arg and result names work with static_argnums/argnames
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
2023-03-17 17:45:41 -07:00
|
|
|
|
fun_or_jaxpr, global_in_avals)
|
2023-03-01 10:04:59 -08:00
|
|
|
|
else:
|
|
|
|
|
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
|
|
|
|
|
jaxpr = fun_or_jaxpr.jaxpr
|
|
|
|
|
global_out_avals = fun_or_jaxpr.out_avals
|
|
|
|
|
consts = fun_or_jaxpr.consts
|
|
|
|
|
|
2023-05-20 22:59:52 -07:00
|
|
|
|
if (keep_unused or auto_spmd_lowering or
|
2023-04-04 15:20:32 -07:00
|
|
|
|
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
|
|
|
|
|
for a in global_in_avals)):
|
|
|
|
|
kept_var_idx = set(range(len(global_in_avals)))
|
|
|
|
|
else:
|
|
|
|
|
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
|
|
|
|
|
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
|
|
|
|
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
|
|
|
|
|
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
|
|
|
|
|
del kept_const_idx
|
|
|
|
|
|
|
|
|
|
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
return (closed_jaxpr, global_in_avals, tuple(global_out_avals), donated_invars,
|
|
|
|
|
kept_var_idx, name_stack)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class SemanticallyEqualShardings:
|
2023-06-23 15:11:37 -07:00
|
|
|
|
shardings: tuple[Union[sharding_impls.GSPMDSharding, UnspecifiedValue], ...]
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
return hash(tuple(
|
2023-06-06 06:34:48 -07:00
|
|
|
|
s._hlo_sharding_hash if isinstance(s, sharding_impls.GSPMDSharding) else s # type: ignore
|
2023-04-09 15:41:32 -07:00
|
|
|
|
for s in self.shardings))
|
|
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
if not isinstance(other, SemanticallyEqualShardings):
|
|
|
|
|
return False
|
2023-06-06 06:34:48 -07:00
|
|
|
|
return all(op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
if (isinstance(s, sharding_impls.GSPMDSharding) and
|
|
|
|
|
isinstance(o, sharding_impls.GSPMDSharding))
|
|
|
|
|
else s == o for s, o in zip(self.shardings, other.shardings))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
|
|
|
|
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
|
|
|
|
semantic_in_shardings, semantic_out_shardings,
|
|
|
|
|
da_object, lowering_platform,
|
2023-07-11 10:23:48 -07:00
|
|
|
|
donated_invars, name_stack, override_lowering_rules):
|
2023-04-04 15:20:32 -07:00
|
|
|
|
jaxpr = closed_jaxpr.jaxpr
|
2023-04-09 15:41:32 -07:00
|
|
|
|
in_shardings = semantic_in_shardings.shardings
|
|
|
|
|
out_shardings = semantic_out_shardings.shardings
|
|
|
|
|
global_in_avals = closed_jaxpr.in_avals
|
|
|
|
|
global_out_avals = closed_jaxpr.out_avals
|
|
|
|
|
device_assignment = da_object.device_assignment
|
2023-04-04 15:20:32 -07:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
2023-04-17 07:52:56 -07:00
|
|
|
|
if logger.isEnabledFor(log_priority):
|
|
|
|
|
logger.log(log_priority,
|
|
|
|
|
"Compiling %s for with global shapes and types %s. "
|
|
|
|
|
"Argument mapping: %s.",
|
|
|
|
|
fun_name, global_in_avals, in_shardings)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
# Look at the number of replcas present in the jaxpr. In
|
|
|
|
|
# lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is
|
|
|
|
|
# handled here so as to deprecate the lower_xla_callable codepath when
|
|
|
|
|
# `jax.Array` is turned on by default.
|
|
|
|
|
# TODO(yashkatariya): Remove this when `jit(pmap)` is removed.
|
|
|
|
|
nreps = dispatch.jaxpr_replicas(jaxpr)
|
|
|
|
|
dispatch.raise_warnings_or_errors_for_jit_of_pmap(
|
|
|
|
|
nreps, backend, fun_name, jaxpr)
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
in_mlir_shardings: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
|
|
|
|
|
out_mlir_shardings: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
|
2023-04-09 15:41:32 -07:00
|
|
|
|
axis_ctx: mlir.AxisContext
|
|
|
|
|
|
|
|
|
|
if nreps == 1:
|
2023-04-14 13:55:52 -07:00
|
|
|
|
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
|
|
|
|
|
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
replicated_args = [False] * len(global_in_avals)
|
2023-04-10 10:15:08 -07:00
|
|
|
|
axis_ctx = sharding_impls.ShardingContext(device_assignment)
|
2023-04-13 08:55:01 -07:00
|
|
|
|
num_partitions = len(device_assignment)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
else:
|
|
|
|
|
# This path is triggered for `jit(pmap)` cases.
|
|
|
|
|
replicated_args = None
|
2023-04-14 13:55:52 -07:00
|
|
|
|
in_mlir_shardings = None
|
|
|
|
|
out_mlir_shardings = None
|
2023-04-10 10:15:08 -07:00
|
|
|
|
axis_env = sharding_impls.AxisEnv(nreps, (), ())
|
|
|
|
|
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
|
2023-04-13 08:55:01 -07:00
|
|
|
|
num_partitions = 1
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
module_name = f"{api_name}_{fun_name}"
|
|
|
|
|
|
|
|
|
|
if len(device_assignment) > 1:
|
|
|
|
|
if any(effects.ordered_effects.contains(eff) for eff
|
|
|
|
|
in closed_jaxpr.effects):
|
|
|
|
|
raise ValueError("Ordered effects are not supported for more than 1 device.")
|
|
|
|
|
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
2023-05-15 08:07:31 -07:00
|
|
|
|
|
|
|
|
|
with dispatch.log_elapsed_time(
|
2023-05-15 09:15:22 -07:00
|
|
|
|
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
2023-05-15 08:07:31 -07:00
|
|
|
|
lowering_result = mlir.lower_jaxpr_to_module(
|
|
|
|
|
module_name,
|
|
|
|
|
closed_jaxpr,
|
|
|
|
|
ordered_effects,
|
|
|
|
|
backend,
|
|
|
|
|
# Optionally, override the lowering platform
|
|
|
|
|
lowering_platform or backend.platform,
|
|
|
|
|
axis_ctx,
|
|
|
|
|
name_stack,
|
|
|
|
|
donated_invars,
|
|
|
|
|
replicated_args=replicated_args,
|
|
|
|
|
arg_shardings=in_mlir_shardings,
|
|
|
|
|
result_shardings=out_mlir_shardings,
|
|
|
|
|
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
|
|
|
|
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
|
|
|
|
num_replicas=nreps,
|
2023-07-11 10:23:48 -07:00
|
|
|
|
num_partitions=num_partitions,
|
|
|
|
|
override_lowering_rules=override_lowering_rules)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
|
2023-04-13 16:57:03 -07:00
|
|
|
|
unordered_effects = list(
|
|
|
|
|
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
2023-04-21 14:37:52 -07:00
|
|
|
|
return (lowering_result.module, lowering_result.keepalive,
|
|
|
|
|
lowering_result.host_callbacks, unordered_effects, ordered_effects,
|
2023-05-31 11:00:08 +03:00
|
|
|
|
nreps, tuple_args, lowering_result.shape_poly_state)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class _DeviceAssignment:
|
2023-06-23 15:11:37 -07:00
|
|
|
|
device_assignment: tuple[xc.Device, ...]
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def _hash(self):
|
|
|
|
|
return hash(self.device_assignment)
|
|
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
return self._hash
|
|
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
if not isinstance(other, _DeviceAssignment):
|
|
|
|
|
return False
|
|
|
|
|
if id(self) == id(other):
|
|
|
|
|
return True
|
|
|
|
|
return (self.device_assignment == other.device_assignment)
|
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def is_fully_addressable(self):
|
|
|
|
|
return len(self.device_assignment) == len(self.addressable_device_assignment)
|
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def addressable_device_assignment(self):
|
|
|
|
|
return [d for d in self.device_assignment
|
|
|
|
|
if d.process_index == d.client.process_index()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=2048)
|
|
|
|
|
def _create_da_object(
|
2023-06-23 15:11:37 -07:00
|
|
|
|
device_assignment: tuple[xc.Device, ...]) -> _DeviceAssignment:
|
2023-04-09 15:41:32 -07:00
|
|
|
|
return _DeviceAssignment(device_assignment)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
|
|
|
|
def lower_sharding_computation(
|
|
|
|
|
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
|
|
|
|
|
api_name: str,
|
|
|
|
|
fun_name: str,
|
|
|
|
|
in_shardings: Sequence[MaybeSharding],
|
|
|
|
|
out_shardings: Union[Sequence[MaybeSharding], UnspecifiedValue],
|
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
global_in_avals: Sequence[core.ShapedArray],
|
|
|
|
|
*,
|
|
|
|
|
keep_unused: bool,
|
2023-04-26 15:54:50 -07:00
|
|
|
|
inline: bool,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
always_lower: bool,
|
|
|
|
|
devices_from_context: Optional[Sequence[xc.Device]] = None,
|
|
|
|
|
lowering_platform: Optional[str],
|
2023-07-11 10:23:48 -07:00
|
|
|
|
override_lowering_rules: Optional[
|
|
|
|
|
tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
) -> MeshComputation:
|
|
|
|
|
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
|
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
|
The caller of this code can pass in a singleton UNSPECIFIED because the
|
2023-04-09 15:41:32 -07:00
|
|
|
|
number of out_avals might not be known at that time and
|
|
|
|
|
lower_sharding_computation calculates the number of out_avals so it can apply
|
2023-04-10 10:15:08 -07:00
|
|
|
|
the singleton UNSPECIFIED to all out_avals.
|
2023-04-09 15:41:32 -07:00
|
|
|
|
"""
|
|
|
|
|
# 1. Trace to jaxpr and preprocess/verify it
|
2023-05-20 22:59:52 -07:00
|
|
|
|
auto_spmd_lowering = (
|
|
|
|
|
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else
|
|
|
|
|
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
|
|
|
|
|
kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce(
|
|
|
|
|
fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
|
2023-05-20 22:59:52 -07:00
|
|
|
|
donated_invars, auto_spmd_lowering)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
jaxpr = closed_jaxpr.jaxpr
|
|
|
|
|
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if is_unspecified(out_shardings):
|
|
|
|
|
out_shardings = (UNSPECIFIED,) * len(global_out_avals)
|
2023-04-05 14:09:46 -07:00
|
|
|
|
assert isinstance(out_shardings, tuple)
|
|
|
|
|
assert len(out_shardings) == len(global_out_avals), (
|
|
|
|
|
len(out_shardings), len(global_out_avals))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
|
|
|
|
# should be the same.
|
|
|
|
|
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
|
2023-02-10 13:53:43 -08:00
|
|
|
|
backend, device_assignment = _get_and_check_device_assignment(
|
|
|
|
|
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings],
|
2023-04-05 14:09:46 -07:00
|
|
|
|
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings],
|
|
|
|
|
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
|
2023-02-10 13:53:43 -08:00
|
|
|
|
for js, source_info in jaxpr_sharding]),
|
|
|
|
|
devices_from_context)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
committed = bool(
|
|
|
|
|
devices_from_context or
|
|
|
|
|
len(device_assignment) > 1 or
|
2023-04-10 10:15:08 -07:00
|
|
|
|
any(not is_unspecified(i) for i in in_shardings) or
|
|
|
|
|
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
|
|
|
|
|
any(not is_unspecified(o) for o in out_shardings))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-13 15:18:56 -07:00
|
|
|
|
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
|
|
|
|
|
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
da_object = _create_da_object(tuple(device_assignment))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
if not da_object.is_fully_addressable:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
check_multihost_collective_allowlist(jaxpr)
|
2023-04-26 15:54:50 -07:00
|
|
|
|
if inline and config.jax_spmd_mode != 'allow_all':
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Running operations on `Array`s that are not fully addressable by this "
|
|
|
|
|
"process (i.e. `Array`s with data sharded across multiple devices and "
|
|
|
|
|
"processes.) is dangerous. It’s very important that all processes run "
|
|
|
|
|
"the same cross-process computations in the same order otherwise it "
|
|
|
|
|
"can lead to hangs. "
|
|
|
|
|
"If you’re not already familiar with JAX’s multi-process "
|
|
|
|
|
"programming model, please read "
|
|
|
|
|
"https://jax.readthedocs.io/en/latest/multi_process.html. "
|
|
|
|
|
"To fix this error, run your `jitted` computation inside "
|
|
|
|
|
"`with jax.spmd_mode('allow_all'):` context manager.")
|
|
|
|
|
|
|
|
|
|
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
kept_outputs = [True] * len(global_out_avals)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# Computations that only produce constants and/or only rearrange their inputs,
|
|
|
|
|
# which are often produced from partial evaluation, don't need compilation,
|
|
|
|
|
# and don't need to evaluate their arguments.
|
|
|
|
|
if (not always_lower and not (jaxpr.effects or has_outfeed) and
|
|
|
|
|
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
|
2023-04-10 10:15:08 -07:00
|
|
|
|
all(is_unspecified(o) for o in out_shardings)):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return MeshComputation(
|
2023-04-09 15:41:32 -07:00
|
|
|
|
str(name_stack), None, True, donated_invars, jaxpr=jaxpr,
|
|
|
|
|
consts=closed_jaxpr.consts, global_in_avals=global_in_avals,
|
|
|
|
|
global_out_avals=global_out_avals, in_shardings=in_shardings,
|
|
|
|
|
backend=backend, da_object=da_object,
|
|
|
|
|
committed=committed, kept_var_idx=kept_var_idx, keepalive=None)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# 2. Build up the HLO
|
2023-04-09 15:41:32 -07:00
|
|
|
|
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
|
|
|
|
|
semantic_out_shardings = SemanticallyEqualShardings(out_shardings)
|
|
|
|
|
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
2023-05-31 11:00:08 +03:00
|
|
|
|
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
2023-04-09 15:41:32 -07:00
|
|
|
|
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
|
|
|
|
semantic_out_shardings, da_object, lowering_platform,
|
2023-07-11 10:23:48 -07:00
|
|
|
|
donated_invars, name_stack, override_lowering_rules)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# backend and device_assignment is passed through to MeshExecutable because
|
|
|
|
|
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
|
|
|
|
# to get the device_assignment and backend. So pass it to MeshExecutable
|
|
|
|
|
# because we calculate the device_assignment and backend before in_shardings,
|
|
|
|
|
# etc are pruned.
|
|
|
|
|
return MeshComputation(
|
|
|
|
|
str(name_stack),
|
|
|
|
|
module,
|
|
|
|
|
False,
|
|
|
|
|
donated_invars,
|
|
|
|
|
global_in_avals=global_in_avals,
|
|
|
|
|
global_out_avals=global_out_avals,
|
|
|
|
|
in_shardings=in_shardings,
|
|
|
|
|
out_shardings=out_shardings,
|
|
|
|
|
spmd_lowering=True,
|
|
|
|
|
tuple_args=tuple_args,
|
2023-05-20 22:59:52 -07:00
|
|
|
|
auto_spmd_lowering=auto_spmd_lowering,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
unordered_effects=unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects,
|
|
|
|
|
host_callbacks=host_callbacks,
|
|
|
|
|
keepalive=keepalive,
|
|
|
|
|
kept_var_idx=kept_var_idx,
|
|
|
|
|
backend=backend,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
device_assignment=da_object,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
committed=committed,
|
2023-04-19 12:35:15 -07:00
|
|
|
|
pmap_nreps=nreps,
|
2023-05-31 11:00:08 +03:00
|
|
|
|
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
|
|
|
|
shape_poly_state=shape_poly_state)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
2023-04-14 13:55:52 -07:00
|
|
|
|
def _to_logical_sharding(
|
2023-05-20 22:59:52 -07:00
|
|
|
|
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTO]
|
2023-04-14 13:55:52 -07:00
|
|
|
|
) -> Optional[sharding_impls.XLACompatibleSharding]:
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if is_unspecified(sharding) or is_auto(sharding):
|
2023-04-05 14:09:46 -07:00
|
|
|
|
return None
|
|
|
|
|
elif isinstance(aval, ShapedArray):
|
|
|
|
|
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
|
2023-04-14 13:55:52 -07:00
|
|
|
|
return sharding
|
2023-04-05 14:09:46 -07:00
|
|
|
|
elif isinstance(aval, core.AbstractToken):
|
|
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
|
|
|
|
def lower_mesh_computation(
|
2023-03-01 10:04:59 -08:00
|
|
|
|
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
api_name: str,
|
|
|
|
|
fun_name: str,
|
|
|
|
|
mesh: Mesh,
|
2023-05-20 22:59:52 -07:00
|
|
|
|
in_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO]],
|
|
|
|
|
out_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO,
|
|
|
|
|
UnspecifiedValue]],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
spmd_lowering: bool,
|
|
|
|
|
global_in_avals: Sequence[core.ShapedArray],
|
|
|
|
|
tiling_method: Optional[TilingMethod],
|
2023-02-28 11:30:23 +01:00
|
|
|
|
lowering_platform: Optional[str]) -> MeshComputation:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
assert not mesh.empty
|
|
|
|
|
backend = xb.get_device_backend(mesh.devices.flat[0])
|
2023-02-27 11:37:10 -08:00
|
|
|
|
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
global_axis_sizes = mesh.shape
|
|
|
|
|
|
|
|
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
2023-04-17 07:52:56 -07:00
|
|
|
|
if logger.isEnabledFor(log_priority):
|
|
|
|
|
logger.log(log_priority,
|
|
|
|
|
"Compiling %s for %s mesh with global shapes and types %s. "
|
|
|
|
|
"Argument mapping: %s.",
|
|
|
|
|
fun_name, tuple(global_axis_sizes.items()), global_in_avals,
|
|
|
|
|
in_shardings)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# 1. Trace to jaxpr and preprocess/verify it
|
|
|
|
|
if spmd_lowering:
|
2023-06-23 15:11:37 -07:00
|
|
|
|
manual_axes: frozenset[MeshAxisName] = frozenset()
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
|
|
|
|
|
if tiling_method is not None:
|
|
|
|
|
if isinstance(tiling_method, TileVectorize):
|
|
|
|
|
tiling_transform = vtile_by_mesh
|
|
|
|
|
elif isinstance(tiling_method, TileManual):
|
|
|
|
|
tiling_transform = lambda f, *args: vtile_manual(f, tiling_method.manual_axes, *args) # type: ignore
|
|
|
|
|
manual_axes = tiling_method.manual_axes
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
|
|
|
|
|
assert not callable(out_shardings)
|
2023-03-01 10:04:59 -08:00
|
|
|
|
assert isinstance(fun_or_jaxpr, lu.WrappedFun)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which
|
|
|
|
|
# is why `.spec` can be accessed.
|
2023-03-01 10:04:59 -08:00
|
|
|
|
fun_or_jaxpr = tiling_transform(
|
|
|
|
|
fun_or_jaxpr, mesh, [get_array_mapping(i.spec) for i in in_shardings], # type: ignore
|
2023-02-07 11:16:01 -08:00
|
|
|
|
[get_array_mapping(o.spec) for o in out_shardings]) # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
in_jaxpr_avals = global_in_avals
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(tiling_method, TileVectorize)
|
|
|
|
|
# In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is
|
|
|
|
|
# why `.spec` can be accessed.
|
2023-02-07 11:16:01 -08:00
|
|
|
|
in_tiled_avals = [tile_aval_nd(global_axis_sizes, get_array_mapping(i.spec), aval) # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for aval, i in safe_zip(global_in_avals, in_shardings)]
|
|
|
|
|
in_jaxpr_avals = in_tiled_avals
|
2023-03-01 10:04:59 -08:00
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
with core.extend_axis_env_nd(mesh.shape.items()):
|
2023-03-01 10:04:59 -08:00
|
|
|
|
if isinstance(fun_or_jaxpr, lu.WrappedFun):
|
|
|
|
|
with dispatch.log_elapsed_time(
|
2023-05-15 09:15:22 -07:00
|
|
|
|
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT):
|
2023-03-01 10:04:59 -08:00
|
|
|
|
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
|
|
|
|
|
fun_or_jaxpr, in_jaxpr_avals)
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
|
|
|
|
|
jaxpr = fun_or_jaxpr.jaxpr
|
|
|
|
|
out_jaxpr_avals = fun_or_jaxpr.out_avals
|
|
|
|
|
consts = fun_or_jaxpr.consts
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
assert len(out_shardings) == len(out_jaxpr_avals)
|
|
|
|
|
if spmd_lowering:
|
|
|
|
|
global_out_avals = out_jaxpr_avals
|
|
|
|
|
else:
|
|
|
|
|
# In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is
|
|
|
|
|
# why `.spec` can be accessed.
|
2023-02-07 11:16:01 -08:00
|
|
|
|
global_out_avals = [untile_aval_nd(global_axis_sizes, get_array_mapping(o.spec), aval) # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for aval, o in safe_zip(out_jaxpr_avals, out_shardings)]
|
2023-03-02 20:49:51 -08:00
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
_sanitize_mesh_jaxpr(jaxpr)
|
|
|
|
|
if mesh.is_multi_process:
|
|
|
|
|
check_multihost_collective_allowlist(jaxpr)
|
|
|
|
|
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
|
|
|
|
|
|
|
|
|
# 2. Build up the HLO
|
|
|
|
|
tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform)
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
in_partitions: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
|
|
|
|
|
out_partitions: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
axis_ctx: mlir.AxisContext
|
|
|
|
|
if spmd_lowering:
|
2023-04-14 13:55:52 -07:00
|
|
|
|
in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings)
|
|
|
|
|
out_partitions = map(_to_logical_sharding, global_out_avals, out_shardings)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
replicated_args = [False] * len(in_jaxpr_avals)
|
2023-04-10 10:15:08 -07:00
|
|
|
|
axis_ctx = sharding_impls.SPMDAxisContext(mesh, manual_axes)
|
2023-04-13 08:55:01 -07:00
|
|
|
|
num_replicas = 1
|
|
|
|
|
num_partitions = mesh.devices.size
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-02-07 11:16:01 -08:00
|
|
|
|
replicated_args = [not get_array_mapping(i.spec) for i in in_shardings] # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
in_partitions = None
|
|
|
|
|
out_partitions = None
|
2023-04-10 10:15:08 -07:00
|
|
|
|
axis_env = sharding_impls.AxisEnv(
|
|
|
|
|
nreps=mesh.size,
|
|
|
|
|
names=tuple(global_axis_sizes.keys()),
|
|
|
|
|
sizes=tuple(global_axis_sizes.values()))
|
|
|
|
|
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
|
2023-04-13 08:55:01 -07:00
|
|
|
|
num_replicas = mesh.devices.size
|
|
|
|
|
num_partitions = 1
|
2023-02-06 14:28:36 -08:00
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
|
|
|
|
module_name = f"{api_name}_{fun_name}"
|
|
|
|
|
with core.extend_axis_env_nd(mesh.shape.items()):
|
2023-02-01 17:50:00 -08:00
|
|
|
|
if any(effects.ordered_effects.contains(eff) for eff
|
|
|
|
|
in closed_jaxpr.effects):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise ValueError("Ordered effects not supported in mesh computations.")
|
2023-02-01 17:50:00 -08:00
|
|
|
|
unordered_effects = list(effects.ordered_effects.filter_not_in(
|
|
|
|
|
closed_jaxpr.effects))
|
|
|
|
|
ordered_effects = list(effects.ordered_effects.filter_in(
|
|
|
|
|
closed_jaxpr.effects))
|
2023-05-15 08:07:31 -07:00
|
|
|
|
with dispatch.log_elapsed_time(
|
2023-05-15 09:15:22 -07:00
|
|
|
|
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
2023-05-15 08:07:31 -07:00
|
|
|
|
lowering_result = mlir.lower_jaxpr_to_module(
|
|
|
|
|
module_name,
|
|
|
|
|
closed_jaxpr,
|
|
|
|
|
ordered_effects,
|
|
|
|
|
backend,
|
|
|
|
|
lowering_platform or backend.platform,
|
|
|
|
|
axis_ctx,
|
|
|
|
|
name_stack,
|
|
|
|
|
donated_invars,
|
|
|
|
|
replicated_args=replicated_args,
|
|
|
|
|
arg_shardings=in_partitions,
|
|
|
|
|
result_shardings=out_partitions,
|
|
|
|
|
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
|
|
|
|
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
|
|
|
|
num_replicas=num_replicas,
|
|
|
|
|
num_partitions=num_partitions)
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return MeshComputation(
|
|
|
|
|
str(name_stack),
|
2023-04-21 14:37:52 -07:00
|
|
|
|
lowering_result.module,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
False,
|
|
|
|
|
donated_invars,
|
|
|
|
|
global_in_avals=global_in_avals,
|
|
|
|
|
global_out_avals=global_out_avals,
|
|
|
|
|
in_shardings=in_shardings,
|
|
|
|
|
out_shardings=out_shardings,
|
|
|
|
|
spmd_lowering=spmd_lowering,
|
|
|
|
|
tuple_args=tuple_args,
|
2023-05-20 22:59:52 -07:00
|
|
|
|
auto_spmd_lowering=False,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
unordered_effects=unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects,
|
2023-04-21 14:37:52 -07:00
|
|
|
|
host_callbacks=lowering_result.host_callbacks,
|
|
|
|
|
keepalive=lowering_result.keepalive,
|
2023-03-02 22:12:53 -08:00
|
|
|
|
kept_var_idx=set(range(len(global_in_avals))),
|
2023-02-06 14:28:36 -08:00
|
|
|
|
backend=backend,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
|
2023-04-19 12:35:15 -07:00
|
|
|
|
committed=True,
|
2023-05-31 11:00:08 +03:00
|
|
|
|
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
|
|
|
|
shape_poly_state=lowering_result.shape_poly_state)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
class MeshComputation(stages.XlaLowering):
|
|
|
|
|
_hlo: Optional[ir.Module]
|
|
|
|
|
_executable: Optional[MeshExecutable]
|
|
|
|
|
|
|
|
|
|
def __init__(self, name: str, hlo: Optional[ir.Module],
|
|
|
|
|
is_trivial: bool, donated_invars: Sequence[bool], **compile_args):
|
|
|
|
|
self._name = name
|
|
|
|
|
self._hlo = hlo
|
|
|
|
|
self.is_trivial = is_trivial
|
|
|
|
|
self._donated_invars = donated_invars
|
|
|
|
|
self.compile_args = compile_args
|
|
|
|
|
self._executable = None
|
|
|
|
|
|
|
|
|
|
# -- stages.XlaLowering overrides
|
|
|
|
|
|
|
|
|
|
def stablehlo(self) -> ir.Module:
|
|
|
|
|
if self.is_trivial:
|
2023-04-21 14:37:52 -07:00
|
|
|
|
raise ValueError("A trivial computation has no HLO")
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self._hlo
|
|
|
|
|
|
2023-03-30 17:13:46 -07:00
|
|
|
|
def compile(
|
|
|
|
|
self,
|
|
|
|
|
compiler_options=None,
|
|
|
|
|
) -> MeshExecutable:
|
|
|
|
|
if self._executable is None or compiler_options is not None:
|
2023-03-22 17:22:39 -07:00
|
|
|
|
if self.is_trivial:
|
2023-03-30 17:13:46 -07:00
|
|
|
|
executable = MeshExecutable.from_trivial_jaxpr(
|
2023-03-22 17:22:39 -07:00
|
|
|
|
**self.compile_args)
|
|
|
|
|
else:
|
2023-03-30 17:13:46 -07:00
|
|
|
|
executable = UnloadedMeshExecutable.from_hlo(
|
2023-03-22 17:22:39 -07:00
|
|
|
|
self._name,
|
|
|
|
|
self._hlo,
|
|
|
|
|
**self.compile_args,
|
2023-03-30 17:13:46 -07:00
|
|
|
|
compiler_options=compiler_options)
|
|
|
|
|
if compiler_options is None:
|
|
|
|
|
self._executable = executable
|
|
|
|
|
return executable
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self._executable
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def cost_analysis(self) -> dict[str, float]:
|
2023-02-15 01:49:55 +00:00
|
|
|
|
backend = self.compile_args["backend"]
|
|
|
|
|
if xb.using_pjrt_c_api(backend):
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Lowered.cost_analysis not implemented on platform "
|
|
|
|
|
f"'{backend.platform}'. Use compile().cost_analysis() for "
|
|
|
|
|
"post-compilation cost estimates.")
|
|
|
|
|
return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module())
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-04-17 17:21:41 -07:00
|
|
|
|
@lru_cache(maxsize=1024)
|
|
|
|
|
def _get_replicated_slices(num_addressable_devices: int, ndim: Optional[int]):
|
|
|
|
|
if ndim is None:
|
|
|
|
|
return ((slice(None),),) * num_addressable_devices
|
|
|
|
|
else:
|
|
|
|
|
return ((slice(None),) * ndim,) * num_addressable_devices
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _get_input_indices(
|
2023-04-10 12:22:45 -07:00
|
|
|
|
avals: Sequence[ShapedArray],
|
|
|
|
|
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
|
|
|
|
da_object: Union[_DeviceAssignment, Sequence[xc.Device]],
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> Sequence[tuple[Optional[Index], ...]]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
input_indices = []
|
2023-04-10 12:22:45 -07:00
|
|
|
|
if isinstance(da_object, _DeviceAssignment):
|
|
|
|
|
num_addressable_devices = len(da_object.addressable_device_assignment)
|
|
|
|
|
else:
|
|
|
|
|
num_addressable_devices = len(
|
|
|
|
|
[d for d in da_object if d.process_index == d.client.process_index()])
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for aval, sharding in zip(avals, shardings):
|
|
|
|
|
if aval is core.abstract_token:
|
2023-04-17 17:21:41 -07:00
|
|
|
|
index = _get_replicated_slices(num_addressable_devices, None)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-04-14 13:55:52 -07:00
|
|
|
|
if sharding.is_fully_replicated:
|
2023-04-17 17:21:41 -07:00
|
|
|
|
index = _get_replicated_slices(num_addressable_devices, aval.ndim)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
|
|
|
|
index = tuple(
|
2023-04-10 12:22:45 -07:00
|
|
|
|
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
input_indices.append(index)
|
|
|
|
|
|
|
|
|
|
return input_indices
|
|
|
|
|
|
|
|
|
|
|
2023-02-17 17:52:37 -08:00
|
|
|
|
def get_gspmd_shardings_from_executable(
|
2023-02-06 14:28:36 -08:00
|
|
|
|
xla_executable, device_assignment: Sequence[xc.Device],
|
|
|
|
|
num_in_avals: int, num_out_avals: int
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[Sequence[sharding_impls.XLACompatibleSharding],
|
2023-03-13 08:49:39 -07:00
|
|
|
|
Sequence[sharding_impls.XLACompatibleSharding]]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax.experimental import pjit
|
|
|
|
|
|
|
|
|
|
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
|
|
|
|
# Hence the op shardings will not be set on the `hlo_module`. In that case,
|
|
|
|
|
# just return SingleDeviceShardings since we know the computation is running
|
|
|
|
|
# only on 1 device.
|
|
|
|
|
if len(device_assignment) == 1:
|
2023-04-13 15:18:56 -07:00
|
|
|
|
ss = sharding_impls.SingleDeviceSharding(device_assignment[0])
|
|
|
|
|
return [ss] * num_in_avals, [ss] * num_out_avals
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
|
|
|
|
|
|
2023-03-13 08:49:39 -07:00
|
|
|
|
in_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, i)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for i in in_op_shardings]
|
2023-03-13 08:49:39 -07:00
|
|
|
|
out_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, o)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for o in out_op_shardings]
|
|
|
|
|
# This condition happens when all the elements in the output tuple have the
|
|
|
|
|
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
|
|
|
|
|
# put the sharding on ROOT instead of the tuple.
|
|
|
|
|
# TODO(b/245667823): Remove this when XLA fixes this.
|
|
|
|
|
if len(out_shardings_xla) == 1 and len(out_shardings_xla) < num_out_avals:
|
|
|
|
|
out_shardings_xla = out_shardings_xla * num_out_avals
|
2023-03-13 14:08:48 -07:00
|
|
|
|
assert len(out_shardings_xla) == num_out_avals, (
|
|
|
|
|
len(out_shardings_xla), num_out_avals)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return in_shardings_xla, out_shardings_xla
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
|
|
|
|
|
# without mesh.
|
|
|
|
|
def _get_mesh_pspec_shardings_from_executable(
|
|
|
|
|
xla_executable, mesh: Mesh
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[Sequence[sharding_impls.NamedSharding],
|
2023-03-13 08:49:39 -07:00
|
|
|
|
Sequence[sharding_impls.NamedSharding]]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax.experimental import pjit
|
|
|
|
|
|
|
|
|
|
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
|
2023-03-13 08:49:39 -07:00
|
|
|
|
return ([sharding_impls.NamedSharding(mesh, i) for i in in_pspec],
|
|
|
|
|
[sharding_impls.NamedSharding(mesh, o) for o in out_pspec])
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-07-04 09:00:06 -07:00
|
|
|
|
_orig_out_sharding_handlers = {}
|
|
|
|
|
|
|
|
|
|
_ShardingT = TypeVar("_ShardingT", bound=sharding_impls.XLACompatibleSharding)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _register_out_sharding_handler(
|
|
|
|
|
sharding_cls: type[_ShardingT],
|
|
|
|
|
handler: Callable[[xc.OpSharding, _ShardingT], _ShardingT],
|
|
|
|
|
) -> None:
|
|
|
|
|
_orig_out_sharding_handlers[sharding_cls] = handler
|
2023-04-11 16:27:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gspmd_to_named_sharding(
|
|
|
|
|
op_sharding: xc.OpSharding,
|
|
|
|
|
self: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
|
2023-04-13 15:18:56 -07:00
|
|
|
|
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
|
|
|
|
|
op_sharding, self.mesh)[0]
|
|
|
|
|
return create_mesh_pspec_sharding(
|
|
|
|
|
self.mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
|
2023-07-04 09:00:06 -07:00
|
|
|
|
|
|
|
|
|
_register_out_sharding_handler(
|
|
|
|
|
sharding_impls.NamedSharding, _gspmd_to_named_sharding
|
|
|
|
|
)
|
2023-04-11 16:27:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gspmd_to_positional_sharding(
|
|
|
|
|
op_sharding: xc.OpSharding,
|
|
|
|
|
self: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
|
2023-06-12 11:51:47 -07:00
|
|
|
|
return sharding_impls._op_sharding_to_pos_sharding(
|
2023-04-11 16:27:08 -07:00
|
|
|
|
op_sharding, self._device_assignment)
|
2023-07-04 09:00:06 -07:00
|
|
|
|
|
|
|
|
|
_register_out_sharding_handler(
|
|
|
|
|
sharding_impls.PositionalSharding, _gspmd_to_positional_sharding
|
|
|
|
|
)
|
2023-04-11 16:27:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_out_sharding_from_orig_sharding(
|
2023-05-01 17:39:16 -07:00
|
|
|
|
out_shardings, out_avals, orig_s, orig_aval, are_out_sharding_from_xla):
|
2023-04-09 15:41:32 -07:00
|
|
|
|
out = []
|
2023-07-04 09:00:06 -07:00
|
|
|
|
orig_handler = _orig_out_sharding_handlers[type(orig_s)]
|
2023-05-01 17:39:16 -07:00
|
|
|
|
for o, out_aval, from_xla in safe_zip(out_shardings, out_avals,
|
|
|
|
|
are_out_sharding_from_xla):
|
2023-04-09 15:41:32 -07:00
|
|
|
|
if isinstance(o, sharding_impls.GSPMDSharding):
|
|
|
|
|
try:
|
2023-05-01 17:39:16 -07:00
|
|
|
|
# Only return the same input sharding object if the OpShardings and
|
|
|
|
|
# in_aval.ndim and out_aval.ndim match. This is because if OpSharding is
|
|
|
|
|
# replicated then, it doesn't encode the ndim in it. The devices
|
|
|
|
|
# will be the same at this point because those checks happen before.
|
|
|
|
|
if (orig_aval is not None and out_aval is not None and
|
|
|
|
|
out_aval.ndim == orig_aval.ndim and
|
2023-05-01 11:46:19 -07:00
|
|
|
|
sharding_impls.are_op_shardings_equal(
|
2023-06-06 06:34:48 -07:00
|
|
|
|
o._hlo_sharding, orig_s._to_xla_hlo_sharding(orig_aval.ndim))):
|
2023-05-01 11:46:19 -07:00
|
|
|
|
out.append((orig_s, False))
|
|
|
|
|
else:
|
2023-06-06 06:34:48 -07:00
|
|
|
|
out.append((orig_handler(o._hlo_sharding, orig_s), False))
|
2023-04-09 15:41:32 -07:00
|
|
|
|
except:
|
|
|
|
|
out.append((o, from_xla))
|
|
|
|
|
else:
|
|
|
|
|
out.append((o, from_xla))
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def maybe_get_orig_out_sharding(
|
2023-05-01 17:39:16 -07:00
|
|
|
|
in_shardings, out_shardings, are_out_shardings_from_xla, in_avals,
|
|
|
|
|
out_avals):
|
2023-04-09 15:41:32 -07:00
|
|
|
|
if all(hasattr(o, '_original_sharding') for o in out_shardings):
|
|
|
|
|
return ([o._original_sharding for o in out_shardings],
|
|
|
|
|
(False,) * len(out_shardings))
|
|
|
|
|
|
2023-04-11 16:27:08 -07:00
|
|
|
|
orig_s = None
|
2023-05-01 11:46:19 -07:00
|
|
|
|
orig_aval = None
|
|
|
|
|
for i, aval in safe_zip(in_shardings, in_avals):
|
2023-04-09 15:41:32 -07:00
|
|
|
|
oi = getattr(i, '_original_sharding', None)
|
2023-07-04 09:00:06 -07:00
|
|
|
|
if type(oi) in _orig_out_sharding_handlers:
|
2023-04-11 16:27:08 -07:00
|
|
|
|
orig_s = oi
|
2023-05-01 11:46:19 -07:00
|
|
|
|
orig_aval = aval
|
2023-04-09 15:41:32 -07:00
|
|
|
|
break
|
2023-04-11 16:27:08 -07:00
|
|
|
|
if orig_s is not None:
|
|
|
|
|
return zip(*_get_out_sharding_from_orig_sharding(
|
2023-05-01 17:39:16 -07:00
|
|
|
|
out_shardings, out_avals, orig_s, orig_aval, are_out_shardings_from_xla))
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
return out_shardings, are_out_shardings_from_xla
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
2023-04-12 17:37:52 -07:00
|
|
|
|
def _cached_compilation(computation, name, mesh, spmd_lowering,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
tuple_args, auto_spmd_lowering,
|
|
|
|
|
_allow_propagation_to_outputs, host_callbacks, backend,
|
|
|
|
|
da, pmap_nreps, compiler_options_keys,
|
|
|
|
|
compiler_options_values):
|
|
|
|
|
device_assignment = da.device_assignment if isinstance(
|
|
|
|
|
da, _DeviceAssignment) else da
|
|
|
|
|
|
2023-05-20 22:59:52 -07:00
|
|
|
|
# TODO(phawkins): One would normally just write:
|
|
|
|
|
# dev = np.array(device_assignment)
|
|
|
|
|
# The formulation below is substantially faster if there are many devices.
|
|
|
|
|
# If we were to optimize __getattr__ on xc.Device we might not need this
|
|
|
|
|
# workaround.
|
|
|
|
|
dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(
|
|
|
|
|
np.arange(len(device_assignment))
|
|
|
|
|
)
|
|
|
|
|
if pmap_nreps > 1:
|
|
|
|
|
num_replicas, num_partitions = pmap_nreps, 1
|
|
|
|
|
elif spmd_lowering:
|
|
|
|
|
num_replicas, num_partitions = 1, dev.size
|
2023-04-09 15:41:32 -07:00
|
|
|
|
else:
|
2023-05-20 22:59:52 -07:00
|
|
|
|
num_replicas, num_partitions = dev.size, 1
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
if pmap_nreps > 1:
|
|
|
|
|
# In `jit` device_assignment is set to None when num_replicas > 1. Do
|
|
|
|
|
# the same thing here too.
|
|
|
|
|
xla_device_assignment = None
|
|
|
|
|
else:
|
|
|
|
|
xla_device_assignment = dev.reshape((num_replicas, num_partitions))
|
|
|
|
|
|
|
|
|
|
if compiler_options_keys is None:
|
|
|
|
|
compiler_options = None
|
|
|
|
|
else:
|
|
|
|
|
compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values))
|
|
|
|
|
|
|
|
|
|
compile_options = xb.get_compile_options(
|
|
|
|
|
num_replicas=num_replicas,
|
|
|
|
|
num_partitions=num_partitions,
|
|
|
|
|
device_assignment=xla_device_assignment,
|
|
|
|
|
use_spmd_partitioning=spmd_lowering,
|
|
|
|
|
use_auto_spmd_partitioning=auto_spmd_lowering,
|
|
|
|
|
env_options_overrides=compiler_options,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
opts = compile_options.executable_build_options
|
|
|
|
|
if auto_spmd_lowering:
|
|
|
|
|
assert mesh is not None
|
|
|
|
|
opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values())
|
|
|
|
|
opts.auto_spmd_partitioning_mesh_ids = (
|
|
|
|
|
sharding_specs.get_logical_mesh_ids(list(mesh.shape.values()))
|
|
|
|
|
.reshape(-1))
|
|
|
|
|
compile_options.parameter_is_tupled_arguments = tuple_args
|
|
|
|
|
opts.allow_spmd_sharding_propagation_to_output = list(_allow_propagation_to_outputs)
|
|
|
|
|
|
|
|
|
|
if hasattr(backend, "compile_replicated"):
|
|
|
|
|
return None, compile_options
|
|
|
|
|
|
2023-05-15 09:15:22 -07:00
|
|
|
|
with dispatch.log_elapsed_time(
|
|
|
|
|
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
|
2023-04-09 15:41:32 -07:00
|
|
|
|
xla_executable = dispatch.compile_or_get_cached(
|
2023-04-20 06:16:12 -07:00
|
|
|
|
backend, computation, dev, compile_options, host_callbacks)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
return xla_executable, compile_options
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class UnloadedMeshExecutable:
|
|
|
|
|
xla_executable: Any
|
2023-04-10 12:22:45 -07:00
|
|
|
|
device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
backend: xb.XlaBackend
|
|
|
|
|
input_avals: Sequence[ShapedArray]
|
2023-03-13 08:49:39 -07:00
|
|
|
|
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
output_avals: Sequence[ShapedArray]
|
2023-03-13 08:49:39 -07:00
|
|
|
|
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
committed: bool
|
|
|
|
|
are_out_shardings_from_xla: Sequence[bool]
|
|
|
|
|
name: str
|
2023-06-23 15:11:37 -07:00
|
|
|
|
unordered_effects: list[core.Effect]
|
|
|
|
|
ordered_effects: list[core.Effect]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
keepalive: Sequence[Any]
|
|
|
|
|
host_callbacks: Sequence[Any]
|
2023-06-23 15:11:37 -07:00
|
|
|
|
kept_var_idx: set[int]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
auto_spmd_lowering: bool
|
2023-04-19 12:35:15 -07:00
|
|
|
|
jaxpr_debug_info: Optional[core.JaxprDebugInfo]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-22 17:22:39 -07:00
|
|
|
|
def build_unsafe_call(self):
|
2023-04-10 12:22:45 -07:00
|
|
|
|
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
|
|
|
|
|
self.device_assignment)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
handle_args = InputsHandler(self.xla_executable.local_devices(),
|
|
|
|
|
self.input_shardings, input_indices)
|
|
|
|
|
handle_outs = global_avals_to_results_handler(
|
|
|
|
|
self.output_avals, self.output_shardings, self.committed,
|
|
|
|
|
self.are_out_shardings_from_xla) # type: ignore # arg-type
|
|
|
|
|
|
2023-03-13 14:08:48 -07:00
|
|
|
|
unsafe_call = ExecuteReplicated( # type: ignore # assignment
|
|
|
|
|
self.xla_executable, self.name, self.backend, handle_args,
|
|
|
|
|
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
|
|
|
|
|
bool(self.host_callbacks), self.kept_var_idx)
|
2023-03-22 17:22:39 -07:00
|
|
|
|
return unsafe_call
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-22 17:22:39 -07:00
|
|
|
|
def load(self) -> MeshExecutable:
|
|
|
|
|
return MeshExecutable(self.xla_executable, self.build_unsafe_call,
|
|
|
|
|
self.input_avals,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.input_shardings, self.output_shardings,
|
|
|
|
|
self.auto_spmd_lowering, self.kept_var_idx,
|
2023-04-19 12:35:15 -07:00
|
|
|
|
self.jaxpr_debug_info, self)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# May return a MeshExecutable in the compile_replicated case.
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_hlo(name: str,
|
2023-04-21 14:37:52 -07:00
|
|
|
|
hlo: ir.Module,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
global_in_avals: Sequence[ShapedArray],
|
|
|
|
|
global_out_avals: Sequence[ShapedArray],
|
2023-05-20 22:59:52 -07:00
|
|
|
|
in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO]],
|
|
|
|
|
out_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO,
|
|
|
|
|
UnspecifiedValue]],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
spmd_lowering: bool,
|
|
|
|
|
tuple_args: bool,
|
|
|
|
|
auto_spmd_lowering: bool,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
unordered_effects: list[core.Effect],
|
|
|
|
|
ordered_effects: list[core.Effect],
|
|
|
|
|
host_callbacks: list[Any],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
keepalive: Any,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
kept_var_idx: set[int],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
backend: xb.XlaBackend,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
committed: bool,
|
2023-03-30 17:13:46 -07:00
|
|
|
|
pmap_nreps: int = 1,
|
2023-04-19 12:35:15 -07:00
|
|
|
|
jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None,
|
2023-05-31 11:00:08 +03:00
|
|
|
|
shape_poly_state: Optional[mlir.ShapePolyLoweringState] = None,
|
2023-03-30 17:13:46 -07:00
|
|
|
|
compiler_options=None
|
2023-03-22 17:22:39 -07:00
|
|
|
|
) -> MeshExecutable:
|
2023-06-06 13:26:35 -07:00
|
|
|
|
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
2023-06-09 23:46:45 -07:00
|
|
|
|
hlo = mlir.refine_polymorphic_shapes(hlo)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
compiler_options_keys = tuple(
|
|
|
|
|
compiler_options.keys()) if compiler_options is not None else None
|
|
|
|
|
compiler_options_values = tuple(
|
|
|
|
|
compiler_options.values()) if compiler_options is not None else None
|
|
|
|
|
da = device_assignment if isinstance(
|
|
|
|
|
device_assignment, _DeviceAssignment) else tuple(device_assignment)
|
2023-04-10 12:22:45 -07:00
|
|
|
|
del device_assignment
|
2023-04-12 17:37:52 -07:00
|
|
|
|
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
|
2023-05-20 22:59:52 -07:00
|
|
|
|
|
|
|
|
|
mesh = None
|
|
|
|
|
if auto_spmd_lowering:
|
|
|
|
|
for i in it.chain.from_iterable([in_shardings, out_shardings]):
|
|
|
|
|
if is_auto(i):
|
|
|
|
|
mesh = i.mesh # type: ignore
|
|
|
|
|
break
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
xla_executable, compile_options = _cached_compilation(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
hlo, name, mesh, spmd_lowering,
|
2023-04-12 17:37:52 -07:00
|
|
|
|
tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
tuple(host_callbacks), backend, da, pmap_nreps,
|
|
|
|
|
compiler_options_keys, compiler_options_values)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
if hasattr(backend, "compile_replicated"):
|
|
|
|
|
semantics_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
|
|
|
|
|
semantics_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
|
|
|
|
|
return _compile_replicated_mesh_executable_from_hlo(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
hlo, name, tuple(global_in_avals), tuple(global_out_avals),
|
2023-04-09 15:41:32 -07:00
|
|
|
|
semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering,
|
|
|
|
|
compile_options, tuple(host_callbacks), bool(unordered_effects),
|
|
|
|
|
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
pmap_nreps, jaxpr_debug_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
if auto_spmd_lowering:
|
|
|
|
|
assert mesh is not None
|
2023-04-09 15:41:32 -07:00
|
|
|
|
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
|
|
|
|
|
xla_executable, mesh)
|
2023-05-20 22:59:52 -07:00
|
|
|
|
in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore
|
2023-04-09 15:41:32 -07:00
|
|
|
|
for x, i in safe_zip(in_shardings_xla, in_shardings)]
|
|
|
|
|
out_shardings_tuple = [
|
|
|
|
|
(x, True) if is_auto(o) else (o, False)
|
|
|
|
|
for x, o in safe_zip(out_shardings_xla, out_shardings)
|
|
|
|
|
]
|
|
|
|
|
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
|
2023-04-10 10:15:08 -07:00
|
|
|
|
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
and pmap_nreps == 1):
|
|
|
|
|
assert mesh is None
|
2023-04-10 12:22:45 -07:00
|
|
|
|
device_assignment = da.device_assignment if isinstance( # type: ignore
|
|
|
|
|
da, _DeviceAssignment) else da
|
2023-04-09 15:41:32 -07:00
|
|
|
|
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
2023-04-10 12:22:45 -07:00
|
|
|
|
xla_executable, device_assignment, # type: ignore
|
2023-04-09 15:41:32 -07:00
|
|
|
|
len(global_in_avals), len(global_out_avals))
|
|
|
|
|
orig_out_shardings = out_shardings
|
|
|
|
|
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
|
|
|
|
|
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
|
|
|
|
|
global_out_avals):
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if is_unspecified(orig):
|
2023-04-09 15:41:32 -07:00
|
|
|
|
out_shardings.append(xla_s)
|
|
|
|
|
are_out_shardings_from_xla.append(True)
|
|
|
|
|
else:
|
|
|
|
|
if not op_shardings.are_op_shardings_equal(
|
2023-06-05 13:40:59 -07:00
|
|
|
|
xla_s._to_xla_hlo_sharding(aval.ndim), # type: ignore
|
|
|
|
|
orig._to_xla_hlo_sharding(aval.ndim)): # type: ignore
|
2023-04-09 15:41:32 -07:00
|
|
|
|
raise AssertionError(
|
|
|
|
|
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
|
|
|
|
"(User sharding)")
|
|
|
|
|
out_shardings.append(orig)
|
|
|
|
|
are_out_shardings_from_xla.append(False)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-04-09 15:41:32 -07:00
|
|
|
|
are_out_shardings_from_xla = (False,) * len(global_out_avals)
|
|
|
|
|
|
|
|
|
|
if pmap_nreps > 1:
|
2023-04-10 12:22:45 -07:00
|
|
|
|
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
|
|
|
|
|
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding(
|
2023-05-01 11:46:19 -07:00
|
|
|
|
in_shardings, out_shardings, are_out_shardings_from_xla,
|
2023-05-01 17:39:16 -07:00
|
|
|
|
global_in_avals, global_out_avals)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
return UnloadedMeshExecutable(
|
|
|
|
|
xla_executable=xla_executable,
|
2023-04-10 12:22:45 -07:00
|
|
|
|
device_assignment=da, # type: ignore
|
2023-04-09 15:41:32 -07:00
|
|
|
|
backend=backend,
|
|
|
|
|
input_avals=global_in_avals,
|
|
|
|
|
input_shardings=in_shardings, # type: ignore
|
|
|
|
|
output_avals=global_out_avals,
|
|
|
|
|
output_shardings=out_shardings, # type: ignore # arg-type
|
|
|
|
|
committed=committed,
|
|
|
|
|
are_out_shardings_from_xla=are_out_shardings_from_xla,
|
|
|
|
|
name=name,
|
|
|
|
|
unordered_effects=unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects,
|
|
|
|
|
keepalive=keepalive,
|
|
|
|
|
host_callbacks=host_callbacks,
|
|
|
|
|
kept_var_idx=kept_var_idx,
|
2023-04-19 12:35:15 -07:00
|
|
|
|
auto_spmd_lowering=auto_spmd_lowering,
|
|
|
|
|
jaxpr_debug_info=jaxpr_debug_info).load()
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
class MeshExecutableFastpathData(NamedTuple):
|
2023-02-16 11:54:25 -08:00
|
|
|
|
xla_executable: xc.LoadedExecutable
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_pytree_def: Any
|
2023-03-13 08:49:39 -07:00
|
|
|
|
in_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
|
|
|
|
out_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_avals: Sequence[ShapedArray]
|
|
|
|
|
out_committed: Sequence[bool]
|
|
|
|
|
kept_var_bitvec: Iterable[bool]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MeshExecutable(stages.XlaExecutable):
|
|
|
|
|
__slots__ = [
|
2023-04-10 12:22:45 -07:00
|
|
|
|
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
|
|
|
|
|
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
|
2023-04-19 12:35:15 -07:00
|
|
|
|
"_jaxpr_debug_info", "_unloaded_executable",
|
2023-02-06 14:28:36 -08:00
|
|
|
|
]
|
|
|
|
|
|
2023-03-22 17:22:39 -07:00
|
|
|
|
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_shardings, auto_spmd_lowering, kept_var_idx,
|
2023-04-19 12:35:15 -07:00
|
|
|
|
jaxpr_debug_info=None, unloaded_executable=None):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.xla_executable = xla_executable
|
2023-03-22 17:22:39 -07:00
|
|
|
|
self.build_unsafe_call = build_unsafe_call
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# in_avals is a list of global and local avals. Aval is global if input
|
|
|
|
|
# is a GDA or jax.Array else local.
|
|
|
|
|
self.in_avals = in_avals
|
2023-03-22 17:22:39 -07:00
|
|
|
|
self._unsafe_call = None
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self._in_shardings = in_shardings
|
|
|
|
|
self._out_shardings = out_shardings
|
|
|
|
|
self._auto_spmd_lowering = auto_spmd_lowering
|
|
|
|
|
self._kept_var_idx = kept_var_idx
|
2023-04-19 12:35:15 -07:00
|
|
|
|
self._jaxpr_debug_info = jaxpr_debug_info
|
2023-03-22 17:22:39 -07:00
|
|
|
|
self._unloaded_executable = unloaded_executable
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def unsafe_call(self) -> Callable[..., Any]:
|
|
|
|
|
if self._unsafe_call is None:
|
|
|
|
|
self._unsafe_call = self.build_unsafe_call()
|
|
|
|
|
return self._unsafe_call
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
in_shardings, backend, da_object,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
committed, kept_var_idx, keepalive) -> MeshExecutable:
|
|
|
|
|
assert keepalive is None
|
|
|
|
|
if hasattr(backend, "compile_replicated"):
|
|
|
|
|
return _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
|
|
|
|
jaxpr, consts, global_in_avals, global_out_avals, in_shardings,
|
2023-04-10 12:22:45 -07:00
|
|
|
|
backend, da_object, committed, kept_var_idx, 1)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
out_shardings = _out_shardings_for_trivial(
|
2023-04-09 15:41:32 -07:00
|
|
|
|
jaxpr, consts, in_shardings, da_object.device_assignment)
|
2023-04-10 12:22:45 -07:00
|
|
|
|
indices = _get_input_indices(global_out_avals, out_shardings, da_object)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
local_device_assignment = da_object.addressable_device_assignment
|
2023-02-06 14:28:36 -08:00
|
|
|
|
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
|
|
|
|
|
handle_outs = global_avals_to_results_handler(
|
|
|
|
|
global_out_avals, out_shardings, committed,
|
|
|
|
|
[False] * len(global_out_avals))
|
|
|
|
|
unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins,
|
|
|
|
|
handle_outs, kept_var_idx)
|
2023-03-22 17:22:39 -07:00
|
|
|
|
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
|
|
|
|
|
in_shardings, out_shardings, False, kept_var_idx,
|
2023-04-10 12:22:45 -07:00
|
|
|
|
None)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# -- stages.XlaExecutable overrides
|
|
|
|
|
|
|
|
|
|
def xla_extension_executable(self):
|
|
|
|
|
return self.xla_executable
|
|
|
|
|
|
|
|
|
|
def call(self, *args):
|
|
|
|
|
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
|
|
|
|
|
arg_avals = map(xla.abstractify, kept_args)
|
|
|
|
|
ref_avals = self.in_avals
|
2023-04-19 15:08:21 -07:00
|
|
|
|
check_arg_avals_for_call(ref_avals, arg_avals, self._jaxpr_debug_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# Check the GDA sharding and the input sharding.
|
2023-04-19 12:35:15 -07:00
|
|
|
|
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings,
|
|
|
|
|
self._jaxpr_debug_info)
|
2023-03-22 17:22:39 -07:00
|
|
|
|
return self.unsafe_call(*args) # pylint: disable=not-callable
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-13 08:49:39 -07:00
|
|
|
|
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self._in_shardings
|
|
|
|
|
|
2023-03-13 08:49:39 -07:00
|
|
|
|
def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self._out_shardings
|
|
|
|
|
|
|
|
|
|
def create_cpp_call(self, no_kwargs, in_tree, out_tree):
|
|
|
|
|
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
|
|
|
|
|
not self.unsafe_call.has_unordered_effects and
|
|
|
|
|
not self.unsafe_call.has_host_callbacks):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def aot_cache_miss(*args, **kwargs):
|
|
|
|
|
params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree)
|
|
|
|
|
outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
|
|
|
|
|
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))
|
|
|
|
|
|
|
|
|
|
if use_fastpath:
|
|
|
|
|
out_avals = [o.aval for o in out_flat]
|
|
|
|
|
out_committed = [o._committed for o in out_flat]
|
|
|
|
|
kept_var_bitvec = [i in self._kept_var_idx
|
|
|
|
|
for i in range(len(args_flat))]
|
2023-02-07 11:16:01 -08:00
|
|
|
|
fastpath_data = MeshExecutableFastpathData(
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.xla_executable, out_tree, self._in_shardings,
|
|
|
|
|
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
|
|
|
|
|
else:
|
|
|
|
|
fastpath_data = None
|
|
|
|
|
return outs, fastpath_data
|
|
|
|
|
|
2023-02-24 15:05:12 -08:00
|
|
|
|
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], []) # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-04-19 15:08:21 -07:00
|
|
|
|
def check_arg_avals_for_call(ref_avals, arg_avals,
|
|
|
|
|
jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None):
|
2023-03-16 15:46:57 -07:00
|
|
|
|
if len(ref_avals) != len(arg_avals):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"Computation compiled for {len(ref_avals)} inputs "
|
|
|
|
|
f"but called with {len(arg_avals)}")
|
2023-07-10 18:28:50 -07:00
|
|
|
|
|
|
|
|
|
if jaxpr_debug_info is not None:
|
|
|
|
|
arg_names = [f"'{name}'" for name in jaxpr_debug_info.arg_names]
|
|
|
|
|
else:
|
|
|
|
|
num_args = len(ref_avals)
|
|
|
|
|
arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]
|
|
|
|
|
|
2023-04-19 15:08:21 -07:00
|
|
|
|
errors = []
|
|
|
|
|
for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names):
|
2023-03-16 15:46:57 -07:00
|
|
|
|
if not core.typematch(ref_aval, arg_aval):
|
2023-07-10 18:28:50 -07:00
|
|
|
|
errors.append(
|
|
|
|
|
f"Argument {name} compiled with {ref_aval.str_short()} and called "
|
|
|
|
|
f"with {arg_aval.str_short()}")
|
2023-04-19 15:08:21 -07:00
|
|
|
|
if errors:
|
2023-07-10 18:28:50 -07:00
|
|
|
|
max_num_errors = 5
|
|
|
|
|
str_errors = "\n".join(errors[:max_num_errors])
|
|
|
|
|
if len(errors) >= max_num_errors:
|
|
|
|
|
num_mismatch_str = f"The first {max_num_errors} of {len(errors)}"
|
|
|
|
|
else:
|
|
|
|
|
num_mismatch_str = "The"
|
2023-04-19 15:08:21 -07:00
|
|
|
|
raise TypeError(
|
2023-07-10 18:28:50 -07:00
|
|
|
|
"Argument types differ from the types for which this computation was "
|
|
|
|
|
f"compiled. {num_mismatch_str} mismatches are:\n{str_errors}")
|
2023-03-16 15:46:57 -07:00
|
|
|
|
|
|
|
|
|
|
2023-04-10 12:22:45 -07:00
|
|
|
|
def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
|
|
|
|
|
# Create replicated shardings for jit(pmap) path with local devices
|
|
|
|
|
# because multihost jit(pmap) is not allowed.
|
2023-04-13 15:18:56 -07:00
|
|
|
|
gs = sharding_impls.GSPMDSharding.get_replicated(local_devices)
|
|
|
|
|
in_shardings = [gs] * num_in_shardings
|
|
|
|
|
out_shardings = [gs] * num_out_shardings
|
2023-04-10 12:22:45 -07:00
|
|
|
|
# jit(pmap) will generate Arrays with multi-device sharding.
|
|
|
|
|
# It is unsupported for these shardings to be uncommited, so force
|
|
|
|
|
# the outputs to be committed.
|
|
|
|
|
committed = True
|
|
|
|
|
return in_shardings, out_shardings, committed, tuple(local_devices)
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _out_shardings_for_trivial(
|
|
|
|
|
jaxpr: core.Jaxpr, consts: Sequence[Any],
|
2023-03-13 08:49:39 -07:00
|
|
|
|
in_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
device_assignment: Sequence[xc.Device],
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> list[sharding_impls.XLACompatibleSharding]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# For each jaxpr output, compute a Sharding by:
|
|
|
|
|
# * if the output is a forwarded input, get the corresponding in_sharding;
|
|
|
|
|
# * if the output is a constant Array, get its .sharding attribute;
|
|
|
|
|
# * otherwise, the output is a literal or numpy.ndarray constant, so give it
|
|
|
|
|
# a replicated sharding
|
|
|
|
|
from jax._src import array
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
if len(device_assignment) > 1:
|
|
|
|
|
rep = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
|
|
|
|
|
in_shardings = tuple(
|
|
|
|
|
i._original_sharding if hasattr(i, '_original_sharding') else i
|
|
|
|
|
for i in in_shardings)
|
|
|
|
|
else:
|
|
|
|
|
dev, = device_assignment
|
|
|
|
|
rep = sharding_impls.SingleDeviceSharding(dev)
|
|
|
|
|
in_shardings = (sharding_impls.SingleDeviceSharding(dev),) * len(in_shardings)
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
shardings: dict[core.Var, sharding_impls.XLACompatibleSharding] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for constvar, constval in zip(jaxpr.constvars, consts):
|
|
|
|
|
if isinstance(constval, array.ArrayImpl):
|
|
|
|
|
shardings[constvar] = constval.sharding
|
|
|
|
|
map(shardings.setdefault, jaxpr.invars, in_shardings)
|
|
|
|
|
return [rep if isinstance(x, core.Literal) else shardings.get(x, rep)
|
|
|
|
|
for x in jaxpr.outvars]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args):
|
2023-06-23 15:11:37 -07:00
|
|
|
|
env: dict[core.Var, Any] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
|
|
|
|
|
map(env.setdefault, jaxpr.invars, pruned_args)
|
|
|
|
|
map(env.setdefault, jaxpr.constvars, consts)
|
|
|
|
|
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
|
|
|
|
|
for v in jaxpr.outvars]
|
2023-03-08 21:39:56 -08:00
|
|
|
|
return out_handler(in_handler(outs))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
@weakref_lru_cache
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _compile_replicated_mesh_executable_from_hlo(
|
2023-04-09 15:41:32 -07:00
|
|
|
|
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
|
|
|
|
|
semantics_out_shardings, auto_spmd_lowering, compile_options,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
backend, da, committed, pmap_nreps, jaxpr_debug_info):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
assert not auto_spmd_lowering
|
2023-04-09 15:41:32 -07:00
|
|
|
|
in_shardings = semantics_in_shardings.shardings
|
|
|
|
|
out_shardings = semantics_out_shardings.shardings
|
|
|
|
|
|
2023-04-10 12:22:45 -07:00
|
|
|
|
input_indices = _get_input_indices(global_in_avals, in_shardings, da) # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if pmap_nreps > 1:
|
|
|
|
|
# For a jit wrapping a pmap, replicate each input index to match the
|
|
|
|
|
# devices of the replicated jit computation.
|
|
|
|
|
input_indices = [index * pmap_nreps for index in input_indices]
|
2023-04-09 15:41:32 -07:00
|
|
|
|
kept_var_idx = set(kept_var_idx)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# Will compute out_handler with executable information.
|
|
|
|
|
unsafe_call = backend.compile_replicated(
|
|
|
|
|
is_trivial=False, name=name, computation=computation,
|
|
|
|
|
compile_options=compile_options, host_callbacks=host_callbacks,
|
|
|
|
|
has_unordered_effects=has_unordered_effects,
|
2023-03-17 11:50:59 -07:00
|
|
|
|
ordered_effects=ordered_effects, in_avals=global_in_avals,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
in_indices=input_indices, in_shardings=in_shardings,
|
|
|
|
|
kept_var_idx=kept_var_idx,
|
|
|
|
|
out_avals=global_out_avals, out_shardings=out_shardings,
|
2023-03-06 15:01:48 -08:00
|
|
|
|
committed=committed, pmap_nreps=pmap_nreps)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
xla_executable = None
|
2023-03-22 17:22:39 -07:00
|
|
|
|
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
in_shardings, out_shardings, auto_spmd_lowering,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
kept_var_idx, jaxpr_debug_info, None)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
|
|
|
|
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
|
2023-04-10 12:22:45 -07:00
|
|
|
|
da_object, committed, kept_var_idx, pmap_nreps):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_shardings = _out_shardings_for_trivial(
|
2023-04-10 12:22:45 -07:00
|
|
|
|
jaxpr, consts, in_shardings, da_object.device_assignment)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-10 12:22:45 -07:00
|
|
|
|
input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
handle_outs = global_avals_to_results_handler(
|
|
|
|
|
global_out_avals, out_shardings, committed,
|
|
|
|
|
[False] * len(global_out_avals))
|
|
|
|
|
# Use the standard out_handler.
|
|
|
|
|
unsafe_call = backend.compile_replicated(
|
|
|
|
|
is_trivial=True, jaxpr=jaxpr, consts=consts,
|
2023-04-10 12:22:45 -07:00
|
|
|
|
device_assignment=da_object.device_assignment, in_avals=global_in_avals,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
in_indices=input_indices, in_shardings=in_shardings,
|
|
|
|
|
kept_var_idx=kept_var_idx, out_handler=handle_outs,
|
2023-03-06 15:01:48 -08:00
|
|
|
|
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
|
2023-03-22 17:22:39 -07:00
|
|
|
|
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
|
|
|
|
|
in_shardings, out_shardings, False, kept_var_idx,
|
2023-04-10 12:22:45 -07:00
|
|
|
|
None)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache()
|
2023-02-07 11:16:01 -08:00
|
|
|
|
def create_mesh_pspec_sharding(
|
2023-04-10 08:42:18 -07:00
|
|
|
|
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None
|
2023-03-13 08:49:39 -07:00
|
|
|
|
) -> sharding_impls.NamedSharding:
|
2023-04-10 08:42:18 -07:00
|
|
|
|
if pspec is None:
|
2023-04-10 10:48:26 -07:00
|
|
|
|
pspec, parsed_pspec = PartitionSpec(), None
|
2023-03-13 08:49:39 -07:00
|
|
|
|
return sharding_impls.NamedSharding(mesh, pspec, parsed_pspec)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
def check_device_backend_on_shardings(shardings) -> bool:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for i in shardings:
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if is_unspecified(i) or is_auto(i):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
continue
|
|
|
|
|
if hasattr(i, '_original_sharding') and getattr(
|
|
|
|
|
i._original_sharding, '_device_backend', False):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
def check_gda_or_array_xla_sharding_match(
|
2023-04-19 12:35:15 -07:00
|
|
|
|
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
|
|
|
|
jaxpr_debug_info: Optional[core.JaxprDebugInfo]) -> None:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax._src.array import ArrayImpl
|
2023-04-19 12:35:15 -07:00
|
|
|
|
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
|
|
|
|
|
jaxpr_debug_info.arg_names)
|
|
|
|
|
errors = []
|
|
|
|
|
num_errors = 5
|
|
|
|
|
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
|
2023-03-15 12:59:33 -07:00
|
|
|
|
if not isinstance(arg, ArrayImpl):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# No need to cache this check since MeshExecutable has a C++ fast path
|
|
|
|
|
# for AOT compiled call.
|
2023-02-07 11:16:01 -08:00
|
|
|
|
if (not check_device_backend_on_shardings([xs]) and
|
2023-03-15 12:59:33 -07:00
|
|
|
|
arg._committed and
|
2023-04-06 08:31:47 -07:00
|
|
|
|
not op_shardings.are_op_shardings_equal(
|
2023-06-05 13:40:59 -07:00
|
|
|
|
arg.sharding._to_xla_hlo_sharding(arg.ndim),
|
|
|
|
|
xs._to_xla_hlo_sharding(arg.ndim))):
|
2023-04-19 12:35:15 -07:00
|
|
|
|
errors.append(
|
|
|
|
|
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
|
|
|
|
|
f"arg {name} with shape: {arg.aval.str_short()}")
|
|
|
|
|
|
|
|
|
|
if errors:
|
2023-04-19 15:08:21 -07:00
|
|
|
|
str_errors = '\n'.join(errors[:num_errors])
|
|
|
|
|
num_mismatch_str = (
|
|
|
|
|
f'the {len(errors)} mismatches' if len(errors) < num_errors else
|
|
|
|
|
f"{num_errors} mismatches out of {len(errors)}")
|
2023-04-19 12:35:15 -07:00
|
|
|
|
raise ValueError(
|
|
|
|
|
"Array(s) sharding does not match the input(s) sharding. "
|
2023-04-19 15:08:21 -07:00
|
|
|
|
f"Here are {num_mismatch_str}:\n{str_errors}")
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
2023-04-10 10:15:08 -07:00
|
|
|
|
parsed_pspec, _, _ = sharding_impls.prepare_axis_resources(
|
|
|
|
|
pspec, "pspec to array_mapping")
|
2023-02-07 11:16:01 -08:00
|
|
|
|
return _get_array_mapping(parsed_pspec)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_forbidden_primitives = {
|
|
|
|
|
'xla_pmap': 'pmap',
|
|
|
|
|
}
|
|
|
|
|
def _sanitize_mesh_jaxpr(jaxpr):
|
|
|
|
|
if isinstance(jaxpr, core.ClosedJaxpr):
|
|
|
|
|
jaxpr = jaxpr.jaxpr
|
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
|
if eqn.primitive.name in _forbidden_primitives:
|
|
|
|
|
raise RuntimeError(f"Nesting {_forbidden_primitives[eqn.primitive.name]} "
|
|
|
|
|
f"inside xmaps not supported!")
|
|
|
|
|
core.traverse_jaxpr_params(_sanitize_mesh_jaxpr, eqn.params)
|
|
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
custom_resource_typing_rules: dict[core.Primitive, Callable] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
|
|
|
|
|
if isinstance(jaxpr, core.ClosedJaxpr):
|
|
|
|
|
jaxpr = jaxpr.jaxpr
|
|
|
|
|
def _check_aval(aval, what_thunk):
|
|
|
|
|
if not hasattr(aval, 'named_shape'):
|
|
|
|
|
return
|
|
|
|
|
resource_to_axis = {}
|
|
|
|
|
for axis in aval.named_shape:
|
|
|
|
|
if axis_resources:
|
|
|
|
|
for resource in axis_resources[axis]:
|
|
|
|
|
if resource in resource_to_axis:
|
|
|
|
|
other_axis = resource_to_axis[resource]
|
|
|
|
|
axis, other_axis = sorted([str(axis), str(other_axis)])
|
|
|
|
|
raise JAXTypeError(
|
|
|
|
|
f"Axes `{axis}` and `{other_axis}` are both mapped to the "
|
|
|
|
|
f"resource `{resource}`, but they coincide in the named_shape "
|
|
|
|
|
f"of {what_thunk()}")
|
|
|
|
|
resource_to_axis[resource] = axis
|
|
|
|
|
|
|
|
|
|
what_thunk = lambda: (f"an input to {what_jaxpr_thunk()}")
|
|
|
|
|
for v in jaxpr.constvars:
|
|
|
|
|
_check_aval(v.aval, what_thunk)
|
|
|
|
|
for v in jaxpr.invars:
|
|
|
|
|
_check_aval(v.aval, what_thunk)
|
|
|
|
|
what_thunk = lambda: (f"a value returned from a primitive {eqn.primitive} created "
|
|
|
|
|
f"at {source_info_util.summarize(eqn.source_info)}")
|
|
|
|
|
rec_what_jaxpr_thunk = lambda: (f"a primitive {eqn.primitive} created at"
|
|
|
|
|
f"{source_info_util.summarize(eqn.source_info)}")
|
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
|
typing_rule = custom_resource_typing_rules.get(eqn.primitive, None)
|
|
|
|
|
if typing_rule:
|
|
|
|
|
typing_rule([v.aval for v in eqn.invars], eqn.params, eqn.source_info,
|
|
|
|
|
resource_env, axis_resources)
|
|
|
|
|
else:
|
|
|
|
|
core.traverse_jaxpr_params(partial(resource_typecheck,
|
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
axis_resources=axis_resources,
|
|
|
|
|
what_jaxpr_thunk=rec_what_jaxpr_thunk),
|
|
|
|
|
eqn.params)
|
|
|
|
|
for v in eqn.outvars:
|
|
|
|
|
_check_aval(v.aval, what_thunk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False):
|
|
|
|
|
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
|
|
|
|
|
# NOTE: This takes in the non-sharded avals!
|
|
|
|
|
def mk_sharding_spec(aval, aval_axes):
|
|
|
|
|
if aval is core.abstract_token:
|
|
|
|
|
assert not aval_axes
|
|
|
|
|
return ShardingSpec([], [Replicated(axis_size) for axis_size in axis_sizes.values()])
|
|
|
|
|
aval_shape = list(aval.shape)
|
|
|
|
|
# NOTE: sorted is stable, which is important when multiple resources
|
|
|
|
|
# map to the same axis.
|
|
|
|
|
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
|
|
|
|
|
if not allow_uneven_axes:
|
|
|
|
|
if aval_shape[axis] % axis_sizes[name] != 0:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f'The aval shape on dimension {axis} is {aval_shape[axis]} and '
|
|
|
|
|
f'the size of axis {name} is {axis_sizes[name]}. The aval shape % '
|
|
|
|
|
'axis size should be zero but got '
|
|
|
|
|
f'{aval_shape[axis] % axis_sizes[name]}')
|
|
|
|
|
aval_shape[axis] //= axis_sizes[name]
|
2023-04-06 09:48:14 -07:00
|
|
|
|
return sharding_specs.make_sharding_spec(
|
|
|
|
|
axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return mk_sharding_spec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def maybe_extend_axis_env(*args, **kwargs):
|
|
|
|
|
with core.extend_axis_env(*args, **kwargs):
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
2023-03-31 11:41:49 -07:00
|
|
|
|
def device_put(x, devices: Sequence[xc.ArrayImpl],
|
2023-06-23 15:11:37 -07:00
|
|
|
|
replicate: bool=False) -> list[xc.ArrayImpl]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
|
|
|
|
|
if replicate:
|
2023-03-16 15:46:57 -07:00
|
|
|
|
return [jax.device_put(x, device) for device in devices]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-03-16 15:46:57 -07:00
|
|
|
|
return [jax.device_put(val, device) for val, device in safe_zip(x, devices)]
|
2023-04-06 09:48:14 -07:00
|
|
|
|
|
|
|
|
|
# TODO(phawkins): fix external users not to use these functions.
|
|
|
|
|
def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
|
|
|
|
|
return sharding_specs.create_pmap_sharding_spec(
|
|
|
|
|
aval.shape, sharded_dim, sharded_dim_size)
|
|
|
|
|
|
|
|
|
|
def _pmap_sharding_spec(nrep, axis_size, npart, parts,
|
|
|
|
|
sharded_aval, map_axis: Optional[int]) -> ShardingSpec:
|
2023-04-12 08:49:07 -07:00
|
|
|
|
assert npart == 1, npart
|
|
|
|
|
assert parts is None, parts
|
|
|
|
|
return sharding_specs.pmap_sharding_spec(
|
|
|
|
|
nrep, axis_size, sharded_aval.shape, map_axis)
|