2023-02-06 14:28:36 -08:00
|
|
|
|
# Copyright 2018 The JAX Authors.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
"""Implementation of pmap and related functionality."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import enum
|
2023-03-10 10:07:37 -08:00
|
|
|
|
from contextlib import contextmanager
|
2023-07-20 09:43:40 -07:00
|
|
|
|
from collections import namedtuple
|
2023-07-21 14:20:39 -04:00
|
|
|
|
from collections.abc import Sequence, Iterable
|
2023-02-06 14:28:36 -08:00
|
|
|
|
import dataclasses
|
|
|
|
|
from functools import partial, lru_cache, cached_property
|
|
|
|
|
import itertools as it
|
|
|
|
|
import logging
|
2023-02-28 12:40:30 -08:00
|
|
|
|
import math
|
2023-09-25 16:41:43 -07:00
|
|
|
|
import threading
|
2023-12-11 13:59:29 +00:00
|
|
|
|
from typing import Any, Callable, NamedTuple, TypeVar, Union, cast
|
2023-12-08 12:09:04 +00:00
|
|
|
|
from collections.abc import Iterator
|
2023-08-08 10:51:38 -07:00
|
|
|
|
import warnings
|
2023-02-06 22:51:50 -08:00
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
import jax
|
|
|
|
|
from jax.errors import JAXTypeError
|
|
|
|
|
|
|
|
|
|
from jax._src import api_util
|
2023-08-15 06:38:56 -07:00
|
|
|
|
from jax._src import compiler
|
2023-10-09 07:28:18 -07:00
|
|
|
|
from jax._src import config
|
|
|
|
|
from jax._src import core
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax._src import dispatch
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import dtypes
|
2023-02-01 17:50:00 -08:00
|
|
|
|
from jax._src import effects
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import linear_util as lu
|
2023-04-10 10:15:08 -07:00
|
|
|
|
from jax._src import mesh as mesh_lib
|
2023-04-06 08:31:47 -07:00
|
|
|
|
from jax._src import op_shardings
|
2023-04-06 09:48:14 -07:00
|
|
|
|
from jax._src import sharding_specs
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax._src import profiler
|
2023-03-13 08:49:39 -07:00
|
|
|
|
from jax._src import sharding_impls
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import source_info_util
|
|
|
|
|
from jax._src import stages
|
2023-07-19 06:47:46 -07:00
|
|
|
|
from jax._src import tree_util
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src import util
|
2023-02-28 07:01:14 -08:00
|
|
|
|
from jax._src import xla_bridge as xb
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax._src.abstract_arrays import array_types
|
2023-07-12 13:08:57 -04:00
|
|
|
|
from jax._src.core import DShapedArray
|
2023-03-20 09:09:15 -07:00
|
|
|
|
from jax._src.core import ShapedArray
|
2023-02-06 22:51:50 -08:00
|
|
|
|
from jax._src.interpreters import ad
|
2023-02-09 15:11:20 -08:00
|
|
|
|
from jax._src.interpreters import batching
|
2023-03-27 13:29:59 -07:00
|
|
|
|
from jax._src.interpreters import partial_eval as pe
|
2023-02-09 15:11:20 -08:00
|
|
|
|
from jax._src.interpreters import mlir
|
2023-02-07 15:00:56 -08:00
|
|
|
|
from jax._src.interpreters import xla
|
2023-11-15 08:48:17 -08:00
|
|
|
|
from jax._src.layout import XLACompatibleLayout, SpecifiedLayout, LayoutRequest
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax._src.lib import xla_client as xc
|
2023-08-09 16:57:28 -07:00
|
|
|
|
from jax._src.lib import xla_extension_version
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
|
from jax._src.lib.mlir.dialects import hlo
|
2023-04-06 11:42:45 -07:00
|
|
|
|
from jax._src.partition_spec import PartitionSpec
|
2023-04-10 10:15:08 -07:00
|
|
|
|
from jax._src.sharding_impls import (
|
2024-03-05 16:20:24 -08:00
|
|
|
|
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
|
|
|
|
|
UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
|
2024-01-23 21:28:33 -08:00
|
|
|
|
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources,
|
2024-03-05 16:20:24 -08:00
|
|
|
|
SingleDeviceSharding, GSPMDSharding)
|
|
|
|
|
from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name,
|
|
|
|
|
tuple_update, tuple_delete, distributed_debug_log,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
unzip2, HashableFunction, weakref_lru_cache)
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
from jax._src.state.types import AbstractRef, RefEffect
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Built in Python lists don't support weak refs but subclasses of lists do.
|
|
|
|
|
class WeakRefList(list):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xe = xc._xla
|
|
|
|
|
|
|
|
|
|
unsafe_map, map = map, safe_map # type: ignore
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
Index = Union[int, slice, tuple[Union[int, slice], ...]]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-06 09:48:14 -07:00
|
|
|
|
NoSharding = sharding_specs.NoSharding
|
|
|
|
|
Chunked = sharding_specs.Chunked
|
|
|
|
|
Unstacked = sharding_specs.Unstacked
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-06 09:48:14 -07:00
|
|
|
|
ShardedAxis = sharding_specs.ShardedAxis
|
|
|
|
|
Replicated = sharding_specs.Replicated
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
|
2023-04-10 10:15:08 -07:00
|
|
|
|
Mesh = mesh_lib.Mesh
|
|
|
|
|
MeshAxisName = sharding_impls.MeshAxisName
|
2023-02-06 14:28:36 -08:00
|
|
|
|
MeshDimAssignment = Union[ShardedAxis, Replicated]
|
2023-04-06 09:48:14 -07:00
|
|
|
|
ShardingSpec = sharding_specs.ShardingSpec
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
### util
|
|
|
|
|
|
|
|
|
|
def identity(x): return x
|
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
def shard_arg(arg, sharding, canonicalize=True):
|
2023-07-20 09:43:40 -07:00
|
|
|
|
if canonicalize:
|
|
|
|
|
arg = xla.canonicalize_dtype(arg)
|
2024-01-05 14:16:32 -08:00
|
|
|
|
return shard_arg_handlers[type(arg)](arg, sharding)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
2023-03-02 13:28:25 -08:00
|
|
|
|
def shard_args(
|
2024-01-05 14:16:32 -08:00
|
|
|
|
shardings: Sequence[sharding_impls.XLACompatibleSharding], args,
|
2023-03-31 11:41:49 -07:00
|
|
|
|
) -> Sequence[jax.Array]:
|
2024-01-05 14:16:32 -08:00
|
|
|
|
return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
@lru_cache(maxsize=1024)
|
|
|
|
|
def get_addressable_devices_for_shard_arg(
|
|
|
|
|
s: sharding_impls.XLACompatibleSharding) -> tuple[xc.Device, ...]:
|
|
|
|
|
return s._addressable_device_assignment
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1024)
|
|
|
|
|
def _get_replicated_slices(num_addressable_devices: int):
|
|
|
|
|
return ((slice(None),),) * num_addressable_devices
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
def _shard_token(x, sharding):
|
|
|
|
|
devices = get_addressable_devices_for_shard_arg(sharding)
|
|
|
|
|
indices = _get_replicated_slices(len(devices))
|
2023-03-15 17:08:21 -07:00
|
|
|
|
zeros = np.zeros((), dtype=np.dtype(np.bool_))
|
|
|
|
|
aval = api_util.shaped_abstractify(zeros)
|
2024-01-05 14:16:32 -08:00
|
|
|
|
return batched_device_put(aval, sharding, [zeros for _ in indices], devices)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
shard_arg_handlers[core.Token] = _shard_token
|
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
def _masked_array_error(x, sharding):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
|
|
|
|
"Use arr.filled() to convert the value to a standard numpy array.")
|
|
|
|
|
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error
|
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
def _shard_array(x, sharding):
|
|
|
|
|
devices = get_addressable_devices_for_shard_arg(sharding)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if x.dtype == dtypes.float0:
|
|
|
|
|
x = np.zeros(x.shape, dtype=np.dtype(bool))
|
2023-03-15 17:08:21 -07:00
|
|
|
|
aval = api_util.shaped_abstractify(x)
|
2024-03-12 16:04:05 -07:00
|
|
|
|
if sharding.is_fully_replicated:
|
|
|
|
|
shards = [x] * len(devices)
|
|
|
|
|
else:
|
|
|
|
|
indices = tuple(sharding.addressable_devices_indices_map(x.shape).values())
|
|
|
|
|
shards = [x[i] for i in indices]
|
|
|
|
|
return batched_device_put(aval, sharding, shards, devices)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for _t in array_types:
|
|
|
|
|
shard_arg_handlers[_t] = _shard_array
|
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
def _shard_darray(x, sharding):
|
|
|
|
|
return shard_arg(x._data, sharding)
|
2023-07-10 15:21:41 -04:00
|
|
|
|
shard_arg_handlers[core.DArray] = _shard_darray
|
|
|
|
|
|
2024-03-01 11:07:45 -08:00
|
|
|
|
def _shard_mutable_array(x, sharding):
|
|
|
|
|
return shard_arg(x._buf, sharding)
|
|
|
|
|
shard_arg_handlers[core.MutableArray] = _shard_mutable_array
|
|
|
|
|
|
2023-03-28 18:30:36 -07:00
|
|
|
|
def batched_device_put(aval: core.ShapedArray,
|
2023-03-14 10:19:03 -07:00
|
|
|
|
sharding: jax.sharding.Sharding, xs: Sequence[Any],
|
|
|
|
|
devices: Sequence[jax.Device], committed: bool = True):
|
|
|
|
|
from jax._src import array
|
|
|
|
|
|
|
|
|
|
bufs = [x for x, d in safe_zip(xs, devices)
|
|
|
|
|
if (isinstance(x, array.ArrayImpl) and
|
|
|
|
|
dispatch.is_single_device_sharding(x.sharding) and
|
2023-11-29 16:52:09 -08:00
|
|
|
|
x.devices() == {d})]
|
2023-03-14 10:19:03 -07:00
|
|
|
|
if len(bufs) == len(xs):
|
|
|
|
|
return array.ArrayImpl(
|
|
|
|
|
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
2024-01-05 14:16:32 -08:00
|
|
|
|
return xc.batched_device_put(aval, sharding, xs, list(devices), committed) # type: ignore
|
2023-03-14 10:19:03 -07:00
|
|
|
|
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
def _shard_aval(size, axis: int, aval):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
try:
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
return _shard_aval_handlers[type(aval)](size, axis, aval)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
except KeyError as err:
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
raise TypeError(f"No _shard_aval handler for type: {type(aval)}") from err
|
|
|
|
|
_shard_aval_handlers: dict[type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _shard_abstract_array(size, axis: int, x):
|
|
|
|
|
try:
|
|
|
|
|
if x.shape[axis] != size:
|
|
|
|
|
raise ValueError(f"Axis size {size} does not match dimension {axis} of "
|
|
|
|
|
f"shape {x.shape}")
|
|
|
|
|
except IndexError:
|
|
|
|
|
raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
if config.pmap_no_rank_reduction.value:
|
|
|
|
|
return x.update(shape=tuple_update(x.shape, axis, 1))
|
|
|
|
|
else:
|
|
|
|
|
return x.update(shape=tuple_delete(x.shape, axis))
|
|
|
|
|
_shard_aval_handlers[ShapedArray] = _shard_abstract_array
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def local_aval_to_result_handler(
|
|
|
|
|
aval: core.AbstractValue,
|
2023-03-13 08:49:39 -07:00
|
|
|
|
sharding: sharding_impls.XLACompatibleSharding,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
indices: tuple[Index, ...] | None,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> Callable[[list[xc.ArrayImpl]], Any]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""Returns a function for handling the raw buffers of a single output aval.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
aval: The local output AbstractValue.
|
|
|
|
|
sharding_spec: Indicates how the output is sharded across devices, or None
|
|
|
|
|
for non-array avals.
|
|
|
|
|
indices: The pre-computed result of spec_to_indices, or None for non-array
|
|
|
|
|
avals.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A function for handling the Buffers that will eventually be produced
|
|
|
|
|
for this output. The function will return an object suitable for returning
|
2023-08-18 16:50:36 -04:00
|
|
|
|
to the user, e.g. an Array.
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""
|
|
|
|
|
try:
|
2023-03-20 09:09:15 -07:00
|
|
|
|
return local_result_handlers[(type(aval))](aval, sharding, indices)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
except KeyError as err:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"No pxla_result_handler for type: {type(aval)}") from err
|
|
|
|
|
|
2023-03-30 20:11:11 +00:00
|
|
|
|
PxlaResultHandler = Callable[..., Callable[[Any], Any]]
|
2023-06-23 15:11:37 -07:00
|
|
|
|
local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def global_aval_to_result_handler(
|
2024-02-28 15:21:50 -08:00
|
|
|
|
aval: core.AbstractValue, out_sharding, committed: bool
|
2023-03-31 11:41:49 -07:00
|
|
|
|
) -> Callable[[Sequence[xc.ArrayImpl]], Any]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""Returns a function for handling the raw buffers of a single output aval.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
aval: The global output AbstractValue.
|
|
|
|
|
out_axis_resources: A PartitionSpec specifying the sharding of outputs.
|
|
|
|
|
Used for creating GSDAs.
|
|
|
|
|
global_mesh: The global device mesh that generated this output. Used
|
|
|
|
|
for creating GSDAs.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A function for handling the Buffers that will eventually be produced
|
|
|
|
|
for this output. The function will return an object suitable for returning
|
2023-08-18 16:50:36 -04:00
|
|
|
|
to the user, e.g. an Array.
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""
|
|
|
|
|
try:
|
2024-02-28 15:21:50 -08:00
|
|
|
|
return global_result_handlers[type(aval)](aval, out_sharding, committed)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
except KeyError as err:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"No pxla_result_handler for type: {type(aval)}") from err
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
global_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
### lazy device-memory persistence and result handling
|
|
|
|
|
|
|
|
|
|
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def xla_pmap_impl_lazy(
|
|
|
|
|
fun: lu.WrappedFun,
|
|
|
|
|
*args,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
backend: str | None,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
axis_name: core.AxisName,
|
|
|
|
|
axis_size: int,
|
|
|
|
|
global_axis_size: int,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
devices: Sequence[Any] | None,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
name: str,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
in_axes: Sequence[int | None],
|
|
|
|
|
out_axes_thunk: Callable[[], Sequence[int | None]],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
is_explicit_global_axis_size: bool,
|
2023-02-15 18:11:55 -08:00
|
|
|
|
) -> Callable:
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if (config.disable_jit.value and config.eager_pmap.value and
|
2023-03-29 09:22:34 -07:00
|
|
|
|
not is_explicit_global_axis_size and not any(d for d in donated_invars)):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _emap_apply_fn(*args):
|
|
|
|
|
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
|
|
|
|
|
axis_size=axis_size, global_axis_size=global_axis_size,
|
|
|
|
|
devices=devices, name=name, in_axes=in_axes,
|
|
|
|
|
out_axes_thunk=out_axes_thunk,
|
|
|
|
|
donated_invars=donated_invars,
|
|
|
|
|
is_explicit_global_axis_size=is_explicit_global_axis_size)
|
|
|
|
|
return _emap_apply_fn
|
|
|
|
|
abstract_args = unsafe_map(xla.abstractify, args)
|
|
|
|
|
compiled_fun, fingerprint = parallel_callable(
|
|
|
|
|
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
|
2023-03-29 09:22:34 -07:00
|
|
|
|
in_axes, out_axes_thunk, donated_invars,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
is_explicit_global_axis_size, *abstract_args)
|
|
|
|
|
|
|
|
|
|
# Don't re-abstractify args unless logging is enabled for performance.
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if config.distributed_debug.value:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
distributed_debug_log(("Running pmapped function", name),
|
|
|
|
|
("python function", fun.f),
|
|
|
|
|
("devices", devices),
|
|
|
|
|
("abstract args", map(xla.abstractify, args)),
|
|
|
|
|
("fingerprint", fingerprint))
|
|
|
|
|
return compiled_fun
|
|
|
|
|
|
|
|
|
|
def xla_pmap_impl(fun: lu.WrappedFun, *args, **params):
|
|
|
|
|
compiled_fun = xla_pmap_impl_lazy(fun, *args, **params)
|
|
|
|
|
return compiled_fun(*args)
|
|
|
|
|
|
|
|
|
|
class EmapInfo(NamedTuple):
|
2023-07-21 14:20:39 -04:00
|
|
|
|
backend: str | None
|
|
|
|
|
devices: Sequence[Any] | None
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
def _emap_impl(fun: lu.WrappedFun, *args,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
backend: str | None,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
axis_name: core.AxisName,
|
|
|
|
|
axis_size: int,
|
|
|
|
|
global_axis_size: int,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
devices: Sequence[Any] | None,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
name: str,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
in_axes: Sequence[int | None],
|
|
|
|
|
out_axes_thunk: Callable[[], Sequence[int | None]],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
is_explicit_global_axis_size: bool,
|
|
|
|
|
):
|
|
|
|
|
from jax._src import array
|
|
|
|
|
# TODO(sharadmv,mattjj): implement these cases
|
|
|
|
|
if any(d for d in donated_invars):
|
|
|
|
|
raise NotImplementedError("Buffer donation not supported in eager pmap.")
|
|
|
|
|
if is_explicit_global_axis_size:
|
|
|
|
|
raise NotImplementedError("Non-default global_axis_size not supported in "
|
|
|
|
|
"eager pmap.")
|
|
|
|
|
|
|
|
|
|
emap_info = EmapInfo(backend, devices)
|
|
|
|
|
shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes]
|
|
|
|
|
with core.new_base_main(MapTrace, emap_info=emap_info) as main:
|
|
|
|
|
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main):
|
|
|
|
|
t = main.with_cur_sublevel()
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
tracers = [MapTracer(t, arg, s) for arg, s in zip(args, shard_axes)]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
ans = fun.call_wrapped(*tracers)
|
|
|
|
|
out_tracers = map(t.full_raise, ans)
|
|
|
|
|
outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers)
|
|
|
|
|
del main
|
|
|
|
|
out_axes = out_axes_thunk()
|
|
|
|
|
|
|
|
|
|
platform = xb.get_backend(backend).platform
|
|
|
|
|
donate_argnums = (1,) if platform in {"cuda", "rocm", "tpu"} else ()
|
|
|
|
|
new_outvals = []
|
|
|
|
|
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
|
|
|
|
|
with jax.disable_jit(False):
|
|
|
|
|
donate_argnums_ = donate_argnums
|
2023-03-20 14:17:25 -07:00
|
|
|
|
if isinstance(outval, array.ArrayImpl):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# We don't want to donate if it's already sharded.
|
|
|
|
|
donate_argnums_ = ()
|
|
|
|
|
out = jax.pmap(
|
|
|
|
|
lambda _, x: x,
|
|
|
|
|
in_axes=(0, out_axis_src.get(axis_name)),
|
|
|
|
|
out_axes=out_axis,
|
|
|
|
|
devices=(None if devices is None else list(devices)),
|
|
|
|
|
backend=backend,
|
|
|
|
|
donate_argnums=donate_argnums_)(np.arange(axis_size), outval)
|
|
|
|
|
new_outvals.append(out)
|
|
|
|
|
return new_outvals
|
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
|
def _map_schedule(idx: tuple[int | None, ...]) -> tuple[int | None, ...]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# In order to do a multi-map (a simultaneous map over several axes), we will
|
|
|
|
|
# nest several maps. Each time we do a map, we "remove" an input axis so we
|
|
|
|
|
# need to update the remaining map axes. For example, if we are to map over
|
|
|
|
|
# the axes 0, 3, and 4, we make three calls to pmap with in_axes as 0, 2, 2.
|
|
|
|
|
return tuple(None if i is None else
|
|
|
|
|
i - sum(j is not None and j < i for j in idx[:l])
|
|
|
|
|
for l, i in enumerate(idx))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# We're often creating `f`s on the fly and we try to carefully make them have
|
|
|
|
|
# the right __hash__ and __eq__. However, despite our attempts pmap's caching
|
|
|
|
|
# still ends up not working, because it has a separate cache per
|
|
|
|
|
# _function object_. Adding this annotation here lets us reuse the same pmap
|
|
|
|
|
# callable for all equivalent primitive pmaps.
|
2023-07-21 14:20:39 -04:00
|
|
|
|
@lru_cache
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName],
|
2023-07-21 14:20:39 -04:00
|
|
|
|
all_axes: list[tuple[int | None, ...]]
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[Callable, dict[core.AxisName, int]]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
used_names = []
|
|
|
|
|
for i, name in reversed(list(enumerate(names))):
|
|
|
|
|
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
|
|
|
|
|
if any(in_axis is not None for in_axis in in_axes):
|
|
|
|
|
f = jax.pmap(
|
|
|
|
|
f,
|
|
|
|
|
in_axes=in_axes,
|
|
|
|
|
axis_name=name,
|
|
|
|
|
out_axes=0,
|
|
|
|
|
backend=info.backend,
|
|
|
|
|
devices=(None if info.devices is None else list(info.devices)))
|
|
|
|
|
used_names.append(name)
|
|
|
|
|
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
|
|
|
|
|
return f, out_shard_axes
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
FakePrimitive = namedtuple("FakePrimitive", ["multiple_results", "bind"])
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
class MapTrace(core.Trace):
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, emap_info):
|
|
|
|
|
super().__init__(*args)
|
|
|
|
|
self.emap_info = emap_info
|
|
|
|
|
|
|
|
|
|
def pure(self, val):
|
|
|
|
|
return MapTracer(self, val, {})
|
|
|
|
|
|
|
|
|
|
def sublift(self, tracer):
|
|
|
|
|
return MapTracer(self, tracer.val, tracer.shard_axes)
|
|
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
|
info = self.main.payload["emap_info"]
|
|
|
|
|
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
|
|
|
|
|
names = tuple(f.name for f in core.thread_local_state.trace_state.axis_env
|
|
|
|
|
if f.main_trace is self.main)
|
|
|
|
|
all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes) # pytype: disable=wrong-arg-types # always-use-return-annotations
|
|
|
|
|
f = HashableFunction(lambda *args: primitive.bind(*args, **params),
|
|
|
|
|
(primitive, tuple(params.items())))
|
|
|
|
|
f_mapped, out_shard_axes = _multi_pmap(f, info, names, all_axes)
|
|
|
|
|
with core.eval_context(), jax.disable_jit(False):
|
|
|
|
|
outvals = f_mapped(*vals)
|
|
|
|
|
if primitive.multiple_results:
|
|
|
|
|
return [MapTracer(self, val, out_shard_axes) for val in outvals]
|
|
|
|
|
return MapTracer(self, outvals, out_shard_axes)
|
|
|
|
|
|
|
|
|
|
def process_call(self, call_primitive, fun, tracers, params):
|
2023-03-23 11:43:49 -07:00
|
|
|
|
raise NotImplementedError
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-08 21:12:40 -07:00
|
|
|
|
def process_map(self, map_primitive, fun, tracers, params):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if params['devices'] is not None:
|
|
|
|
|
raise ValueError("Nested pmap with explicit devices argument.")
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if not config.disable_jit.value:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
bind = HashableFunction(
|
2023-04-08 21:12:40 -07:00
|
|
|
|
lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs),
|
|
|
|
|
(map_primitive, fun))
|
2023-02-07 11:16:01 -08:00
|
|
|
|
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self.process_primitive(fake_primitive, tracers, params)
|
|
|
|
|
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
|
|
|
|
|
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
|
2023-12-14 12:41:56 -08:00
|
|
|
|
vals, shard_axes = unzip2((t.val, t.shard_axes) for t in tracers)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s}
|
|
|
|
|
if ax is not None else s
|
|
|
|
|
for v, ax, s in zip(vals, in_axes, shard_axes)]
|
2023-12-14 12:41:56 -08:00
|
|
|
|
# TODO(mattjj): use _emap_subtrace here?
|
2023-02-06 14:28:36 -08:00
|
|
|
|
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
|
|
|
|
|
t = self.main.with_cur_sublevel()
|
|
|
|
|
in_tracers = map(partial(MapTracer, t), vals, shard_axes)
|
|
|
|
|
ans = fun.call_wrapped(*in_tracers)
|
|
|
|
|
out_tracers = map(t.full_raise, ans)
|
|
|
|
|
out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers)
|
|
|
|
|
del t, in_tracers, ans, out_tracers
|
|
|
|
|
out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst)
|
|
|
|
|
for v, s, dst in zip(out, outaxes, out_axes_thunk()))
|
|
|
|
|
return map(partial(MapTracer, self), out, outaxes)
|
|
|
|
|
|
2023-04-08 21:12:40 -07:00
|
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
|
2023-12-14 12:41:56 -08:00
|
|
|
|
if symbolic_zeros:
|
|
|
|
|
msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. "
|
|
|
|
|
"Please open an issue at https://github.com/google/jax/issues !")
|
|
|
|
|
raise NotImplementedError(msg)
|
|
|
|
|
del prim, jvp, symbolic_zeros # always base main, can drop jvp
|
|
|
|
|
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
|
|
|
|
|
fun, out_axes = _emap_subtrace(fun, self.main, in_axes)
|
|
|
|
|
with core.new_sublevel():
|
|
|
|
|
out_vals = fun.call_wrapped(*in_vals)
|
|
|
|
|
return map(partial(MapTracer, self), out_vals, out_axes())
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
|
2023-03-24 14:42:19 -07:00
|
|
|
|
out_trees, symbolic_zeros):
|
2023-12-14 12:41:56 -08:00
|
|
|
|
if symbolic_zeros:
|
|
|
|
|
msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. "
|
|
|
|
|
"Please open an issue at https://github.com/google/jax/issues !")
|
|
|
|
|
raise NotImplementedError(msg)
|
|
|
|
|
del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp
|
|
|
|
|
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
|
|
|
|
|
fun, out_axes = _emap_subtrace(fun, self.main, in_axes)
|
|
|
|
|
with core.new_sublevel():
|
|
|
|
|
out_vals = fun.call_wrapped(*in_vals)
|
|
|
|
|
return map(partial(MapTracer, self), out_vals, out_axes())
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
def process_axis_index(self, frame):
|
|
|
|
|
bind = HashableFunction(
|
|
|
|
|
lambda _: jax.lax.axis_index(frame.name),
|
|
|
|
|
(jax.lax.axis_index, frame.name))
|
2023-02-07 11:16:01 -08:00
|
|
|
|
fake_primitive = FakePrimitive(multiple_results=False, bind=bind)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
with core.eval_context():
|
|
|
|
|
range = jax.lax.iota(np.int32, frame.size)
|
|
|
|
|
dummy_tracer = MapTracer(self, range, {frame.name: 0})
|
|
|
|
|
return self.process_primitive(fake_primitive, (dummy_tracer,), {})
|
|
|
|
|
|
2023-12-14 12:41:56 -08:00
|
|
|
|
@lu.transformation_with_aux
|
|
|
|
|
def _emap_subtrace(main, in_axes, *in_vals):
|
|
|
|
|
t = main.with_cur_sublevel()
|
|
|
|
|
in_tracers = map(partial(MapTracer, t), in_vals, in_axes)
|
|
|
|
|
ans = yield in_tracers, {}
|
|
|
|
|
out_tracers = map(t.full_raise, ans)
|
|
|
|
|
out_vals, out_axes = unzip2((t.val, t.shard_axes) for t in out_tracers)
|
|
|
|
|
del t, in_tracers, ans, out_tracers
|
|
|
|
|
yield out_vals, out_axes
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
|
2023-07-21 14:20:39 -04:00
|
|
|
|
annotation: int | None) -> int | None:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if annotation is None: return None
|
|
|
|
|
mapped_axes_ = set(mapped_axes)
|
|
|
|
|
return [i for i in range(ndim) if i not in mapped_axes_][annotation]
|
|
|
|
|
|
|
|
|
|
def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
shard_axis_src: dict[core.AxisName, int],
|
2023-07-21 14:20:39 -04:00
|
|
|
|
dst_annotation: int | None
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[Any, dict[core.AxisName, int]]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
shard_axis_out = dict(shard_axis_src)
|
|
|
|
|
src = shard_axis_out.pop(axis_name, None)
|
|
|
|
|
dst = _annot_to_flat(np.ndim(val) + (src is None), shard_axis_out.values(),
|
|
|
|
|
dst_annotation)
|
|
|
|
|
with core.eval_context():
|
|
|
|
|
if src == dst:
|
|
|
|
|
outval = val
|
|
|
|
|
elif type(src) == type(dst) == int:
|
|
|
|
|
outval = batching.moveaxis(val, src, dst)
|
|
|
|
|
shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst)
|
|
|
|
|
elif src is None and dst is not None:
|
|
|
|
|
outval = batching.broadcast(val, axis_size, dst)
|
|
|
|
|
shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()}
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
return outval, shard_axis_out
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def _moveaxis(ndim: int, shard_axes: dict[core.AxisName, int],
|
|
|
|
|
src: int, dst: int) -> dict[core.AxisName, int]:
|
2023-07-21 14:20:39 -04:00
|
|
|
|
lst: list[core.AxisName | None] = [None] * ndim
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for k, v in shard_axes.items():
|
|
|
|
|
lst[v] = k
|
|
|
|
|
name = lst.pop(src)
|
|
|
|
|
lst.insert(dst - (src < dst), name)
|
|
|
|
|
return {name: i for i, name in enumerate(lst) if name is not None}
|
|
|
|
|
|
|
|
|
|
class MapTracer(core.Tracer):
|
|
|
|
|
__slots__ = ["val", "shard_axes"]
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def __init__(self, trace: MapTrace, val, shard_axes: dict[core.AxisName, int]):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self._trace = trace
|
|
|
|
|
self.val = val
|
|
|
|
|
self.shard_axes = shard_axes
|
|
|
|
|
assert all(val < self.val.ndim for val in self.shard_axes.values())
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def aval(self):
|
|
|
|
|
aval = xla.abstractify(self.val)
|
|
|
|
|
shard_axes = dict(self.shard_axes)
|
|
|
|
|
for axis_idx in sorted(shard_axes.values())[::-1]:
|
|
|
|
|
aval = core.mapped_aval(aval.shape[axis_idx], axis_idx, aval)
|
|
|
|
|
return aval
|
|
|
|
|
|
|
|
|
|
def full_lower(self):
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
named_axes = [f"{k}={v}" for k, v in self.shard_axes.items()]
|
|
|
|
|
return f"{self.val}{{{','.join(named_axes)}}}"
|
|
|
|
|
|
|
|
|
|
@lu.cache
|
|
|
|
|
def parallel_callable(fun: lu.WrappedFun,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
backend_name: str | None,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
axis_name: core.AxisName,
|
|
|
|
|
axis_size: int,
|
|
|
|
|
global_axis_size: int,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
devices: Sequence[Any] | None,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
name: str,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
in_axes: Sequence[int | None],
|
|
|
|
|
out_axes_thunk: Callable[[], Sequence[int | None]],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
is_explicit_global_axis_size: bool,
|
|
|
|
|
*avals):
|
|
|
|
|
pmap_computation = lower_parallel_callable(
|
|
|
|
|
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
|
2023-03-29 09:22:34 -07:00
|
|
|
|
in_axes, out_axes_thunk, donated_invars,
|
2023-09-28 12:44:14 +02:00
|
|
|
|
is_explicit_global_axis_size, avals,
|
|
|
|
|
lowering_parameters=mlir.LoweringParameters())
|
2023-02-06 14:28:36 -08:00
|
|
|
|
pmap_executable = pmap_computation.compile()
|
|
|
|
|
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class ParallelCallableInfo:
|
|
|
|
|
name: str
|
2023-02-16 11:54:25 -08:00
|
|
|
|
backend: xc.Client
|
2023-02-06 14:28:36 -08:00
|
|
|
|
axis_name: core.AxisName
|
|
|
|
|
axis_size: int
|
|
|
|
|
global_axis_size: int
|
2023-07-21 14:20:39 -04:00
|
|
|
|
devices: Sequence[xc.Device] | None
|
|
|
|
|
in_axes: Iterable[int | None]
|
|
|
|
|
out_axes_thunk: Callable[[], Sequence[int | None]]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
avals: Sequence[core.AbstractValue]
|
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def local_devices(self):
|
|
|
|
|
if self.devices:
|
|
|
|
|
out = [d for d in self.devices
|
|
|
|
|
if d.process_index == xb.process_index(self.backend)]
|
|
|
|
|
assert len(out) > 0
|
|
|
|
|
else:
|
|
|
|
|
out = None # type: ignore
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def out_axes(self):
|
|
|
|
|
return self.out_axes_thunk()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShardInfo(NamedTuple):
|
|
|
|
|
sharded_avals: Sequence[core.AbstractValue]
|
2023-04-12 08:49:07 -07:00
|
|
|
|
out_sharded_avals: Sequence[core.ShapedArray]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
global_sharded_avals: Sequence[core.AbstractValue]
|
|
|
|
|
num_local_shards: int
|
|
|
|
|
num_global_shards: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReplicaInfo(NamedTuple):
|
|
|
|
|
jaxpr_replicas: int
|
|
|
|
|
num_local_replicas: int
|
|
|
|
|
num_global_replicas: int
|
|
|
|
|
|
|
|
|
|
|
2023-04-12 12:53:32 -07:00
|
|
|
|
def find_replicas(
|
|
|
|
|
jaxpr: core.Jaxpr, axis_size: int, global_axis_size: int
|
|
|
|
|
) -> ReplicaInfo:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
|
|
|
|
|
jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr)
|
|
|
|
|
num_local_replicas = axis_size * jaxpr_replicas
|
|
|
|
|
num_global_replicas = global_axis_size * jaxpr_replicas
|
|
|
|
|
return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas)
|
|
|
|
|
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
@lu.transformation
|
|
|
|
|
def _change_argument_ranks(in_axes, out_axes_thunk, *args):
|
|
|
|
|
args = tuple(
|
|
|
|
|
arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,))
|
|
|
|
|
for in_axis, arg in zip(in_axes, args)
|
|
|
|
|
)
|
|
|
|
|
results = yield (args, {})
|
|
|
|
|
out_axes = out_axes_thunk()
|
|
|
|
|
yield tuple(
|
|
|
|
|
x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,))
|
|
|
|
|
for x, axis in zip(results, out_axes)
|
|
|
|
|
)
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
def stage_parallel_callable(
|
2023-04-12 12:53:32 -07:00
|
|
|
|
pci: ParallelCallableInfo, fun: lu.WrappedFun
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[core.Jaxpr, list[Any], ReplicaInfo, ShardInfo]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
sharded_avals = tuple(
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
_shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for axis, aval in safe_zip(pci.in_axes, pci.avals))
|
|
|
|
|
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
orig_fun = fun
|
|
|
|
|
if config.pmap_no_rank_reduction.value:
|
|
|
|
|
fun = _change_argument_ranks(fun, pci.in_axes, pci.out_axes_thunk)
|
|
|
|
|
else:
|
|
|
|
|
fun = orig_fun
|
2023-02-06 14:28:36 -08:00
|
|
|
|
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
|
2023-05-15 09:15:22 -07:00
|
|
|
|
with dispatch.log_elapsed_time(
|
|
|
|
|
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
|
|
|
|
|
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
|
2023-03-29 09:22:34 -07:00
|
|
|
|
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
|
|
|
|
|
|
|
|
|
assert len(out_sharded_avals) == len(pci.out_axes), (
|
|
|
|
|
len(out_sharded_avals), len(pci.out_axes))
|
|
|
|
|
|
|
|
|
|
replicas = find_replicas(jaxpr, pci.axis_size, pci.global_axis_size)
|
2023-04-12 08:49:07 -07:00
|
|
|
|
num_local_shards = replicas.num_local_replicas
|
|
|
|
|
num_global_shards = replicas.num_global_replicas
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
shards = ShardInfo(
|
2023-03-29 09:22:34 -07:00
|
|
|
|
sharded_avals, out_sharded_avals, sharded_avals,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
num_local_shards, num_global_shards)
|
|
|
|
|
|
2023-04-12 08:49:07 -07:00
|
|
|
|
return jaxpr, consts, replicas, shards
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
|
|
|
|
def lower_parallel_callable(
|
|
|
|
|
fun: lu.WrappedFun,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
backend_name: str | None,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
axis_name: core.AxisName,
|
|
|
|
|
axis_size: int,
|
|
|
|
|
global_axis_size: int,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
devices: Sequence[xc.Device] | None,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
name: str,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
in_axes: Iterable[int | None],
|
|
|
|
|
out_axes_thunk: Callable[[], Sequence[int | None]],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
is_explicit_global_axis_size: bool,
|
2023-02-28 11:30:23 +01:00
|
|
|
|
avals: Sequence[core.AbstractValue],
|
|
|
|
|
*,
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
lowering_parameters: mlir.LoweringParameters) -> PmapComputation:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# Determine global_axis_size for use in AxisEnv.
|
|
|
|
|
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
|
|
|
|
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
|
|
|
|
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
|
|
|
|
|
if (xb.process_count() == 1 and is_explicit_global_axis_size
|
|
|
|
|
and global_axis_size != axis_size):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Specified axis_size {global_axis_size} doesn't match received "
|
|
|
|
|
f"axis_size {axis_size}.")
|
|
|
|
|
|
|
|
|
|
if devices is not None and backend_name is None:
|
|
|
|
|
backend = xb.get_device_backend(devices[0])
|
|
|
|
|
else:
|
|
|
|
|
backend = xb.get_backend(backend_name)
|
|
|
|
|
|
|
|
|
|
no_nested_sharding = False
|
|
|
|
|
must_run_on_all_devices = False
|
|
|
|
|
if not is_explicit_global_axis_size:
|
|
|
|
|
if xb.process_count(backend) > 1:
|
|
|
|
|
if devices:
|
|
|
|
|
# This allows each host in a multi-host pmap to run on a different number
|
2023-04-12 08:49:07 -07:00
|
|
|
|
# of devices, but precludes nested sharding (i.e. inner pmaps).
|
2023-02-06 14:28:36 -08:00
|
|
|
|
no_nested_sharding = True
|
|
|
|
|
else:
|
|
|
|
|
# This assumes all hosts run on the same number of devices. We make sure
|
|
|
|
|
# this assumption is true by requiring that the pmap is run on all devices
|
|
|
|
|
# (and making the further assumption that each host has the same number of
|
|
|
|
|
# devices). Nested sharding is ok in this case.
|
|
|
|
|
must_run_on_all_devices = True
|
|
|
|
|
|
|
|
|
|
pci = ParallelCallableInfo(
|
|
|
|
|
name, backend, axis_name, axis_size, global_axis_size, devices,
|
|
|
|
|
in_axes, out_axes_thunk, avals)
|
2023-04-12 08:49:07 -07:00
|
|
|
|
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if logger.isEnabledFor(logging.DEBUG):
|
|
|
|
|
logger.debug("sharded_avals: %s", shards.sharded_avals)
|
|
|
|
|
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
|
|
|
|
|
logger.debug("num_replicas: %d num_local_replicas: %d",
|
|
|
|
|
replicas.num_global_replicas, replicas.num_local_replicas)
|
|
|
|
|
logger.debug("devices: %s", devices)
|
|
|
|
|
logger.debug("local_devices: %s", pci.local_devices)
|
|
|
|
|
|
|
|
|
|
if (xb.process_count(backend) > 1 and must_run_on_all_devices and
|
|
|
|
|
shards.num_local_shards != xb.local_device_count(backend)):
|
|
|
|
|
if shards.num_local_shards == axis_size:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"On multi-host platforms, the input to pmapped functions must have "
|
|
|
|
|
f"leading axis size equal to the number of local devices if no "
|
|
|
|
|
f"`devices` argument is specified. Got {axis_size=}, "
|
|
|
|
|
f"num_local_devices={xb.local_device_count(backend)}")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"On multi-host platforms, pmapped functions must run across all "
|
|
|
|
|
f"devices, i.e. num_replicas * num_partitions should equal the "
|
|
|
|
|
f"number of local devices. Got "
|
2023-04-12 08:49:07 -07:00
|
|
|
|
f"num_replicas={replicas.num_local_replicas}, and "
|
2023-02-06 14:28:36 -08:00
|
|
|
|
f"num_local_devices={xb.local_device_count(backend)}")
|
|
|
|
|
|
2023-04-12 08:49:07 -07:00
|
|
|
|
if no_nested_sharding and replicas.jaxpr_replicas > 1:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise ValueError(
|
|
|
|
|
f"On multi-host platforms, pmapped functions that both have `devices` "
|
2023-04-12 08:49:07 -07:00
|
|
|
|
f"specified and contain an inner_pmap must specify an "
|
2023-02-06 14:28:36 -08:00
|
|
|
|
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
|
2023-04-12 08:49:07 -07:00
|
|
|
|
f"{replicas.jaxpr_replicas}")
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-10-09 07:28:18 -07:00
|
|
|
|
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
2023-04-17 07:52:56 -07:00
|
|
|
|
if logger.isEnabledFor(log_priority):
|
|
|
|
|
logger.log(log_priority,
|
|
|
|
|
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
|
|
|
|
|
fun.__name__, id(fun),
|
|
|
|
|
shards.num_global_shards, avals, replicas.num_global_replicas)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
|
axis_env = sharding_impls.AxisEnv(
|
2023-02-06 14:28:36 -08:00
|
|
|
|
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
|
2023-02-27 11:37:10 -08:00
|
|
|
|
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
|
2024-03-04 05:41:29 -08:00
|
|
|
|
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
|
2023-02-06 14:28:36 -08:00
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
|
|
|
|
replicated_args = [axis is None for axis in in_axes]
|
|
|
|
|
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
|
|
|
|
|
backend.platform)
|
|
|
|
|
module_name = f"pmap_{fun.__name__}"
|
|
|
|
|
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
2023-02-01 17:50:00 -08:00
|
|
|
|
ordered_effects = list(
|
|
|
|
|
effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
|
|
|
|
if ordered_effects:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise ValueError("Ordered effects not supported in `pmap`.")
|
2023-02-01 17:50:00 -08:00
|
|
|
|
unordered_effects = list(
|
|
|
|
|
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
2023-05-15 08:07:31 -07:00
|
|
|
|
with dispatch.log_elapsed_time(
|
2023-05-15 09:15:22 -07:00
|
|
|
|
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
2023-05-15 08:07:31 -07:00
|
|
|
|
lowering_result = mlir.lower_jaxpr_to_module(
|
|
|
|
|
module_name,
|
|
|
|
|
closed_jaxpr,
|
2023-09-28 12:44:14 +02:00
|
|
|
|
ordered_effects=ordered_effects,
|
|
|
|
|
backend_or_name=backend,
|
2023-10-25 10:39:47 -07:00
|
|
|
|
platforms=lowering_parameters.platforms or (backend.platform,),
|
2023-09-28 12:44:14 +02:00
|
|
|
|
axis_context=sharding_impls.ReplicaAxisContext(axis_env),
|
|
|
|
|
name_stack=name_stack,
|
|
|
|
|
donated_args=donated_invars,
|
2023-05-15 08:07:31 -07:00
|
|
|
|
replicated_args=replicated_args,
|
|
|
|
|
arg_shardings=None,
|
|
|
|
|
result_shardings=None,
|
|
|
|
|
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
|
|
|
|
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
2023-09-28 12:44:14 +02:00
|
|
|
|
num_replicas=replicas.num_global_replicas,
|
|
|
|
|
lowering_parameters=lowering_parameters)
|
2023-04-21 14:37:52 -07:00
|
|
|
|
return PmapComputation(lowering_result.module, pci=pci, replicas=replicas,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
shards=shards, tuple_args=tuple_args,
|
|
|
|
|
unordered_effects=unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects,
|
2023-04-21 14:37:52 -07:00
|
|
|
|
keepalive=lowering_result.keepalive,
|
|
|
|
|
host_callbacks=lowering_result.host_callbacks,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
def _pmap_unmap_shaped_array(
|
|
|
|
|
size: int, axis_name: core.AxisName, axis: int | None, aval: ShapedArray
|
|
|
|
|
) -> ShapedArray:
|
|
|
|
|
named_shape = dict(aval.named_shape)
|
|
|
|
|
named_shape.pop(axis_name, None) # TODO: make this mandatory
|
|
|
|
|
if axis is None: return aval.update(named_shape=named_shape)
|
|
|
|
|
elif type(axis) is int:
|
|
|
|
|
return ShapedArray(tuple_update(aval.shape, axis, size), aval.dtype,
|
|
|
|
|
named_shape=named_shape, weak_type=aval.weak_type)
|
|
|
|
|
else: raise TypeError(axis)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AvalMapHandlerPair = tuple[Any, Callable]
|
|
|
|
|
_pmap_aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
|
|
|
|
|
ShapedArray: (Any, _pmap_unmap_shaped_array),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _pmap_unmapped_aval(size: core.AxisSize, axis_name, axis: int | None,
|
|
|
|
|
aval: core.AbstractValue) -> core.AbstractValue:
|
|
|
|
|
if not config.pmap_no_rank_reduction.value:
|
|
|
|
|
return core.unmapped_aval(size, axis_name, axis, aval)
|
|
|
|
|
|
|
|
|
|
_, handler = _pmap_aval_mapping_handlers.get(type(aval), (None, None))
|
|
|
|
|
if handler is not None:
|
|
|
|
|
return handler(size, axis_name, axis, aval)
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
class PmapComputation(stages.XlaLowering):
|
|
|
|
|
_hlo: ir.Module
|
2023-07-21 14:20:39 -04:00
|
|
|
|
_executable: PmapExecutable | None
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
def __init__(self, hlo: ir.Module, **compile_args):
|
|
|
|
|
self._executable = None
|
|
|
|
|
self._hlo = hlo
|
|
|
|
|
self.compile_args = compile_args
|
|
|
|
|
|
|
|
|
|
# -- stages.XlaLowering overrides
|
|
|
|
|
|
|
|
|
|
def stablehlo(self) -> ir.Module:
|
|
|
|
|
return self._hlo
|
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
2023-03-30 17:13:46 -07:00
|
|
|
|
def compile(self, compiler_options=None) -> PmapExecutable:
|
|
|
|
|
if self._executable is None or compiler_options is not None:
|
|
|
|
|
executable = UnloadedPmapExecutable.from_hlo(
|
|
|
|
|
self._hlo, **self.compile_args,
|
|
|
|
|
compiler_options=compiler_options)
|
|
|
|
|
if compiler_options is None:
|
|
|
|
|
self._executable = executable
|
|
|
|
|
return executable
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self._executable
|
|
|
|
|
|
2023-04-12 08:49:07 -07:00
|
|
|
|
def _cast_to_shaped_array(aval: core.AbstractValue) -> ShapedArray:
|
|
|
|
|
assert isinstance(aval, ShapedArray), aval
|
|
|
|
|
return cast(ShapedArray, aval)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class UnloadedPmapExecutable:
|
|
|
|
|
compiled: Any
|
|
|
|
|
backend: xb.XlaBackend
|
2023-02-14 23:00:40 -08:00
|
|
|
|
local_input_avals: Sequence[core.AbstractValue]
|
2023-03-13 08:49:39 -07:00
|
|
|
|
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
local_output_avals: Sequence[ShapedArray]
|
2023-03-13 08:49:39 -07:00
|
|
|
|
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
2023-06-23 15:11:37 -07:00
|
|
|
|
unordered_effects: list[core.Effect]
|
|
|
|
|
ordered_effects: list[core.Effect]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
keepalive: Sequence[Any]
|
|
|
|
|
host_callbacks: Sequence[Any]
|
2023-04-19 15:08:21 -07:00
|
|
|
|
jaxpr_debug_info: core.JaxprDebugInfo
|
|
|
|
|
|
|
|
|
|
def build_execute_fun(self):
|
|
|
|
|
input_indices = []
|
|
|
|
|
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
|
|
|
|
|
assert isinstance(spec, sharding_impls.PmapSharding), spec
|
|
|
|
|
assert isinstance(aval, core.ShapedArray), aval
|
|
|
|
|
input_indices.append(
|
|
|
|
|
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
|
|
|
|
|
if spec.sharding_spec is not None else None)
|
|
|
|
|
handle_outs = local_avals_to_results_handler(self.local_output_avals,
|
|
|
|
|
self.output_shardings)
|
2024-01-05 14:16:32 -08:00
|
|
|
|
handle_args = InputsHandler(self.input_shardings,
|
|
|
|
|
self.compiled.local_devices(), input_indices)
|
2023-04-19 15:08:21 -07:00
|
|
|
|
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
|
|
|
|
|
self.backend, handle_args, handle_outs,
|
|
|
|
|
self.unordered_effects,
|
|
|
|
|
self.ordered_effects, self.keepalive,
|
|
|
|
|
bool(self.host_callbacks),
|
2024-03-01 09:27:57 -08:00
|
|
|
|
set(range(len(input_indices))), None)
|
2023-04-19 15:08:21 -07:00
|
|
|
|
return execute_fun
|
|
|
|
|
|
|
|
|
|
def load(self) -> PmapExecutable:
|
|
|
|
|
fingerprint = getattr(self.compiled, "fingerprint", None)
|
|
|
|
|
|
|
|
|
|
return PmapExecutable(
|
|
|
|
|
self.compiled, self.build_execute_fun, fingerprint,
|
|
|
|
|
self.local_input_avals, self.jaxpr_debug_info, self)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
@staticmethod
|
2023-04-21 14:37:52 -07:00
|
|
|
|
def from_hlo(hlo: ir.Module,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
pci: ParallelCallableInfo,
|
|
|
|
|
replicas: ReplicaInfo,
|
|
|
|
|
shards: ShardInfo,
|
|
|
|
|
tuple_args: bool,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
unordered_effects: list[core.Effect],
|
|
|
|
|
ordered_effects: list[core.Effect],
|
|
|
|
|
host_callbacks: list[Any],
|
2023-03-30 17:13:46 -07:00
|
|
|
|
keepalive: Any,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
jaxpr_debug_info: core.JaxprDebugInfo,
|
2023-03-30 17:13:46 -07:00
|
|
|
|
compiler_options=None):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
devices = pci.devices
|
|
|
|
|
if devices is None:
|
|
|
|
|
if shards.num_global_shards > xb.device_count(pci.backend):
|
|
|
|
|
msg = ("compiling computation that requires {} logical devices, but only {} XLA "
|
2023-04-12 08:49:07 -07:00
|
|
|
|
"devices are available (num_replicas={})")
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise ValueError(msg.format(shards.num_global_shards,
|
|
|
|
|
xb.device_count(pci.backend),
|
2023-04-12 08:49:07 -07:00
|
|
|
|
replicas.num_global_replicas))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# On a single host, we simply grab the first N devices from jax.devices().
|
|
|
|
|
# In the single host case, we want the default device order of pmap to
|
|
|
|
|
# match jax.devices().
|
|
|
|
|
# On multiple hosts, we create a default device assignment that ensures
|
|
|
|
|
# each host is responsible for a contiguous set of replicas.
|
|
|
|
|
if shards.num_global_shards > shards.num_local_shards:
|
|
|
|
|
# TODO(skye): use a locality-aware assignment that satisfies the above
|
|
|
|
|
# constraint.
|
|
|
|
|
devices = [d for process_index in range(xb.process_count(pci.backend))
|
|
|
|
|
for d in xb.local_devices(process_index, pci.backend)]
|
|
|
|
|
else:
|
|
|
|
|
devices = xb.local_devices(backend=pci.backend)[:shards.num_local_shards]
|
|
|
|
|
else:
|
|
|
|
|
if shards.num_local_shards != len(pci.local_devices):
|
|
|
|
|
local_devices_str = ", ".join(map(str, pci.local_devices))
|
|
|
|
|
if shards.num_local_shards == pci.axis_size:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Leading axis size of input to pmapped function must equal the "
|
|
|
|
|
f"number of local devices passed to pmap. Got axis_size="
|
|
|
|
|
f"{pci.axis_size}, num_local_devices={len(pci.local_devices)}.\n"
|
|
|
|
|
f"(Local devices available to pmap: {local_devices_str})")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"pmapped function requires {shards.num_local_shards} local "
|
|
|
|
|
f"devices to run due to nested pmapped or other parallel "
|
|
|
|
|
f"functions, but only {len(pci.local_devices)} are available.\n"
|
|
|
|
|
f"(outer axis size: {pci.axis_size}, local devices available to "
|
|
|
|
|
f"pmap: {local_devices_str})")
|
|
|
|
|
if shards.num_global_shards != len(devices):
|
|
|
|
|
raise ValueError("compiling computation that creates %s shards, "
|
|
|
|
|
"but %s devices were specified" %
|
|
|
|
|
(shards.num_global_shards, len(devices)))
|
|
|
|
|
|
|
|
|
|
# 'devices' may be 1D or 2D at this point (e.g.
|
|
|
|
|
# get_default_device_assignment() returns 2D assignment, caller may have
|
|
|
|
|
# provided 1D list of devices).
|
|
|
|
|
# Convert to 2D in case it's 1D and we have > 1 partitions.
|
2023-04-12 08:49:07 -07:00
|
|
|
|
num_partitions = 1
|
2023-02-06 14:28:36 -08:00
|
|
|
|
device_assignment: np.ndarray = np.array(devices).reshape(
|
2023-04-12 08:49:07 -07:00
|
|
|
|
(replicas.num_global_replicas, num_partitions))
|
2023-08-15 06:38:56 -07:00
|
|
|
|
compile_options = compiler.get_compile_options(
|
2023-02-06 14:28:36 -08:00
|
|
|
|
num_replicas=replicas.num_global_replicas,
|
2023-04-12 08:49:07 -07:00
|
|
|
|
num_partitions=num_partitions,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
device_assignment=device_assignment,
|
2023-04-12 08:49:07 -07:00
|
|
|
|
use_spmd_partitioning=False,
|
2023-03-30 17:13:46 -07:00
|
|
|
|
env_options_overrides=compiler_options,
|
2023-11-20 15:51:27 -08:00
|
|
|
|
detailed_logging=compiler.use_detailed_logging(hlo),
|
|
|
|
|
backend=pci.backend,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
)
|
|
|
|
|
compile_options.parameter_is_tupled_arguments = tuple_args
|
|
|
|
|
|
|
|
|
|
process_index = xb.process_index(pci.backend)
|
|
|
|
|
local_device_assignment = np.array([
|
|
|
|
|
d for d in device_assignment.flat if d.process_index == process_index
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
input_sharding_specs = [
|
2023-04-06 09:48:14 -07:00
|
|
|
|
sharding_specs.pmap_sharding_spec(
|
|
|
|
|
replicas.num_local_replicas, pci.axis_size,
|
|
|
|
|
cast(ShapedArray, aval).shape, in_axis)
|
2023-04-12 08:49:07 -07:00
|
|
|
|
for aval, in_axis in safe_zip(shards.sharded_avals, pci.in_axes)]
|
|
|
|
|
in_shardings = _get_pmap_sharding(local_device_assignment,
|
|
|
|
|
input_sharding_specs)
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
local_unmapped_avals = [
|
2023-04-12 08:49:07 -07:00
|
|
|
|
_cast_to_shaped_array(
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
|
_pmap_unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if out_axis is not None else aval
|
2023-04-12 08:49:07 -07:00
|
|
|
|
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_specs = [
|
2023-04-06 09:48:14 -07:00
|
|
|
|
sharding_specs.pmap_sharding_spec(
|
2023-04-12 08:49:07 -07:00
|
|
|
|
replicas.num_local_replicas, pci.axis_size, aval.shape, out_axis)
|
|
|
|
|
for aval, out_axis in safe_zip(
|
|
|
|
|
shards.out_sharded_avals, pci.out_axes)]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
|
|
|
|
|
|
|
|
|
|
if hasattr(pci.backend, "compile_replicated"):
|
|
|
|
|
input_indices = [
|
2023-04-06 09:48:14 -07:00
|
|
|
|
sharding_specs.spec_to_indices(aval.shape, spec)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if spec is not None else None
|
|
|
|
|
for aval, spec in safe_zip(pci.avals, input_sharding_specs)
|
|
|
|
|
]
|
|
|
|
|
handle_outs = local_avals_to_results_handler(local_unmapped_avals,
|
|
|
|
|
out_shardings)
|
|
|
|
|
return _compile_replicated_pmap_executable_from_hlo(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
hlo, pci, input_indices, in_shardings, handle_outs,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
compile_options, host_callbacks, bool(unordered_effects),
|
2023-04-19 15:08:21 -07:00
|
|
|
|
ordered_effects, jaxpr_debug_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
with dispatch.log_elapsed_time(
|
2023-05-15 09:15:22 -07:00
|
|
|
|
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
|
2023-08-15 06:38:56 -07:00
|
|
|
|
compiled = compiler.compile_or_get_cached(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
pci.backend, hlo, device_assignment, compile_options,
|
2023-04-20 06:16:12 -07:00
|
|
|
|
host_callbacks)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
return UnloadedPmapExecutable(
|
|
|
|
|
compiled=compiled,
|
|
|
|
|
backend=pci.backend,
|
|
|
|
|
local_input_avals=pci.avals,
|
|
|
|
|
input_shardings=in_shardings,
|
|
|
|
|
local_output_avals=local_unmapped_avals,
|
|
|
|
|
output_shardings=out_shardings,
|
|
|
|
|
unordered_effects=unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects,
|
|
|
|
|
keepalive=keepalive,
|
|
|
|
|
host_callbacks=host_callbacks,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
jaxpr_debug_info=jaxpr_debug_info).load()
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-03-13 14:08:48 -07:00
|
|
|
|
def _compile_replicated_pmap_executable_from_hlo(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
hlo: ir.Module, pci, input_indices, in_shardings, handle_outs,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
compile_options, host_callbacks, has_unordered_effects, ordered_effects,
|
|
|
|
|
jaxpr_debug_info):
|
2023-03-13 14:08:48 -07:00
|
|
|
|
# Use the standard out_handler.
|
|
|
|
|
execute_fun = pci.backend.compile_replicated(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
is_trivial=False, name=pci.name, computation=hlo,
|
2023-03-13 14:08:48 -07:00
|
|
|
|
compile_options=compile_options, host_callbacks=host_callbacks,
|
|
|
|
|
has_unordered_effects=has_unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects, in_avals=pci.avals,
|
|
|
|
|
in_indices=input_indices, in_shardings=in_shardings,
|
|
|
|
|
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
|
|
|
|
|
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
2023-04-19 15:08:21 -07:00
|
|
|
|
return PmapExecutable(None, lambda: execute_fun, None, pci.avals,
|
|
|
|
|
jaxpr_debug_info, None)
|
2023-03-13 14:08:48 -07:00
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
class PmapExecutable(stages.XlaExecutable):
|
2023-03-22 17:22:39 -07:00
|
|
|
|
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
|
2023-04-19 15:08:21 -07:00
|
|
|
|
"fingerprint", "in_avals", "_jaxpr_debug_info",
|
|
|
|
|
"_unloaded_executable"]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-22 17:22:39 -07:00
|
|
|
|
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
|
2023-04-19 15:08:21 -07:00
|
|
|
|
in_avals, jaxpr_debug_info, unloaded_executable):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.xla_executable = xla_executable
|
2023-03-22 17:22:39 -07:00
|
|
|
|
self._unsafe_call = None
|
|
|
|
|
self.build_unsafe_call = build_unsafe_call
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.fingerprint = fingerprint
|
|
|
|
|
self.in_avals = in_avals
|
2023-04-19 15:08:21 -07:00
|
|
|
|
self._jaxpr_debug_info = jaxpr_debug_info
|
2023-03-22 17:22:39 -07:00
|
|
|
|
self._unloaded_executable = unloaded_executable
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def unsafe_call(self) -> Callable[..., Any]:
|
|
|
|
|
if self._unsafe_call is None:
|
|
|
|
|
self._unsafe_call = self.build_unsafe_call()
|
|
|
|
|
return self._unsafe_call
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# -- stages.XlaExecutable overrides
|
|
|
|
|
|
|
|
|
|
def xla_extension_executable(self):
|
|
|
|
|
return self.xla_executable
|
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
|
|
|
|
def call(self, *args):
|
|
|
|
|
# TODO(frostig): do we need to check sharding and sharded avals?
|
|
|
|
|
arg_avals = map(xla.abstractify, args)
|
2023-04-19 15:08:21 -07:00
|
|
|
|
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
|
2023-03-22 17:22:39 -07:00
|
|
|
|
return self.unsafe_call(*args) # pylint: disable=not-callable
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_pmap_sharding(devices, specs):
|
2023-03-13 08:49:39 -07:00
|
|
|
|
return [sharding_impls.PmapSharding(devices, spec) for spec in specs]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InputsHandler:
|
|
|
|
|
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
|
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
def __init__(self, in_shardings, local_devices=None, input_indices=None):
|
|
|
|
|
self.handler = partial(shard_args, in_shardings)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.local_devices = local_devices
|
|
|
|
|
self.in_shardings = in_shardings
|
|
|
|
|
self.input_indices = input_indices
|
|
|
|
|
|
|
|
|
|
def __call__(self, input_buffers):
|
|
|
|
|
return self.handler(input_buffers)
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
return ("InputsHandler(\n"
|
|
|
|
|
f"local_devices={self.local_devices},\n"
|
|
|
|
|
f"in_shardings={self.in_shardings},\n"
|
|
|
|
|
f"input_indices={self.input_indices})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResultsHandler:
|
2024-03-18 11:41:17 -07:00
|
|
|
|
# `out_avals` is the `Array` global avals when using pjit or xmap. It is the
|
|
|
|
|
# local one when using `pmap`.
|
2023-02-06 14:28:36 -08:00
|
|
|
|
__slots__ = ("handlers", "out_shardings", "out_avals")
|
|
|
|
|
|
|
|
|
|
def __init__(self, handlers, out_shardings, out_avals):
|
|
|
|
|
self.handlers = handlers
|
|
|
|
|
self.out_shardings = out_shardings
|
|
|
|
|
self.out_avals = out_avals
|
|
|
|
|
|
|
|
|
|
def __call__(self, out_bufs):
|
|
|
|
|
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def local_avals_to_results_handler(
|
|
|
|
|
unmapped_local_out_avals: Sequence[ShapedArray],
|
2023-03-13 08:49:39 -07:00
|
|
|
|
local_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> ResultsHandler:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
|
|
|
|
|
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
|
|
|
|
|
handlers = [
|
|
|
|
|
local_aval_to_result_handler(aval, s, idcs)
|
|
|
|
|
for aval, s, idcs in safe_zip(unmapped_local_out_avals, local_shardings, out_indices)
|
|
|
|
|
]
|
|
|
|
|
return ResultsHandler(handlers, local_shardings, unmapped_local_out_avals)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def global_avals_to_results_handler(
|
|
|
|
|
global_out_avals: Sequence[ShapedArray],
|
2023-03-13 08:49:39 -07:00
|
|
|
|
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
2024-02-28 15:21:50 -08:00
|
|
|
|
committed: bool) -> ResultsHandler:
|
2023-03-15 17:08:21 -07:00
|
|
|
|
handlers = [
|
2024-02-28 15:21:50 -08:00
|
|
|
|
global_aval_to_result_handler(global_aval, s, committed)
|
|
|
|
|
for global_aval, s in safe_zip(global_out_avals, shardings)
|
2023-03-15 17:08:21 -07:00
|
|
|
|
]
|
|
|
|
|
return ResultsHandler(handlers, shardings, global_out_avals)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExecuteReplicated:
|
|
|
|
|
"""The logic to shard inputs, execute a replicated model, returning outputs."""
|
|
|
|
|
__slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler',
|
|
|
|
|
'has_unordered_effects', 'ordered_effects', 'keepalive',
|
|
|
|
|
'has_host_callbacks', '_local_devices', 'kept_var_idx',
|
2024-03-05 16:20:24 -08:00
|
|
|
|
'mut', '__weakref__']
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
|
|
|
|
|
out_handler: ResultsHandler,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
unordered_effects: list[core.Effect],
|
|
|
|
|
ordered_effects: list[core.Effect], keepalive: Any,
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
has_host_callbacks: bool, kept_var_idx: set[int],
|
2024-03-05 16:20:24 -08:00
|
|
|
|
mut: MutationData | None):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.xla_executable = xla_executable
|
|
|
|
|
self.name = name
|
|
|
|
|
self.backend = backend
|
|
|
|
|
self.in_handler = in_handler
|
|
|
|
|
self.out_handler = out_handler
|
|
|
|
|
self.has_unordered_effects = bool(unordered_effects)
|
|
|
|
|
self.ordered_effects = ordered_effects
|
|
|
|
|
self._local_devices = self.xla_executable.local_devices()
|
|
|
|
|
self.keepalive = keepalive
|
|
|
|
|
self.has_host_callbacks = has_host_callbacks
|
|
|
|
|
self.kept_var_idx = kept_var_idx
|
2024-03-05 16:20:24 -08:00
|
|
|
|
self.mut = mut
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-02-27 18:26:12 -08:00
|
|
|
|
def _add_tokens_to_inputs(self, input_bufs):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if self.ordered_effects:
|
2023-09-18 02:49:53 -07:00
|
|
|
|
tokens = [
|
|
|
|
|
dispatch.runtime_tokens.get_token_input(eff, self._local_devices)
|
|
|
|
|
for eff in self.ordered_effects]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
input_bufs = [*tokens, *input_bufs]
|
2023-02-27 18:26:12 -08:00
|
|
|
|
return input_bufs
|
|
|
|
|
|
|
|
|
|
def _handle_token_bufs(self, token_bufs, sharded_token):
|
2023-09-18 02:49:53 -07:00
|
|
|
|
# token_bufs: Sequence[Sequence[tokenArray]], for each effect the returned
|
|
|
|
|
# token buffer (as a singleton list).
|
|
|
|
|
# sharded_token: ShardedToken, containing the RuntimeTokens for each device
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for i, device in enumerate(self._local_devices):
|
|
|
|
|
dispatch.runtime_tokens.set_output_runtime_token(
|
|
|
|
|
device, sharded_token.get_token(i))
|
|
|
|
|
for eff, token_buf in zip(self.ordered_effects, token_bufs):
|
2023-09-18 02:49:53 -07:00
|
|
|
|
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
|
|
|
|
def __call__(self, *args):
|
|
|
|
|
args = [x for i, x in enumerate(args) if i in self.kept_var_idx]
|
2024-03-05 16:20:24 -08:00
|
|
|
|
if self.mut:
|
|
|
|
|
args = [*args, *self.mut.in_mut]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
input_bufs = self.in_handler(args)
|
2023-03-13 17:09:06 -07:00
|
|
|
|
if (self.ordered_effects or self.has_unordered_effects
|
|
|
|
|
or self.has_host_callbacks):
|
|
|
|
|
input_bufs = self._add_tokens_to_inputs(input_bufs)
|
|
|
|
|
results = self.xla_executable.execute_sharded(
|
|
|
|
|
input_bufs, with_tokens=True
|
|
|
|
|
)
|
2023-09-18 02:49:53 -07:00
|
|
|
|
result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
|
|
|
|
|
len(self.ordered_effects))
|
|
|
|
|
sharded_runtime_token = results.consume_token()
|
|
|
|
|
self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-03-13 17:09:06 -07:00
|
|
|
|
results = self.xla_executable.execute_sharded(input_bufs)
|
|
|
|
|
if dispatch.needs_check_special():
|
|
|
|
|
out_arrays = results.disassemble_into_single_device_arrays()
|
|
|
|
|
for arrays in out_arrays:
|
|
|
|
|
dispatch.check_special(self.name, arrays)
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
out = self.out_handler(out_arrays)
|
|
|
|
|
else:
|
|
|
|
|
out = results.consume_with_handlers(self.out_handler.handlers)
|
2024-03-05 16:20:24 -08:00
|
|
|
|
if self.mut is None:
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
return out
|
|
|
|
|
else:
|
|
|
|
|
out_ = []
|
2024-03-05 16:20:24 -08:00
|
|
|
|
for i, o in zip(self.mut.out_mut, out):
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
if i is not None:
|
|
|
|
|
args[i]._buf = o
|
|
|
|
|
else:
|
|
|
|
|
out_.append(o)
|
|
|
|
|
return out_
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xla_pmap_p = core.MapPrimitive('xla_pmap')
|
|
|
|
|
xla_pmap = xla_pmap_p.bind
|
|
|
|
|
xla_pmap_p.def_impl(xla_pmap_impl)
|
|
|
|
|
|
|
|
|
|
def _pmap_partial_eval_custom_params_updater(
|
|
|
|
|
unks_in, inst_in, kept_outs_known, kept_outs_staged, num_res, params_known,
|
|
|
|
|
params_staged):
|
|
|
|
|
# prune inputs to jaxpr_known according to unks_in
|
|
|
|
|
donated_invars_known, _ = partition_list(unks_in, params_known['donated_invars'])
|
|
|
|
|
in_axes_known, _ = partition_list(unks_in, params_known['in_axes'])
|
|
|
|
|
_, out_axes_known = partition_list(kept_outs_known, params_known['out_axes'])
|
|
|
|
|
out_axes_known = out_axes_known + [0] * num_res
|
|
|
|
|
new_params_known = dict(params_known, in_axes=tuple(in_axes_known),
|
|
|
|
|
out_axes=tuple(out_axes_known),
|
|
|
|
|
donated_invars=tuple(donated_invars_known))
|
|
|
|
|
|
|
|
|
|
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
|
|
|
|
|
_, donated_invars_staged = partition_list(inst_in, params_staged['donated_invars'])
|
|
|
|
|
donated_invars_staged = [False] * num_res + donated_invars_staged
|
|
|
|
|
_, in_axes_staged = partition_list(inst_in, params_staged['in_axes'])
|
|
|
|
|
in_axes_staged = [0] * num_res + in_axes_staged
|
|
|
|
|
_, out_axes_staged = partition_list(kept_outs_staged, params_staged['out_axes'])
|
|
|
|
|
new_params_staged = dict(params_staged, in_axes=tuple(in_axes_staged),
|
|
|
|
|
out_axes=tuple(out_axes_staged),
|
|
|
|
|
donated_invars=tuple(donated_invars_staged))
|
|
|
|
|
return new_params_known, new_params_staged
|
|
|
|
|
|
|
|
|
|
def _pmap_partial_eval_custom_res_maker(params_known, aval):
|
|
|
|
|
return core.unmapped_aval(params_known['axis_size'], core.no_axis_name, 0, aval)
|
|
|
|
|
|
|
|
|
|
def _pmap_dce_rule(used_outputs, eqn):
|
|
|
|
|
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
|
2024-03-04 05:41:29 -08:00
|
|
|
|
axis_name = eqn.params["axis_name"]
|
|
|
|
|
with maybe_extend_axis_env(axis_name, eqn.params["global_axis_size"], None):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
|
|
|
|
|
_, donated_invars = partition_list(used_inputs, eqn.params['donated_invars'])
|
|
|
|
|
_, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
|
|
|
|
|
_, out_axes = partition_list(used_outputs, eqn.params['out_axes'])
|
|
|
|
|
new_params = dict(eqn.params, call_jaxpr=new_jaxpr,
|
|
|
|
|
donated_invars=tuple(donated_invars),
|
|
|
|
|
in_axes=tuple(in_axes), out_axes=tuple(out_axes))
|
|
|
|
|
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
|
|
|
|
|
return used_inputs, None
|
|
|
|
|
else:
|
2024-03-04 05:41:29 -08:00
|
|
|
|
effs = core.filter_named_axis_effects(new_jaxpr.effects, {axis_name})
|
2023-02-06 14:28:36 -08:00
|
|
|
|
new_eqn = pe.new_jaxpr_eqn(
|
|
|
|
|
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
|
|
|
|
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
2024-03-04 05:41:29 -08:00
|
|
|
|
eqn.primitive, new_params, effs, eqn.source_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return used_inputs, new_eqn
|
|
|
|
|
|
|
|
|
|
|
2023-08-08 14:39:57 -07:00
|
|
|
|
def _xla_call_partial_eval_update_params(
|
|
|
|
|
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
|
|
|
|
|
) -> core.ParamDict:
|
|
|
|
|
donated_invars = params['donated_invars']
|
|
|
|
|
if not kept_inputs and donated_invars:
|
|
|
|
|
# JaxprTrace.post_process_call creates a call with no input tracers
|
|
|
|
|
donated_invars = (False,) * num_new_inputs
|
|
|
|
|
else:
|
|
|
|
|
assert len(kept_inputs) == len(donated_invars)
|
|
|
|
|
# JaxprTrace.process_call drops known input tracers
|
|
|
|
|
donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept]
|
|
|
|
|
# Any new inputs are prepended to the left, so mark those as not donated.
|
|
|
|
|
donated_invars = [False] * num_new_inputs + donated_invars
|
|
|
|
|
return dict(params, donated_invars=tuple(donated_invars))
|
|
|
|
|
|
|
|
|
|
def xla_call_jvp_update_params(params, nz_tangents):
|
|
|
|
|
donated_invars = params['donated_invars']
|
|
|
|
|
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
|
|
|
|
|
new_donated_invars = (*donated_invars, *donated_tangents)
|
|
|
|
|
return dict(params, donated_invars=new_donated_invars)
|
|
|
|
|
|
|
|
|
|
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
|
|
|
|
|
donated_invars = params['donated_invars']
|
|
|
|
|
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
|
|
|
|
|
donated_cotangents = [False for nz in nonzero_cts if nz]
|
|
|
|
|
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# Set param update handlers to update `donated_invars` just like xla_call_p
|
2023-08-08 14:39:57 -07:00
|
|
|
|
pe.call_param_updaters[xla_pmap_p] = _xla_call_partial_eval_update_params
|
2023-02-06 14:28:36 -08:00
|
|
|
|
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
|
|
|
|
|
partial(pe.call_partial_eval_custom_rule,
|
|
|
|
|
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
|
|
|
|
|
res_aval=_pmap_partial_eval_custom_res_maker)
|
|
|
|
|
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
|
2023-08-08 14:39:57 -07:00
|
|
|
|
ad.call_param_updaters[xla_pmap_p] = xla_call_jvp_update_params
|
|
|
|
|
ad.call_transpose_param_updaters[xla_pmap_p] = _xla_call_transpose_update_params
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
|
|
|
|
|
|
|
|
|
|
def _pmap_axis_subst(params, subst, traverse):
|
|
|
|
|
if 'call_jaxpr' not in params:
|
|
|
|
|
return params
|
|
|
|
|
if not traverse:
|
|
|
|
|
return params
|
|
|
|
|
def shadowed_subst(name):
|
|
|
|
|
return (name,) if name in params['axis_name'] else subst(name)
|
|
|
|
|
with maybe_extend_axis_env(params['axis_name'],
|
|
|
|
|
params['global_axis_size'], None):
|
|
|
|
|
new_jaxpr = core.subst_axis_names_jaxpr(params['call_jaxpr'],
|
|
|
|
|
shadowed_subst)
|
|
|
|
|
return dict(params, call_jaxpr=new_jaxpr)
|
|
|
|
|
core.axis_substitution_rules[xla_pmap_p] = _pmap_axis_subst
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _unravel_index_hlo(axis_env):
|
|
|
|
|
div = mlir.ir_constant(
|
2023-02-28 12:40:30 -08:00
|
|
|
|
np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
return hlo.remainder(hlo.divide(hlo.replica_id(), div), mod)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
def _hlo_shard(aval, axis_env, xs, in_axis):
|
|
|
|
|
if aval is core.abstract_token:
|
|
|
|
|
return xs
|
|
|
|
|
elif isinstance(aval, core.ShapedArray):
|
|
|
|
|
x, = xs
|
|
|
|
|
dims = list(aval.shape)
|
|
|
|
|
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
|
|
|
|
idxs = [zero] * len(dims)
|
|
|
|
|
idxs.insert(in_axis, _unravel_index_hlo(axis_env))
|
|
|
|
|
dims_unsqueezed = dims.copy()
|
|
|
|
|
dims_unsqueezed.insert(in_axis, 1)
|
2023-11-17 11:46:24 -08:00
|
|
|
|
dynamic_slice_result = hlo.dynamic_slice(
|
2023-12-11 12:29:57 -08:00
|
|
|
|
x, idxs, mlir.dense_int_array(dims_unsqueezed))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return [
|
2023-11-17 11:46:24 -08:00
|
|
|
|
hlo.reshape(mlir.aval_to_ir_type(aval), dynamic_slice_result)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
|
|
|
|
|
2023-08-08 14:39:57 -07:00
|
|
|
|
def _axis_read(axis_env, axis_name):
|
|
|
|
|
try:
|
|
|
|
|
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
|
|
|
|
|
except ValueError:
|
|
|
|
|
raise NameError(f"unbound axis name: {axis_name}") from None
|
|
|
|
|
|
|
|
|
|
def axis_groups(axis_env: sharding_impls.AxisEnv, name) -> tuple[tuple[int, ...]]:
|
|
|
|
|
if not isinstance(name, (list, tuple)):
|
|
|
|
|
name = (name,)
|
|
|
|
|
mesh_axes = tuple(unsafe_map(partial(_axis_read, axis_env), name))
|
|
|
|
|
trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes))
|
|
|
|
|
assert not ragged
|
|
|
|
|
mesh_spec = axis_env.sizes + (trailing_size,)
|
|
|
|
|
return _axis_groups(mesh_spec, mesh_axes)
|
|
|
|
|
|
|
|
|
|
def _axis_groups(mesh_spec, mesh_axes):
|
|
|
|
|
"""Computes replica group ids for a collective performed over a subset of the mesh.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
mesh_spec: A sequence of integers representing the mesh shape.
|
|
|
|
|
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
|
|
|
|
|
indicating over which axes the collective is performed.
|
|
|
|
|
Returns:
|
|
|
|
|
A tuple of replica groups (i.e. tuples containing replica ids).
|
|
|
|
|
"""
|
|
|
|
|
iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec)
|
|
|
|
|
groups = np.reshape(
|
|
|
|
|
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
|
|
|
|
|
(math.prod(np.take(mesh_spec, mesh_axes)), -1))
|
|
|
|
|
return tuple(unsafe_map(tuple, groups.T))
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# TODO(b/110096942): more efficient gather
|
2023-10-12 13:32:47 -07:00
|
|
|
|
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if aval is core.abstract_token:
|
|
|
|
|
return xs
|
|
|
|
|
elif isinstance(aval, core.ShapedArray):
|
|
|
|
|
x, = xs
|
|
|
|
|
dims = list(aval.shape)
|
|
|
|
|
padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims)
|
|
|
|
|
padded = mlir.full_like_aval(ctx, 0, padded_aval)
|
|
|
|
|
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
|
|
|
|
idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims)
|
2023-12-11 12:29:57 -08:00
|
|
|
|
broadcast_result = hlo.broadcast(x, mlir.dense_int_array([1]))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
padded = hlo.dynamic_update_slice(padded, broadcast_result, idxs)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
replica_groups = mlir.dense_int_elements(
|
2023-08-08 14:39:57 -07:00
|
|
|
|
axis_groups(axis_env, axis_env.names[-1]))
|
2023-11-17 11:46:24 -08:00
|
|
|
|
out = hlo.cross_replica_sum(padded, replica_groups)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if out_axis != 0:
|
|
|
|
|
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
|
|
|
|
|
perm = list(range(1, len(dims)))
|
|
|
|
|
perm.insert(out_axis, 0)
|
|
|
|
|
transposed_dims = list(dims)
|
|
|
|
|
transposed_dims.insert(out_axis, axis_env.sizes[-1])
|
2023-12-11 12:29:57 -08:00
|
|
|
|
out = hlo.transpose(out, mlir.dense_int_array(perm))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
2023-08-08 14:39:57 -07:00
|
|
|
|
def _extend_axis_env(env: sharding_impls.AxisEnv, name, size: int):
|
|
|
|
|
return sharding_impls.AxisEnv(env.nreps, env.names + (name,),
|
|
|
|
|
env.sizes + (size,))
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _pmap_lowering(ctx, *in_nodes, axis_name,
|
|
|
|
|
axis_size, global_axis_size, devices, name,
|
|
|
|
|
call_jaxpr, backend=None, in_axes, out_axes,
|
2023-03-29 09:22:34 -07:00
|
|
|
|
donated_invars, is_explicit_global_axis_size):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
del donated_invars # Unused.
|
2023-10-25 10:39:47 -07:00
|
|
|
|
mlir.check_backend_matches(backend, ctx.module_context.platforms)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# We in-line here rather than generating a Call HLO as in the xla_call
|
|
|
|
|
# translation rule just because the extra tuple stuff is a pain.
|
|
|
|
|
if ctx.module_context.axis_env.names and devices is not None:
|
|
|
|
|
raise ValueError("Nested pmap with explicit devices argument.")
|
2023-08-08 14:39:57 -07:00
|
|
|
|
new_env = _extend_axis_env(ctx.module_context.axis_env, axis_name,
|
|
|
|
|
global_axis_size)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# Shard the in_nodes that are mapped
|
|
|
|
|
in_avals = [v.aval for v in call_jaxpr.invars]
|
|
|
|
|
in_nodes_sharded = (
|
|
|
|
|
_hlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis)
|
|
|
|
|
if in_axis is not None else mlir.wrap_singleton_ir_values(in_node)
|
|
|
|
|
for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes))
|
|
|
|
|
|
|
|
|
|
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
|
|
|
|
sub_ctx = ctx.module_context.replace(
|
2024-02-20 07:16:38 -08:00
|
|
|
|
axis_context=sharding_impls.ReplicaAxisContext(new_env))
|
|
|
|
|
sharded_outs, _ = mlir.jaxpr_subcomp(
|
|
|
|
|
sub_ctx, call_jaxpr,
|
|
|
|
|
ctx.name_stack.extend(util.wrap_name(name, 'pmap')),
|
|
|
|
|
mlir.TokenSet(), (), *in_nodes_sharded,
|
|
|
|
|
dim_var_values=ctx.dim_var_values)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_avals = [v.aval for v in call_jaxpr.outvars]
|
2023-10-12 13:32:47 -07:00
|
|
|
|
outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
|
|
|
|
|
return outs
|
|
|
|
|
|
|
|
|
|
mlir.register_lowering(xla_pmap_p, _pmap_lowering)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------- xmap -------------------
|
|
|
|
|
|
|
|
|
|
def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
|
|
|
|
|
assert isinstance(aval, ShapedArray)
|
|
|
|
|
shape = list(aval.shape)
|
|
|
|
|
named_shape = dict(aval.named_shape)
|
|
|
|
|
for name, axis in in_axes.items():
|
|
|
|
|
assert shape[axis] % axis_sizes[name] == 0
|
|
|
|
|
assert name not in named_shape
|
|
|
|
|
named_shape[name] = axis_sizes[name]
|
|
|
|
|
shape[axis] //= axis_sizes[name]
|
|
|
|
|
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
|
|
|
|
|
|
|
|
|
def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval):
|
|
|
|
|
assert isinstance(aval, ShapedArray)
|
|
|
|
|
shape = list(aval.shape)
|
|
|
|
|
named_shape = dict(aval.named_shape)
|
|
|
|
|
for name, axis in out_axes.items():
|
|
|
|
|
shape[axis] *= axis_sizes[name]
|
|
|
|
|
named_shape.pop(name, None) # The name might be missing --- it's a broadcast.
|
|
|
|
|
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
|
|
|
|
|
|
|
|
|
|
2023-03-10 10:07:37 -08:00
|
|
|
|
def mesh_local_to_global(mesh, axes: ArrayMapping, aval):
|
|
|
|
|
return untile_aval_nd(mesh.shape, axes,
|
|
|
|
|
tile_aval_nd(mesh.local_mesh.shape, axes, aval))
|
|
|
|
|
|
|
|
|
|
def mesh_global_to_local(mesh, axes: ArrayMapping, aval):
|
|
|
|
|
return untile_aval_nd(mesh.local_mesh.shape, axes,
|
|
|
|
|
tile_aval_nd(mesh.shape, axes, aval))
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
class SPMDBatchTrace(batching.BatchTrace):
|
|
|
|
|
def get_axis_primitive_batcher(self, primitive, frame):
|
|
|
|
|
if primitive in spmd_primitive_batchers:
|
|
|
|
|
return partial(spmd_primitive_batchers[primitive],
|
|
|
|
|
frame.size, frame.name, frame.main_trace.trace_type)
|
|
|
|
|
return super().get_axis_primitive_batcher(primitive, frame)
|
|
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
spmd_primitive_batchers: dict[core.Primitive, Callable] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vtile_by_mesh(fun: lu.WrappedFun,
|
|
|
|
|
mesh: Mesh,
|
|
|
|
|
in_axes: Sequence[ArrayMapping],
|
|
|
|
|
out_axes: Sequence[ArrayMapping]):
|
|
|
|
|
# We vectorize in reversed order, because vmap is often biased towards
|
|
|
|
|
# moving the batch axis to the front, and this way of stacking transforms
|
|
|
|
|
# will order the batch axes according to the mesh axis order.
|
|
|
|
|
# Not strictly necessary, but seems nicer than reversing it?
|
|
|
|
|
for name, size in reversed(mesh.shape.items()):
|
|
|
|
|
fun = batching.vtile(fun,
|
|
|
|
|
tuple(a.get(name, None) for a in in_axes),
|
|
|
|
|
tuple(a.get(name, None) for a in out_axes),
|
|
|
|
|
tile_size=size,
|
|
|
|
|
axis_name=name,
|
|
|
|
|
main_type=SPMDBatchTrace)
|
|
|
|
|
return fun
|
|
|
|
|
|
|
|
|
|
full_to_shard_p = core.Primitive('full_to_shard')
|
|
|
|
|
|
|
|
|
|
@full_to_shard_p.def_abstract_eval
|
|
|
|
|
def _full_to_shard_abstract_eval(x, axes, mesh, **_):
|
|
|
|
|
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
|
|
|
|
|
return tile_aval_nd(mesh.shape, axes, x)
|
|
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
|
def manual_proto(
|
|
|
|
|
aval: core.ShapedArray,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
manual_axes_set: frozenset[sharding_impls.MeshAxisName], mesh: Mesh):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
|
|
|
|
|
and all others as replicated.
|
|
|
|
|
"""
|
|
|
|
|
named_mesh_shape = mesh.shape
|
|
|
|
|
mesh_shape = list(named_mesh_shape.values())
|
|
|
|
|
axis_order = {axis: i for i, axis in enumerate(mesh.axis_names)}
|
|
|
|
|
|
2023-11-14 23:34:30 -05:00
|
|
|
|
manual_axes = sorted(manual_axes_set, key=str)
|
|
|
|
|
replicated_axes = [axis for axis in mesh.axis_names
|
|
|
|
|
if axis not in manual_axes_set]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
tad_perm = ([axis_order[a] for a in replicated_axes] +
|
|
|
|
|
[axis_order[a] for a in manual_axes])
|
|
|
|
|
tad_shape = [1] * aval.ndim
|
2023-04-13 11:48:11 -07:00
|
|
|
|
tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes]))
|
|
|
|
|
tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes]))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-13 11:48:11 -07:00
|
|
|
|
raw_mesh = np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
proto = xc.OpSharding()
|
|
|
|
|
proto.type = xc.OpSharding.Type.OTHER
|
|
|
|
|
proto.tile_assignment_dimensions = tad_shape
|
|
|
|
|
proto.tile_assignment_devices = list(raw_mesh.transpose(tad_perm).reshape(tad_shape).flat)
|
|
|
|
|
proto.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL]
|
|
|
|
|
return proto
|
|
|
|
|
|
|
|
|
|
@partial(mlir.register_lowering, full_to_shard_p)
|
2023-03-10 10:07:37 -08:00
|
|
|
|
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
manual_axes: frozenset[sharding_impls.MeshAxisName]):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# TODO: Can we short-circuit for replicated values? Probably not.
|
|
|
|
|
aval_in, = ctx.avals_in
|
|
|
|
|
aval_out, = ctx.avals_out
|
2024-01-02 11:13:57 -08:00
|
|
|
|
sharding_proto = (
|
|
|
|
|
sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes))
|
|
|
|
|
._to_xla_hlo_sharding(aval_in.ndim).to_proto())
|
2023-02-06 14:28:36 -08:00
|
|
|
|
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
|
2024-01-02 11:13:57 -08:00
|
|
|
|
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto,
|
|
|
|
|
unspecified_dims=unspecified_dims)
|
2023-02-07 11:16:01 -08:00
|
|
|
|
proto = manual_proto(aval_in, manual_axes, mesh)
|
2024-01-02 11:13:57 -08:00
|
|
|
|
return (mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto,
|
|
|
|
|
unspecified_dims=unspecified_dims),)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
shard_to_full_p = core.Primitive('shard_to_full')
|
|
|
|
|
|
|
|
|
|
@shard_to_full_p.def_abstract_eval
|
|
|
|
|
def _shard_to_full_abstract_eval(x, axes, mesh, **_):
|
|
|
|
|
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
|
|
|
|
|
return untile_aval_nd(mesh.shape, axes, x)
|
|
|
|
|
|
|
|
|
|
@partial(mlir.register_lowering, shard_to_full_p)
|
2023-04-05 09:38:37 +02:00
|
|
|
|
def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapping, mesh: Mesh,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
manual_axes: frozenset[sharding_impls.MeshAxisName]):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
aval_in, = ctx.avals_in
|
|
|
|
|
aval_out, = ctx.avals_out
|
2023-04-05 09:38:37 +02:00
|
|
|
|
proto = manual_proto(aval_in, manual_axes, mesh) # type: ignore
|
|
|
|
|
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) # type: ignore
|
2024-01-02 11:13:57 -08:00
|
|
|
|
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto,
|
|
|
|
|
unspecified_dims=unspecified_dims)
|
|
|
|
|
sharding_proto = (
|
|
|
|
|
sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes))
|
|
|
|
|
._to_xla_hlo_sharding(aval_out.ndim).to_proto())
|
|
|
|
|
return (mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto,
|
|
|
|
|
unspecified_dims),)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
@lu.transformation
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def vtile_manual(manual_axes: frozenset[sharding_impls.MeshAxisName],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
mesh: Mesh,
|
|
|
|
|
in_axes: Sequence[ArrayMapping],
|
|
|
|
|
out_axes: Sequence[ArrayMapping],
|
|
|
|
|
*args):
|
|
|
|
|
tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh, manual_axes=manual_axes)
|
|
|
|
|
for arg, axes in zip(args, in_axes)]
|
|
|
|
|
tiled_outs = yield tiled_args, {}
|
|
|
|
|
outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh, manual_axes=manual_axes)
|
|
|
|
|
for out, axes in zip(tiled_outs, out_axes)]
|
|
|
|
|
yield outs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class TileVectorize:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class TileManual:
|
2023-06-23 15:11:37 -07:00
|
|
|
|
manual_axes: frozenset[sharding_impls.MeshAxisName]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
TilingMethod = Union[TileVectorize, TileManual]
|
|
|
|
|
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
def check_if_any_auto(
|
2023-07-21 14:20:39 -04:00
|
|
|
|
shardings: Iterable[(sharding_impls.XLACompatibleSharding |
|
|
|
|
|
AUTO | UnspecifiedValue)]) -> bool:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for s in shardings:
|
2023-02-07 11:16:01 -08:00
|
|
|
|
if is_auto(s):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
|
class MismatchType(enum.Enum):
|
|
|
|
|
ARG_SHARDING = 0
|
|
|
|
|
OUT_SHARDING = 1
|
|
|
|
|
SHARDING_INSIDE_COMPUTATION = 2
|
|
|
|
|
CONTEXT_DEVICES = 3
|
|
|
|
|
IN_SHARDING = 4
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
if self.name == 'IN_SHARDING':
|
|
|
|
|
return 'explicit input sharding'
|
|
|
|
|
elif self.name == 'OUT_SHARDING':
|
|
|
|
|
return 'explicit output sharding'
|
|
|
|
|
elif self.name == 'CONTEXT_DEVICES':
|
|
|
|
|
return 'devices'
|
|
|
|
|
return f'{self.name}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class DeviceAssignmentMismatch:
|
|
|
|
|
da: Sequence[xc.Device]
|
|
|
|
|
m_type: MismatchType
|
2023-07-21 14:20:39 -04:00
|
|
|
|
source_info: dispatch.SourceInfo | None
|
2023-02-10 13:53:43 -08:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def device_ids(self) -> Sequence[int]:
|
|
|
|
|
return [d.id for d in self.da]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def platform(self) -> str:
|
|
|
|
|
return self.da[0].platform.upper()
|
|
|
|
|
|
|
|
|
|
def _maybe_api_name(self, api_name) -> str:
|
|
|
|
|
return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else ""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def source_info_str(self):
|
2024-02-13 18:26:41 -08:00
|
|
|
|
return (
|
|
|
|
|
"" if self.source_info is None
|
|
|
|
|
else f" at {source_info_util.summarize(self.source_info.source_info)}"
|
|
|
|
|
)
|
2023-02-10 13:53:43 -08:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _dev_ids_plat_str(self):
|
|
|
|
|
return f"device ids {self.device_ids} on platform {self.platform}"
|
|
|
|
|
|
2023-02-10 15:36:04 -08:00
|
|
|
|
def m_type_str(self, api_name):
|
2023-04-07 07:09:44 -07:00
|
|
|
|
return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}'
|
2023-02-10 15:36:04 -08:00
|
|
|
|
if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type)
|
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
|
def _str(self, api_name):
|
2023-02-10 15:36:04 -08:00
|
|
|
|
return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with "
|
2023-02-10 13:53:43 -08:00
|
|
|
|
f"{self._dev_ids_plat_str}{self.source_info_str}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeviceAssignmentMismatchError(Exception):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
ShardingInfo = tuple[
|
2023-05-20 22:59:52 -07:00
|
|
|
|
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO],
|
2023-12-11 13:59:29 +00:00
|
|
|
|
MismatchType,
|
|
|
|
|
Union[Any, None], # Any is dispatch.SourceInfo to avoid circular imports
|
|
|
|
|
]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-09 16:18:31 -08:00
|
|
|
|
|
|
|
|
|
def _get_default_device() -> xc.Device:
|
2023-10-09 07:28:18 -07:00
|
|
|
|
return config.default_device.value or xb.local_devices()[0]
|
2023-03-09 16:18:31 -08:00
|
|
|
|
|
|
|
|
|
|
2023-09-25 16:41:43 -07:00
|
|
|
|
class _thread_local_decorator(threading.local):
|
|
|
|
|
|
|
|
|
|
def __init__(self, fn):
|
|
|
|
|
self.fn = fn
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
|
return self.fn(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_thread_local_decorator
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _get_and_check_device_assignment(
|
2023-02-10 13:53:43 -08:00
|
|
|
|
shardings: Iterable[ShardingInfo],
|
2023-07-21 14:20:39 -04:00
|
|
|
|
devices: Sequence[xc.Device] | None,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[xc.Client, tuple[xc.Device, ...]]:
|
2023-02-10 13:53:43 -08:00
|
|
|
|
first_sharding_info = None
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if devices is None:
|
2023-04-13 08:02:53 -07:00
|
|
|
|
devices = ()
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-04-13 08:02:53 -07:00
|
|
|
|
devices = tuple(devices)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
|
for i, s_type, source_info in shardings:
|
2023-05-20 22:59:52 -07:00
|
|
|
|
if is_unspecified(i):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
continue
|
2023-05-20 22:59:52 -07:00
|
|
|
|
|
2023-02-10 13:53:43 -08:00
|
|
|
|
if first_sharding_info is None:
|
2023-05-20 22:59:52 -07:00
|
|
|
|
first_sharding_info = (
|
|
|
|
|
(i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore
|
|
|
|
|
else (i._device_assignment, s_type, source_info)) # type: ignore
|
|
|
|
|
arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if not devices:
|
2023-02-10 13:53:43 -08:00
|
|
|
|
if first_sharding_info[0] != arr_device_assignment:
|
|
|
|
|
raise DeviceAssignmentMismatchError([
|
|
|
|
|
DeviceAssignmentMismatch(*first_sharding_info),
|
|
|
|
|
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
|
|
|
|
if devices != arr_device_assignment:
|
2023-02-10 13:53:43 -08:00
|
|
|
|
raise DeviceAssignmentMismatchError([
|
|
|
|
|
DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None),
|
|
|
|
|
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
|
|
|
|
|
if first_sharding_info is None and devices:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
final_device_assignment = devices
|
2023-02-10 13:53:43 -08:00
|
|
|
|
elif first_sharding_info is None:
|
2023-04-13 08:02:53 -07:00
|
|
|
|
final_device_assignment = (_get_default_device(),)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-02-10 13:53:43 -08:00
|
|
|
|
final_device_assignment = first_sharding_info[0]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
|
|
|
|
|
|
2023-04-05 14:09:46 -07:00
|
|
|
|
MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
2023-08-08 10:51:38 -07:00
|
|
|
|
def prune_unused_inputs(
|
|
|
|
|
jaxpr: core.Jaxpr,
|
|
|
|
|
) -> tuple[core.Jaxpr, set[int], set[int]]:
|
|
|
|
|
used_outputs = [True] * len(jaxpr.outvars)
|
|
|
|
|
new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs)
|
|
|
|
|
kept_const_idx = {i for i, b in enumerate(used_consts) if b}
|
|
|
|
|
kept_var_idx = {i for i, b in enumerate(used_inputs) if b}
|
|
|
|
|
return new_jaxpr, kept_const_idx, kept_var_idx
|
|
|
|
|
|
|
|
|
|
|
2023-11-27 18:00:22 -08:00
|
|
|
|
@weakref_lru_cache
|
|
|
|
|
def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name,
|
|
|
|
|
keep_unused, donated_invars, auto_spmd_lowering):
|
2023-02-27 11:37:10 -08:00
|
|
|
|
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-11-27 18:00:22 -08:00
|
|
|
|
assert isinstance(closed_jaxpr, core.ClosedJaxpr)
|
|
|
|
|
jaxpr = closed_jaxpr.jaxpr
|
|
|
|
|
global_out_avals = closed_jaxpr.out_avals
|
|
|
|
|
consts = closed_jaxpr.consts
|
2023-03-01 10:04:59 -08:00
|
|
|
|
|
2023-05-20 22:59:52 -07:00
|
|
|
|
if (keep_unused or auto_spmd_lowering or
|
2023-04-04 15:20:32 -07:00
|
|
|
|
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
|
|
|
|
|
for a in global_in_avals)):
|
|
|
|
|
kept_var_idx = set(range(len(global_in_avals)))
|
|
|
|
|
else:
|
2023-08-08 10:51:38 -07:00
|
|
|
|
jaxpr, kept_const_idx, kept_var_idx = prune_unused_inputs(jaxpr)
|
2023-04-04 15:20:32 -07:00
|
|
|
|
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
|
|
|
|
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
|
|
|
|
|
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
|
|
|
|
|
del kept_const_idx
|
|
|
|
|
|
|
|
|
|
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
return (closed_jaxpr, global_in_avals, tuple(global_out_avals), donated_invars,
|
|
|
|
|
kept_var_idx, name_stack)
|
|
|
|
|
|
2024-03-05 16:20:24 -08:00
|
|
|
|
class MutationData(NamedTuple):
|
|
|
|
|
in_mut: list[core.MutableArray]
|
|
|
|
|
out_mut: list[int | None]
|
|
|
|
|
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
@weakref_lru_cache
|
|
|
|
|
def _discharge_refs(
|
|
|
|
|
jaxpr: core.ClosedJaxpr
|
2024-03-05 16:20:24 -08:00
|
|
|
|
) -> tuple[core.ClosedJaxpr, Sequence[int | None], MutationData]:
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
from jax._src.state.discharge import discharge_state
|
2024-03-05 16:20:24 -08:00
|
|
|
|
jaxpr, in_mut = _move_mutable_consts(jaxpr)
|
2024-03-01 11:07:45 -08:00
|
|
|
|
new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts))
|
|
|
|
|
count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end
|
|
|
|
|
inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals)
|
|
|
|
|
if isinstance(a, AbstractRef)}
|
|
|
|
|
outin_map = {j: i for i, j in inout_map.items()}
|
|
|
|
|
inout_aliases = tuple(map(inout_map.get, range(len(new_jaxpr.in_avals))))
|
2024-03-05 16:20:24 -08:00
|
|
|
|
out_mut = list(map(outin_map.get, range(len(new_jaxpr.out_avals))))
|
|
|
|
|
return new_jaxpr, inout_aliases, MutationData(in_mut, out_mut)
|
|
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
|
|
|
|
def _move_mutable_consts(
|
|
|
|
|
closed_jaxpr: core.ClosedJaxpr,
|
|
|
|
|
) -> tuple[core.ClosedJaxpr, list[core.MutableArray]]:
|
|
|
|
|
jaxpr = closed_jaxpr.jaxpr
|
|
|
|
|
hoist = [isinstance(c, core.MutableArray) for c in closed_jaxpr.consts]
|
|
|
|
|
consts, in_mut = partition_list(hoist, closed_jaxpr.consts)
|
|
|
|
|
constvars, mutvars = partition_list(hoist, jaxpr.constvars)
|
|
|
|
|
invars = (*jaxpr.invars, *mutvars)
|
|
|
|
|
effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns)
|
|
|
|
|
jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns,
|
|
|
|
|
effects, None)
|
|
|
|
|
return core.ClosedJaxpr(jaxpr, consts), in_mut
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class SemanticallyEqualShardings:
|
2023-07-21 14:20:39 -04:00
|
|
|
|
shardings: tuple[sharding_impls.GSPMDSharding | UnspecifiedValue, ...]
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
return hash(tuple(
|
2024-03-05 16:35:57 -08:00
|
|
|
|
(s._hlo_sharding_hash, s.memory_kind) # type: ignore
|
|
|
|
|
if isinstance(s, sharding_impls.GSPMDSharding) else s
|
2023-04-09 15:41:32 -07:00
|
|
|
|
for s in self.shardings))
|
|
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
if not isinstance(other, SemanticallyEqualShardings):
|
|
|
|
|
return False
|
2023-08-04 09:43:39 -07:00
|
|
|
|
return all(
|
|
|
|
|
(op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
|
2023-08-04 16:26:31 -07:00
|
|
|
|
and s.memory_kind == o.memory_kind)
|
2023-08-04 09:43:39 -07:00
|
|
|
|
if (isinstance(s, sharding_impls.GSPMDSharding) and
|
|
|
|
|
isinstance(o, sharding_impls.GSPMDSharding))
|
|
|
|
|
else s == o
|
|
|
|
|
for s, o in zip(self.shardings, other.shardings)
|
|
|
|
|
)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
|
2023-08-08 10:51:38 -07:00
|
|
|
|
def _raise_warnings_or_errors_for_jit_of_pmap(
|
|
|
|
|
nreps: int, backend: xc.Client, name: str, jaxpr: core.Jaxpr) -> None:
|
|
|
|
|
if nreps > 1:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
f"The jitted function {name} includes a pmap. Using "
|
|
|
|
|
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
|
|
|
|
|
"does not preserve sharded data representations and instead collects "
|
|
|
|
|
"input and output arrays onto a single device. "
|
|
|
|
|
"Consider removing the outer jit unless you know what you're doing. "
|
|
|
|
|
"See https://github.com/google/jax/issues/2926.")
|
|
|
|
|
|
|
|
|
|
if nreps > xb.device_count(backend):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"compiling computation `{name}` that requires {nreps} replicas, but "
|
|
|
|
|
f"only {xb.device_count(backend)} XLA devices are available.")
|
|
|
|
|
|
|
|
|
|
if xb.process_count() > 1 and (
|
|
|
|
|
nreps > 1 or dispatch.jaxpr_has_primitive(jaxpr, "xla_pmap")
|
|
|
|
|
):
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
|
|
|
|
|
"extra data movement anyway, so maybe you don't want it after all).")
|
|
|
|
|
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
@weakref_lru_cache
|
|
|
|
|
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
|
|
|
|
semantic_in_shardings, semantic_out_shardings,
|
2023-12-08 14:35:27 -08:00
|
|
|
|
in_layouts, out_layouts, num_devices, device_assignment,
|
2023-09-11 11:54:29 -07:00
|
|
|
|
donated_invars, name_stack, all_default_mem_kind,
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
inout_aliases: None | tuple[None | int, ...],
|
2023-09-28 12:44:14 +02:00
|
|
|
|
lowering_parameters: mlir.LoweringParameters):
|
2023-04-04 15:20:32 -07:00
|
|
|
|
jaxpr = closed_jaxpr.jaxpr
|
2023-04-09 15:41:32 -07:00
|
|
|
|
in_shardings = semantic_in_shardings.shardings
|
|
|
|
|
out_shardings = semantic_out_shardings.shardings
|
|
|
|
|
global_in_avals = closed_jaxpr.in_avals
|
|
|
|
|
global_out_avals = closed_jaxpr.out_avals
|
2023-04-04 15:20:32 -07:00
|
|
|
|
|
2023-10-09 07:28:18 -07:00
|
|
|
|
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
2023-04-17 07:52:56 -07:00
|
|
|
|
if logger.isEnabledFor(log_priority):
|
|
|
|
|
logger.log(log_priority,
|
|
|
|
|
"Compiling %s for with global shapes and types %s. "
|
|
|
|
|
"Argument mapping: %s.",
|
|
|
|
|
fun_name, global_in_avals, in_shardings)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
# Look at the number of replcas present in the jaxpr. In
|
|
|
|
|
# lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is
|
|
|
|
|
# handled here so as to deprecate the lower_xla_callable codepath when
|
|
|
|
|
# `jax.Array` is turned on by default.
|
|
|
|
|
# TODO(yashkatariya): Remove this when `jit(pmap)` is removed.
|
|
|
|
|
nreps = dispatch.jaxpr_replicas(jaxpr)
|
2023-08-08 10:51:38 -07:00
|
|
|
|
_raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, fun_name, jaxpr)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
|
in_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None
|
|
|
|
|
out_mlir_shardings: list[sharding_impls.XLACompatibleSharding | None] | None
|
2023-04-09 15:41:32 -07:00
|
|
|
|
axis_ctx: mlir.AxisContext
|
|
|
|
|
|
|
|
|
|
if nreps == 1:
|
2023-04-14 13:55:52 -07:00
|
|
|
|
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
|
|
|
|
|
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
replicated_args = [False] * len(global_in_avals)
|
2023-12-08 14:35:27 -08:00
|
|
|
|
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment)
|
|
|
|
|
num_partitions = num_devices
|
2023-04-09 15:41:32 -07:00
|
|
|
|
else:
|
|
|
|
|
# This path is triggered for `jit(pmap)` cases.
|
|
|
|
|
replicated_args = None
|
2023-04-14 13:55:52 -07:00
|
|
|
|
in_mlir_shardings = None
|
|
|
|
|
out_mlir_shardings = None
|
2023-04-10 10:15:08 -07:00
|
|
|
|
axis_env = sharding_impls.AxisEnv(nreps, (), ())
|
|
|
|
|
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
|
2023-04-13 08:55:01 -07:00
|
|
|
|
num_partitions = 1
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
module_name = f"{api_name}_{fun_name}"
|
|
|
|
|
|
2023-12-08 14:35:27 -08:00
|
|
|
|
if num_devices > 1:
|
2023-09-18 02:49:53 -07:00
|
|
|
|
unsupported_effects = effects.ordered_effects.filter_in(closed_jaxpr.effects)
|
|
|
|
|
unsupported_effects = effects.shardable_ordered_effects.filter_not_in(
|
|
|
|
|
unsupported_effects)
|
|
|
|
|
if len(unsupported_effects) > 0:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The following ordered effects are not supported for "
|
|
|
|
|
f"more than 1 device: {unsupported_effects}")
|
2023-04-09 15:41:32 -07:00
|
|
|
|
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
2023-05-15 08:07:31 -07:00
|
|
|
|
|
|
|
|
|
with dispatch.log_elapsed_time(
|
2023-05-15 09:15:22 -07:00
|
|
|
|
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
2023-05-15 08:07:31 -07:00
|
|
|
|
lowering_result = mlir.lower_jaxpr_to_module(
|
|
|
|
|
module_name,
|
|
|
|
|
closed_jaxpr,
|
2023-09-28 12:44:14 +02:00
|
|
|
|
ordered_effects=ordered_effects,
|
|
|
|
|
backend_or_name=backend,
|
2023-05-15 08:07:31 -07:00
|
|
|
|
# Optionally, override the lowering platform
|
2023-10-25 10:39:47 -07:00
|
|
|
|
platforms=lowering_parameters.platforms or (backend.platform,),
|
2023-09-28 12:44:14 +02:00
|
|
|
|
axis_context=axis_ctx,
|
|
|
|
|
name_stack=name_stack,
|
|
|
|
|
donated_args=donated_invars,
|
2023-05-15 08:07:31 -07:00
|
|
|
|
replicated_args=replicated_args,
|
|
|
|
|
arg_shardings=in_mlir_shardings,
|
|
|
|
|
result_shardings=out_mlir_shardings,
|
2023-11-15 08:48:17 -08:00
|
|
|
|
in_layouts=in_layouts,
|
|
|
|
|
out_layouts=out_layouts,
|
2023-05-15 08:07:31 -07:00
|
|
|
|
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
|
|
|
|
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
|
|
|
|
num_replicas=nreps,
|
2023-07-11 10:23:48 -07:00
|
|
|
|
num_partitions=num_partitions,
|
2023-09-11 11:54:29 -07:00
|
|
|
|
all_default_mem_kind=all_default_mem_kind,
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
input_output_aliases=inout_aliases,
|
2023-09-28 12:44:14 +02:00
|
|
|
|
lowering_parameters=lowering_parameters)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
|
2023-04-13 16:57:03 -07:00
|
|
|
|
unordered_effects = list(
|
|
|
|
|
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
2023-04-21 14:37:52 -07:00
|
|
|
|
return (lowering_result.module, lowering_result.keepalive,
|
|
|
|
|
lowering_result.host_callbacks, unordered_effects, ordered_effects,
|
2023-05-31 11:00:08 +03:00
|
|
|
|
nreps, tuple_args, lowering_result.shape_poly_state)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=2048)
|
2023-08-09 16:57:28 -07:00
|
|
|
|
def _create_da_object( # pytype: disable=invalid-annotation
|
2024-01-18 12:54:54 -08:00
|
|
|
|
device_assignment: tuple[xc.Device, ...]) -> xc.DeviceList: # type: ignore
|
|
|
|
|
return xc.DeviceList(device_assignment)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
|
2023-09-11 11:54:29 -07:00
|
|
|
|
def jaxpr_transfer_mem_kinds(
|
|
|
|
|
jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]:
|
2023-08-22 22:07:24 -07:00
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
|
if (eqn.primitive is dispatch.device_put_p and
|
|
|
|
|
isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)):
|
2023-09-11 11:54:29 -07:00
|
|
|
|
yield eqn.params['device']
|
2023-08-22 22:07:24 -07:00
|
|
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
2023-09-11 11:54:29 -07:00
|
|
|
|
yield from jaxpr_transfer_mem_kinds(subjaxpr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def are_all_shardings_default_mem_kind(da_object, shardings):
|
|
|
|
|
try:
|
|
|
|
|
default_mem_kind = da_object.default_memory_kind
|
|
|
|
|
except:
|
|
|
|
|
return True
|
|
|
|
|
for i in shardings:
|
|
|
|
|
if is_unspecified_or_auto(i):
|
|
|
|
|
continue
|
|
|
|
|
if i.memory_kind != default_mem_kind:
|
|
|
|
|
return False
|
|
|
|
|
return True
|
2023-08-22 22:07:24 -07:00
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
|
MaybeLayout = Sequence[Union[XLACompatibleLayout, LayoutRequest, None]]
|
2023-08-22 22:07:24 -07:00
|
|
|
|
|
2023-11-27 22:38:46 -08:00
|
|
|
|
|
|
|
|
|
class AllArgsInfo(NamedTuple):
|
|
|
|
|
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
|
|
|
|
|
in_avals: Sequence[core.ShapedArray]
|
|
|
|
|
in_shardings: Any
|
|
|
|
|
debug_info: core.JaxprDebugInfo | None
|
|
|
|
|
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
@profiler.annotate_function
|
|
|
|
|
def lower_sharding_computation(
|
2023-11-27 18:00:22 -08:00
|
|
|
|
closed_jaxpr: core.ClosedJaxpr,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
api_name: str,
|
|
|
|
|
fun_name: str,
|
|
|
|
|
in_shardings: Sequence[MaybeSharding],
|
2023-11-28 14:35:00 -08:00
|
|
|
|
out_shardings: Sequence[MaybeSharding],
|
2023-04-09 15:41:32 -07:00
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
global_in_avals: Sequence[core.ShapedArray],
|
|
|
|
|
*,
|
|
|
|
|
keep_unused: bool,
|
2023-04-26 15:54:50 -07:00
|
|
|
|
inline: bool,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
devices_from_context: Sequence[xc.Device] | None = None,
|
2023-09-28 12:44:14 +02:00
|
|
|
|
lowering_parameters: mlir.LoweringParameters,
|
2023-11-15 08:48:17 -08:00
|
|
|
|
in_layouts: MaybeLayout,
|
2023-11-28 14:35:00 -08:00
|
|
|
|
out_layouts: MaybeLayout,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
) -> MeshComputation:
|
|
|
|
|
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
|
|
|
|
|
2023-04-10 10:15:08 -07:00
|
|
|
|
The caller of this code can pass in a singleton UNSPECIFIED because the
|
2023-04-09 15:41:32 -07:00
|
|
|
|
number of out_avals might not be known at that time and
|
|
|
|
|
lower_sharding_computation calculates the number of out_avals so it can apply
|
2023-04-10 10:15:08 -07:00
|
|
|
|
the singleton UNSPECIFIED to all out_avals.
|
2023-04-09 15:41:32 -07:00
|
|
|
|
"""
|
|
|
|
|
# 1. Trace to jaxpr and preprocess/verify it
|
2023-11-28 14:35:00 -08:00
|
|
|
|
auto_spmd_lowering = check_if_any_auto(
|
|
|
|
|
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
|
2023-05-20 22:59:52 -07:00
|
|
|
|
|
2023-11-27 22:38:46 -08:00
|
|
|
|
all_args_info = AllArgsInfo(global_in_avals, in_shardings,
|
|
|
|
|
closed_jaxpr.jaxpr.debug_info)
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
|
2023-11-27 18:00:22 -08:00
|
|
|
|
kept_var_idx, name_stack) = _dce_jaxpr(
|
|
|
|
|
closed_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
|
2023-05-20 22:59:52 -07:00
|
|
|
|
donated_invars, auto_spmd_lowering)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
|
2023-11-15 08:48:17 -08:00
|
|
|
|
in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects):
|
2024-03-05 16:20:24 -08:00
|
|
|
|
closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr)
|
|
|
|
|
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut.in_mut)
|
|
|
|
|
in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut)
|
|
|
|
|
donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut)
|
|
|
|
|
out_layouts_ = iter(zip(out_shardings, out_layouts))
|
|
|
|
|
out_shardings, out_layouts = unzip2(
|
|
|
|
|
next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i])
|
|
|
|
|
for i in mut.out_mut)
|
|
|
|
|
assert next(out_layouts_, None) is None
|
|
|
|
|
# TODO(yashkatariya): remove global_in_avals / global_out_avals
|
|
|
|
|
global_in_avals = closed_jaxpr.in_avals
|
|
|
|
|
global_out_avals = closed_jaxpr.out_avals
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
else:
|
2024-03-05 16:20:24 -08:00
|
|
|
|
inout_aliases = mut = None
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
|
|
|
|
|
jaxpr = closed_jaxpr.jaxpr
|
2023-11-15 08:48:17 -08:00
|
|
|
|
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
|
|
|
|
len(out_shardings), len(out_layouts), len(global_out_avals))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
|
|
|
|
# should be the same.
|
|
|
|
|
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
|
2023-02-10 13:53:43 -08:00
|
|
|
|
backend, device_assignment = _get_and_check_device_assignment(
|
2024-01-22 13:44:34 -08:00
|
|
|
|
it.chain(
|
|
|
|
|
((i, MismatchType.ARG_SHARDING, None) for i in util.stable_unique(in_shardings)),
|
|
|
|
|
((o, MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings)),
|
|
|
|
|
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
|
|
|
|
|
for js, source_info in util.stable_unique(jaxpr_sharding))),
|
2023-02-10 13:53:43 -08:00
|
|
|
|
devices_from_context)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2024-02-05 17:27:32 -08:00
|
|
|
|
# TODO(yashkatariya): Enable this when offload APIs are stable.
|
|
|
|
|
# transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
|
2023-09-11 11:54:29 -07:00
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
committed = bool(
|
|
|
|
|
devices_from_context or
|
|
|
|
|
len(device_assignment) > 1 or
|
2023-04-10 10:15:08 -07:00
|
|
|
|
any(not is_unspecified(i) for i in in_shardings) or
|
|
|
|
|
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
|
2024-02-05 17:27:32 -08:00
|
|
|
|
any(not is_unspecified(o) for o in out_shardings))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2024-02-27 09:06:21 -08:00
|
|
|
|
gs = GSPMDSharding.get_replicated(device_assignment)
|
2024-03-14 15:09:07 -07:00
|
|
|
|
if xla_extension_version < 241 or hasattr(backend, "compile_replicated"):
|
|
|
|
|
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
|
2024-02-27 09:06:21 -08:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
da_object = _create_da_object(tuple(device_assignment))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-09-11 11:54:29 -07:00
|
|
|
|
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
|
|
|
|
da_object,
|
2024-02-05 17:27:32 -08:00
|
|
|
|
it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding])) # type: ignore
|
2023-09-11 11:54:29 -07:00
|
|
|
|
|
2023-08-09 16:57:28 -07:00
|
|
|
|
if not da_object.is_fully_addressable: # type: ignore
|
2023-10-09 07:28:18 -07:00
|
|
|
|
if inline and config.spmd_mode.value != 'allow_all':
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Running operations on `Array`s that are not fully addressable by this "
|
|
|
|
|
"process (i.e. `Array`s with data sharded across multiple devices and "
|
|
|
|
|
"processes.) is dangerous. It’s very important that all processes run "
|
|
|
|
|
"the same cross-process computations in the same order otherwise it "
|
|
|
|
|
"can lead to hangs. "
|
|
|
|
|
"If you’re not already familiar with JAX’s multi-process "
|
|
|
|
|
"programming model, please read "
|
|
|
|
|
"https://jax.readthedocs.io/en/latest/multi_process.html. "
|
|
|
|
|
"To fix this error, run your `jitted` computation inside "
|
|
|
|
|
"`with jax.spmd_mode('allow_all'):` context manager.")
|
|
|
|
|
|
|
|
|
|
# 2. Build up the HLO
|
2023-04-09 15:41:32 -07:00
|
|
|
|
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
|
2023-11-28 14:35:00 -08:00
|
|
|
|
semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
|
2023-12-08 16:31:11 -08:00
|
|
|
|
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
2023-12-12 13:33:27 -08:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
2023-05-31 11:00:08 +03:00
|
|
|
|
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
2023-04-09 15:41:32 -07:00
|
|
|
|
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
2023-12-08 14:35:27 -08:00
|
|
|
|
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
|
2023-12-14 09:13:43 -08:00
|
|
|
|
tuple(da_object) if prim_requires_devices else None, donated_invars,
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
name_stack, all_default_mem_kind, inout_aliases,
|
|
|
|
|
lowering_parameters=lowering_parameters)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# backend and device_assignment is passed through to MeshExecutable because
|
|
|
|
|
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
|
|
|
|
# to get the device_assignment and backend. So pass it to MeshExecutable
|
|
|
|
|
# because we calculate the device_assignment and backend before in_shardings,
|
|
|
|
|
# etc are pruned.
|
|
|
|
|
return MeshComputation(
|
|
|
|
|
str(name_stack),
|
|
|
|
|
module,
|
|
|
|
|
donated_invars,
|
|
|
|
|
global_in_avals=global_in_avals,
|
|
|
|
|
global_out_avals=global_out_avals,
|
|
|
|
|
in_shardings=in_shardings,
|
|
|
|
|
out_shardings=out_shardings,
|
|
|
|
|
spmd_lowering=True,
|
|
|
|
|
tuple_args=tuple_args,
|
2023-05-20 22:59:52 -07:00
|
|
|
|
auto_spmd_lowering=auto_spmd_lowering,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
unordered_effects=unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects,
|
|
|
|
|
host_callbacks=host_callbacks,
|
|
|
|
|
keepalive=keepalive,
|
|
|
|
|
kept_var_idx=kept_var_idx,
|
2024-03-05 16:20:24 -08:00
|
|
|
|
mut=mut,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
backend=backend,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
device_assignment=da_object,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
committed=committed,
|
2023-11-15 08:48:17 -08:00
|
|
|
|
in_layouts=in_layouts,
|
|
|
|
|
out_layouts=out_layouts,
|
2023-04-19 12:35:15 -07:00
|
|
|
|
pmap_nreps=nreps,
|
2023-09-11 11:54:29 -07:00
|
|
|
|
shape_poly_state=shape_poly_state,
|
2023-11-27 22:38:46 -08:00
|
|
|
|
all_default_mem_kind=all_default_mem_kind,
|
|
|
|
|
all_args_info=all_args_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
2023-04-14 13:55:52 -07:00
|
|
|
|
def _to_logical_sharding(
|
2023-07-21 14:20:39 -04:00
|
|
|
|
aval: core.AbstractValue, sharding: MaybeSharding | AUTO
|
|
|
|
|
) -> sharding_impls.XLACompatibleSharding | None:
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if is_unspecified(sharding) or is_auto(sharding):
|
2023-04-05 14:09:46 -07:00
|
|
|
|
return None
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
elif isinstance(aval, (ShapedArray, DShapedArray, AbstractRef)):
|
2023-04-05 14:09:46 -07:00
|
|
|
|
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
|
2023-04-14 13:55:52 -07:00
|
|
|
|
return sharding
|
2023-04-05 14:09:46 -07:00
|
|
|
|
elif isinstance(aval, core.AbstractToken):
|
|
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(aval)
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
@profiler.annotate_function
|
|
|
|
|
def lower_mesh_computation(
|
2023-07-21 14:20:39 -04:00
|
|
|
|
fun_or_jaxpr: lu.WrappedFun | core.ClosedJaxpr,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
api_name: str,
|
|
|
|
|
fun_name: str,
|
|
|
|
|
mesh: Mesh,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
in_shardings: Sequence[sharding_impls.NamedSharding | AUTO],
|
|
|
|
|
out_shardings: Sequence[(sharding_impls.NamedSharding | AUTO |
|
|
|
|
|
UnspecifiedValue)],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
donated_invars: Sequence[bool],
|
|
|
|
|
spmd_lowering: bool,
|
|
|
|
|
global_in_avals: Sequence[core.ShapedArray],
|
2023-07-21 14:20:39 -04:00
|
|
|
|
tiling_method: TilingMethod | None,
|
2023-09-28 12:44:14 +02:00
|
|
|
|
lowering_parameters: mlir.LoweringParameters) -> MeshComputation:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
assert not mesh.empty
|
|
|
|
|
backend = xb.get_device_backend(mesh.devices.flat[0])
|
2023-02-27 11:37:10 -08:00
|
|
|
|
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
global_axis_sizes = mesh.shape
|
|
|
|
|
|
2023-10-09 07:28:18 -07:00
|
|
|
|
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
2023-04-17 07:52:56 -07:00
|
|
|
|
if logger.isEnabledFor(log_priority):
|
|
|
|
|
logger.log(log_priority,
|
|
|
|
|
"Compiling %s for %s mesh with global shapes and types %s. "
|
|
|
|
|
"Argument mapping: %s.",
|
|
|
|
|
fun_name, tuple(global_axis_sizes.items()), global_in_avals,
|
|
|
|
|
in_shardings)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# 1. Trace to jaxpr and preprocess/verify it
|
|
|
|
|
if spmd_lowering:
|
2023-06-23 15:11:37 -07:00
|
|
|
|
manual_axes: frozenset[MeshAxisName] = frozenset()
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
|
|
|
|
|
if tiling_method is not None:
|
|
|
|
|
if isinstance(tiling_method, TileVectorize):
|
|
|
|
|
tiling_transform = vtile_by_mesh
|
|
|
|
|
elif isinstance(tiling_method, TileManual):
|
|
|
|
|
tiling_transform = lambda f, *args: vtile_manual(f, tiling_method.manual_axes, *args) # type: ignore
|
|
|
|
|
manual_axes = tiling_method.manual_axes
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
|
|
|
|
|
assert not callable(out_shardings)
|
2023-03-01 10:04:59 -08:00
|
|
|
|
assert isinstance(fun_or_jaxpr, lu.WrappedFun)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which
|
|
|
|
|
# is why `.spec` can be accessed.
|
2023-03-01 10:04:59 -08:00
|
|
|
|
fun_or_jaxpr = tiling_transform(
|
|
|
|
|
fun_or_jaxpr, mesh, [get_array_mapping(i.spec) for i in in_shardings], # type: ignore
|
2023-02-07 11:16:01 -08:00
|
|
|
|
[get_array_mapping(o.spec) for o in out_shardings]) # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
in_jaxpr_avals = global_in_avals
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(tiling_method, TileVectorize)
|
|
|
|
|
# In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is
|
|
|
|
|
# why `.spec` can be accessed.
|
2023-02-07 11:16:01 -08:00
|
|
|
|
in_tiled_avals = [tile_aval_nd(global_axis_sizes, get_array_mapping(i.spec), aval) # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for aval, i in safe_zip(global_in_avals, in_shardings)]
|
|
|
|
|
in_jaxpr_avals = in_tiled_avals
|
2023-03-01 10:04:59 -08:00
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
with core.extend_axis_env_nd(mesh.shape.items()):
|
2023-03-01 10:04:59 -08:00
|
|
|
|
if isinstance(fun_or_jaxpr, lu.WrappedFun):
|
|
|
|
|
with dispatch.log_elapsed_time(
|
2023-05-15 09:15:22 -07:00
|
|
|
|
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT):
|
2023-03-01 10:04:59 -08:00
|
|
|
|
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(
|
|
|
|
|
fun_or_jaxpr, in_jaxpr_avals)
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
|
|
|
|
|
jaxpr = fun_or_jaxpr.jaxpr
|
|
|
|
|
out_jaxpr_avals = fun_or_jaxpr.out_avals
|
|
|
|
|
consts = fun_or_jaxpr.consts
|
|
|
|
|
|
2023-11-27 22:38:46 -08:00
|
|
|
|
all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info)
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
assert len(out_shardings) == len(out_jaxpr_avals)
|
|
|
|
|
if spmd_lowering:
|
|
|
|
|
global_out_avals = out_jaxpr_avals
|
|
|
|
|
else:
|
|
|
|
|
# In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is
|
|
|
|
|
# why `.spec` can be accessed.
|
2023-02-07 11:16:01 -08:00
|
|
|
|
global_out_avals = [untile_aval_nd(global_axis_sizes, get_array_mapping(o.spec), aval) # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for aval, o in safe_zip(out_jaxpr_avals, out_shardings)]
|
2023-03-02 20:49:51 -08:00
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
_sanitize_mesh_jaxpr(jaxpr)
|
|
|
|
|
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
|
|
|
|
|
|
|
|
|
# 2. Build up the HLO
|
|
|
|
|
tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform)
|
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
|
in_partitions: list[sharding_impls.XLACompatibleSharding | None] | None
|
|
|
|
|
out_partitions: list[sharding_impls.XLACompatibleSharding | None] | None
|
2023-02-06 14:28:36 -08:00
|
|
|
|
axis_ctx: mlir.AxisContext
|
|
|
|
|
if spmd_lowering:
|
2023-04-14 13:55:52 -07:00
|
|
|
|
in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings)
|
|
|
|
|
out_partitions = map(_to_logical_sharding, global_out_avals, out_shardings)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
replicated_args = [False] * len(in_jaxpr_avals)
|
2023-04-10 10:15:08 -07:00
|
|
|
|
axis_ctx = sharding_impls.SPMDAxisContext(mesh, manual_axes)
|
2023-04-13 08:55:01 -07:00
|
|
|
|
num_replicas = 1
|
|
|
|
|
num_partitions = mesh.devices.size
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-02-07 11:16:01 -08:00
|
|
|
|
replicated_args = [not get_array_mapping(i.spec) for i in in_shardings] # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
in_partitions = None
|
|
|
|
|
out_partitions = None
|
2023-04-10 10:15:08 -07:00
|
|
|
|
axis_env = sharding_impls.AxisEnv(
|
|
|
|
|
nreps=mesh.size,
|
|
|
|
|
names=tuple(global_axis_sizes.keys()),
|
|
|
|
|
sizes=tuple(global_axis_sizes.values()))
|
|
|
|
|
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
|
2023-04-13 08:55:01 -07:00
|
|
|
|
num_replicas = mesh.devices.size
|
|
|
|
|
num_partitions = 1
|
2024-03-04 05:41:29 -08:00
|
|
|
|
jaxpr = core.remove_named_axis_effects(jaxpr, mesh.axis_names)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
|
|
|
|
module_name = f"{api_name}_{fun_name}"
|
|
|
|
|
with core.extend_axis_env_nd(mesh.shape.items()):
|
2023-02-01 17:50:00 -08:00
|
|
|
|
if any(effects.ordered_effects.contains(eff) for eff
|
|
|
|
|
in closed_jaxpr.effects):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
raise ValueError("Ordered effects not supported in mesh computations.")
|
2023-02-01 17:50:00 -08:00
|
|
|
|
unordered_effects = list(effects.ordered_effects.filter_not_in(
|
|
|
|
|
closed_jaxpr.effects))
|
|
|
|
|
ordered_effects = list(effects.ordered_effects.filter_in(
|
|
|
|
|
closed_jaxpr.effects))
|
2023-05-15 08:07:31 -07:00
|
|
|
|
with dispatch.log_elapsed_time(
|
2023-05-15 09:15:22 -07:00
|
|
|
|
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
2023-05-15 08:07:31 -07:00
|
|
|
|
lowering_result = mlir.lower_jaxpr_to_module(
|
|
|
|
|
module_name,
|
|
|
|
|
closed_jaxpr,
|
2023-09-28 12:44:14 +02:00
|
|
|
|
ordered_effects=ordered_effects,
|
|
|
|
|
backend_or_name=backend,
|
2023-10-25 10:39:47 -07:00
|
|
|
|
platforms=lowering_parameters.platforms or (backend.platform,),
|
2023-09-28 12:44:14 +02:00
|
|
|
|
axis_context=axis_ctx,
|
|
|
|
|
name_stack=name_stack,
|
|
|
|
|
donated_args=donated_invars,
|
2023-05-15 08:07:31 -07:00
|
|
|
|
replicated_args=replicated_args,
|
|
|
|
|
arg_shardings=in_partitions,
|
|
|
|
|
result_shardings=out_partitions,
|
|
|
|
|
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
|
|
|
|
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
|
|
|
|
num_replicas=num_replicas,
|
2023-09-28 12:44:14 +02:00
|
|
|
|
num_partitions=num_partitions,
|
|
|
|
|
lowering_parameters=lowering_parameters)
|
2023-05-15 08:07:31 -07:00
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return MeshComputation(
|
|
|
|
|
str(name_stack),
|
2023-04-21 14:37:52 -07:00
|
|
|
|
lowering_result.module,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
donated_invars,
|
|
|
|
|
global_in_avals=global_in_avals,
|
|
|
|
|
global_out_avals=global_out_avals,
|
|
|
|
|
in_shardings=in_shardings,
|
|
|
|
|
out_shardings=out_shardings,
|
|
|
|
|
spmd_lowering=spmd_lowering,
|
|
|
|
|
tuple_args=tuple_args,
|
2023-05-20 22:59:52 -07:00
|
|
|
|
auto_spmd_lowering=False,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
unordered_effects=unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects,
|
2023-04-21 14:37:52 -07:00
|
|
|
|
host_callbacks=lowering_result.host_callbacks,
|
|
|
|
|
keepalive=lowering_result.keepalive,
|
2023-03-02 22:12:53 -08:00
|
|
|
|
kept_var_idx=set(range(len(global_in_avals))),
|
2023-02-06 14:28:36 -08:00
|
|
|
|
backend=backend,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
|
2023-04-19 12:35:15 -07:00
|
|
|
|
committed=True,
|
2023-11-15 08:48:17 -08:00
|
|
|
|
in_layouts=(None,) * len(global_in_avals),
|
|
|
|
|
out_layouts=(None,) * len(global_out_avals),
|
2023-11-27 22:38:46 -08:00
|
|
|
|
shape_poly_state=lowering_result.shape_poly_state,
|
|
|
|
|
all_args_info=all_args_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
class MeshComputation(stages.XlaLowering):
|
2023-07-21 14:20:39 -04:00
|
|
|
|
_hlo: ir.Module | None
|
|
|
|
|
_executable: MeshExecutable | None
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
|
def __init__(self, name: str, hlo: ir.Module | None,
|
2023-08-25 10:59:10 -07:00
|
|
|
|
donated_invars: Sequence[bool], **compile_args):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self._name = name
|
|
|
|
|
self._hlo = hlo
|
|
|
|
|
self._donated_invars = donated_invars
|
|
|
|
|
self.compile_args = compile_args
|
|
|
|
|
self._executable = None
|
|
|
|
|
|
|
|
|
|
# -- stages.XlaLowering overrides
|
|
|
|
|
|
|
|
|
|
def stablehlo(self) -> ir.Module:
|
|
|
|
|
return self._hlo
|
|
|
|
|
|
2023-08-25 10:59:10 -07:00
|
|
|
|
def compile(self, compiler_options=None) -> MeshExecutable:
|
2023-03-30 17:13:46 -07:00
|
|
|
|
if self._executable is None or compiler_options is not None:
|
2023-08-25 10:59:10 -07:00
|
|
|
|
executable = UnloadedMeshExecutable.from_hlo(
|
|
|
|
|
self._name, self._hlo, **self.compile_args,
|
|
|
|
|
compiler_options=compiler_options)
|
2023-03-30 17:13:46 -07:00
|
|
|
|
if compiler_options is None:
|
|
|
|
|
self._executable = executable
|
|
|
|
|
return executable
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self._executable
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
def cost_analysis(self) -> dict[str, float]:
|
2023-02-15 01:49:55 +00:00
|
|
|
|
backend = self.compile_args["backend"]
|
|
|
|
|
if xb.using_pjrt_c_api(backend):
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Lowered.cost_analysis not implemented on platform "
|
[mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-26 14:46:05 -08:00
|
|
|
|
f"'{backend.platform}'. Use compile().cost_analysis() for " # type: ignore
|
2023-02-15 01:49:55 +00:00
|
|
|
|
"post-compilation cost estimates.")
|
|
|
|
|
return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module())
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
if xla_extension_version < 229:
|
|
|
|
|
def _get_input_indices(
|
|
|
|
|
avals: Sequence[ShapedArray],
|
|
|
|
|
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
2024-01-18 12:54:54 -08:00
|
|
|
|
da_object: xc.DeviceList | Sequence[xc.Device], # type: ignore
|
2024-01-05 14:16:32 -08:00
|
|
|
|
) -> Sequence[tuple[Index | None, ...]]:
|
2023-04-17 17:21:41 -07:00
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
input_indices = []
|
2024-01-18 12:54:54 -08:00
|
|
|
|
if not isinstance(da_object, xc.DeviceList):
|
2024-01-05 14:16:32 -08:00
|
|
|
|
da_object = _create_da_object(tuple(da_object))
|
|
|
|
|
num_addressable_devices = len(da_object.addressable_device_list)
|
2023-04-17 17:21:41 -07:00
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
|
|
|
|
|
if ndim is None:
|
|
|
|
|
return ((slice(None),),) * num_addressable_devices
|
|
|
|
|
else:
|
|
|
|
|
return ((slice(None),) * ndim,) * num_addressable_devices
|
2023-04-10 12:22:45 -07:00
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
for aval, sharding in zip(avals, shardings):
|
|
|
|
|
if aval is core.abstract_token:
|
|
|
|
|
index = _get_replicated_slices(num_addressable_devices, None)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2024-01-05 14:16:32 -08:00
|
|
|
|
if sharding.is_fully_replicated:
|
|
|
|
|
index = _get_replicated_slices(num_addressable_devices, aval.ndim)
|
|
|
|
|
else:
|
|
|
|
|
index = tuple(
|
|
|
|
|
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
|
|
|
|
|
input_indices.append(index)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
return input_indices
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2024-02-27 09:06:21 -08:00
|
|
|
|
def get_out_shardings_from_executable(
|
2023-09-11 11:54:29 -07:00
|
|
|
|
xla_executable,
|
|
|
|
|
device_assignment: Sequence[xc.Device],
|
|
|
|
|
num_out_avals: int,
|
|
|
|
|
num_ordered_effects: int,
|
|
|
|
|
all_default_mem_kind: bool,
|
2024-02-28 14:36:20 -08:00
|
|
|
|
) -> Sequence[sharding_impls.GSPMDSharding] | None:
|
2023-08-01 13:26:43 -07:00
|
|
|
|
from jax._src import pjit
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-11-20 11:43:41 -08:00
|
|
|
|
if config.enable_memories.value:
|
|
|
|
|
if all_default_mem_kind:
|
2023-08-04 09:43:39 -07:00
|
|
|
|
omk = [None] * num_out_avals
|
2023-11-20 11:43:41 -08:00
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
omk = xla_executable.get_output_memory_kinds()[0]
|
|
|
|
|
if num_ordered_effects > 0:
|
|
|
|
|
omk = omk[num_ordered_effects:]
|
|
|
|
|
except:
|
|
|
|
|
omk = [None] * num_out_avals
|
|
|
|
|
else:
|
|
|
|
|
omk = [None] * num_out_avals
|
2023-09-11 11:54:29 -07:00
|
|
|
|
|
|
|
|
|
assert len(omk) == num_out_avals, (len(omk), num_out_avals)
|
2023-08-04 09:43:39 -07:00
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
2024-01-23 21:28:33 -08:00
|
|
|
|
# Hence the op shardings will not be set on the `hlo_module`.
|
2023-02-06 14:28:36 -08:00
|
|
|
|
if len(device_assignment) == 1:
|
2024-01-23 21:28:33 -08:00
|
|
|
|
return [sharding_impls.GSPMDSharding.get_replicated(device_assignment, memory_kind=mk)
|
2023-08-04 09:43:39 -07:00
|
|
|
|
for mk in omk]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-08-04 09:43:39 -07:00
|
|
|
|
_, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
|
2023-11-17 12:18:46 -08:00
|
|
|
|
if not out_op_shardings:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
if num_ordered_effects > 0:
|
|
|
|
|
out_op_shardings = out_op_shardings[num_ordered_effects:]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-11-17 20:48:22 -08:00
|
|
|
|
# This means that there are no outputs for JAX but for XLA there is an empty
|
|
|
|
|
# tuple output which gets a replicated sharding.
|
|
|
|
|
if num_out_avals == 0 and len(out_op_shardings) == 1:
|
|
|
|
|
return None
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# This condition happens when all the elements in the output tuple have the
|
|
|
|
|
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
|
|
|
|
|
# put the sharding on ROOT instead of the tuple.
|
|
|
|
|
# TODO(b/245667823): Remove this when XLA fixes this.
|
2023-08-04 09:43:39 -07:00
|
|
|
|
if len(out_op_shardings) == 1 and len(out_op_shardings) < num_out_avals:
|
|
|
|
|
out_op_shardings = out_op_shardings * num_out_avals # type: ignore
|
|
|
|
|
|
|
|
|
|
assert len(out_op_shardings) == num_out_avals == len(omk), (
|
|
|
|
|
len(out_op_shardings), num_out_avals, len(omk))
|
|
|
|
|
|
|
|
|
|
return [sharding_impls.GSPMDSharding(device_assignment, os, memory_kind=mk)
|
|
|
|
|
for os, mk in safe_zip(out_op_shardings, omk)]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2024-02-27 09:06:21 -08:00
|
|
|
|
def _get_in_shardings_from_xla(
|
|
|
|
|
xla_executable, device_assignment: Sequence[xc.Device], num_in_avals: int,
|
|
|
|
|
num_ordered_effects: int
|
2024-02-28 14:36:20 -08:00
|
|
|
|
) -> Sequence[GSPMDSharding] | None:
|
2024-02-27 09:06:21 -08:00
|
|
|
|
"""Returns input shardings from XLA."""
|
|
|
|
|
from jax._src import pjit
|
|
|
|
|
|
|
|
|
|
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
|
|
|
|
# Hence the op shardings will not be set on the `hlo_module`.
|
|
|
|
|
if len(device_assignment) == 1:
|
2024-02-28 14:36:20 -08:00
|
|
|
|
return [GSPMDSharding.get_replicated(device_assignment)] * num_in_avals
|
2024-02-27 09:06:21 -08:00
|
|
|
|
|
|
|
|
|
in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable)
|
|
|
|
|
if not in_op_shardings:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
if num_ordered_effects > 0:
|
|
|
|
|
in_op_shardings = in_op_shardings[num_ordered_effects:]
|
|
|
|
|
|
|
|
|
|
assert len(in_op_shardings) == num_in_avals, (
|
|
|
|
|
len(in_op_shardings), num_in_avals)
|
|
|
|
|
|
2024-02-28 14:36:20 -08:00
|
|
|
|
return [GSPMDSharding(device_assignment, os)
|
2024-02-27 09:06:21 -08:00
|
|
|
|
for os in in_op_shardings]
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
|
|
|
|
|
# without mesh.
|
|
|
|
|
def _get_mesh_pspec_shardings_from_executable(
|
|
|
|
|
xla_executable, mesh: Mesh
|
2023-06-23 15:11:37 -07:00
|
|
|
|
) -> tuple[Sequence[sharding_impls.NamedSharding],
|
2023-03-13 08:49:39 -07:00
|
|
|
|
Sequence[sharding_impls.NamedSharding]]:
|
2023-08-01 13:26:43 -07:00
|
|
|
|
from jax._src import pjit
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-08-01 13:26:43 -07:00
|
|
|
|
in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh)
|
2023-03-13 08:49:39 -07:00
|
|
|
|
return ([sharding_impls.NamedSharding(mesh, i) for i in in_pspec],
|
|
|
|
|
[sharding_impls.NamedSharding(mesh, o) for o in out_pspec])
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-07-04 09:00:06 -07:00
|
|
|
|
_orig_out_sharding_handlers = {}
|
|
|
|
|
|
|
|
|
|
_ShardingT = TypeVar("_ShardingT", bound=sharding_impls.XLACompatibleSharding)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _register_out_sharding_handler(
|
|
|
|
|
sharding_cls: type[_ShardingT],
|
2023-08-04 09:43:39 -07:00
|
|
|
|
handler: Callable[[sharding_impls.GSPMDSharding, _ShardingT], _ShardingT],
|
2023-07-04 09:00:06 -07:00
|
|
|
|
) -> None:
|
|
|
|
|
_orig_out_sharding_handlers[sharding_cls] = handler
|
2023-04-11 16:27:08 -07:00
|
|
|
|
|
|
|
|
|
|
2024-03-06 11:41:34 -08:00
|
|
|
|
def _gspmd_to_named_sharding_via_mesh(
|
|
|
|
|
out_s: sharding_impls.GSPMDSharding,
|
|
|
|
|
mesh: Mesh) -> sharding_impls.NamedSharding:
|
|
|
|
|
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
|
|
|
|
|
out_s._hlo_sharding, mesh)[0]
|
|
|
|
|
return create_mesh_pspec_sharding(
|
|
|
|
|
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
|
|
|
|
|
out_s.memory_kind)
|
|
|
|
|
|
2023-04-11 16:27:08 -07:00
|
|
|
|
def _gspmd_to_named_sharding(
|
2023-08-04 09:43:39 -07:00
|
|
|
|
out_s: sharding_impls.GSPMDSharding,
|
|
|
|
|
orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
|
2024-03-07 13:33:13 -08:00
|
|
|
|
return _gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
|
2023-07-04 09:00:06 -07:00
|
|
|
|
|
|
|
|
|
_register_out_sharding_handler(
|
2023-08-04 09:43:39 -07:00
|
|
|
|
sharding_impls.NamedSharding, _gspmd_to_named_sharding)
|
2023-04-11 16:27:08 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gspmd_to_positional_sharding(
|
2023-08-04 09:43:39 -07:00
|
|
|
|
out_s: sharding_impls.GSPMDSharding,
|
2024-01-23 21:28:33 -08:00
|
|
|
|
orig_in_s: sharding_impls.PositionalSharding
|
|
|
|
|
) -> sharding_impls.PositionalSharding:
|
2023-06-12 11:51:47 -07:00
|
|
|
|
return sharding_impls._op_sharding_to_pos_sharding(
|
2023-08-04 09:43:39 -07:00
|
|
|
|
out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind)
|
2023-07-04 09:00:06 -07:00
|
|
|
|
|
|
|
|
|
_register_out_sharding_handler(
|
2023-08-04 09:43:39 -07:00
|
|
|
|
sharding_impls.PositionalSharding, _gspmd_to_positional_sharding)
|
2023-04-11 16:27:08 -07:00
|
|
|
|
|
2024-01-23 21:28:33 -08:00
|
|
|
|
def _gspmd_to_single_device_sharding(
|
|
|
|
|
out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding:
|
|
|
|
|
assert isinstance(orig_in_s, SingleDeviceSharding)
|
|
|
|
|
return SingleDeviceSharding(
|
|
|
|
|
out_s._device_assignment[0], memory_kind=out_s.memory_kind)
|
|
|
|
|
|
|
|
|
|
_register_out_sharding_handler(
|
|
|
|
|
SingleDeviceSharding, _gspmd_to_single_device_sharding)
|
|
|
|
|
|
2023-04-11 16:27:08 -07:00
|
|
|
|
|
|
|
|
|
def _get_out_sharding_from_orig_sharding(
|
2024-02-28 15:21:50 -08:00
|
|
|
|
out_shardings, out_avals, orig_in_s, orig_aval):
|
2023-04-09 15:41:32 -07:00
|
|
|
|
out = []
|
2023-08-04 09:43:39 -07:00
|
|
|
|
orig_handler = _orig_out_sharding_handlers[type(orig_in_s)]
|
2024-02-28 15:21:50 -08:00
|
|
|
|
for o, out_aval in safe_zip(out_shardings, out_avals):
|
2023-04-09 15:41:32 -07:00
|
|
|
|
if isinstance(o, sharding_impls.GSPMDSharding):
|
|
|
|
|
try:
|
2023-05-01 17:39:16 -07:00
|
|
|
|
# Only return the same input sharding object if the OpShardings and
|
|
|
|
|
# in_aval.ndim and out_aval.ndim match. This is because if OpSharding is
|
|
|
|
|
# replicated then, it doesn't encode the ndim in it. The devices
|
|
|
|
|
# will be the same at this point because those checks happen before.
|
|
|
|
|
if (orig_aval is not None and out_aval is not None and
|
2023-08-04 16:26:31 -07:00
|
|
|
|
out_aval.ndim == orig_aval.ndim
|
|
|
|
|
and sharding_impls.are_op_shardings_equal(
|
|
|
|
|
o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim))
|
|
|
|
|
and o.memory_kind == orig_in_s.memory_kind):
|
2024-02-28 15:21:50 -08:00
|
|
|
|
out.append(orig_in_s)
|
2023-05-01 11:46:19 -07:00
|
|
|
|
else:
|
2024-02-28 15:21:50 -08:00
|
|
|
|
out.append(orig_handler(o, orig_in_s))
|
2023-04-09 15:41:32 -07:00
|
|
|
|
except:
|
2024-02-28 15:21:50 -08:00
|
|
|
|
out.append(o)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
else:
|
2024-02-28 15:21:50 -08:00
|
|
|
|
out.append(o)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def maybe_get_orig_out_sharding(
|
2024-02-28 15:21:50 -08:00
|
|
|
|
in_shardings, out_shardings, in_avals, out_avals):
|
2023-04-09 15:41:32 -07:00
|
|
|
|
if all(hasattr(o, '_original_sharding') for o in out_shardings):
|
2024-02-28 15:21:50 -08:00
|
|
|
|
return [o._original_sharding for o in out_shardings]
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
2023-08-04 09:43:39 -07:00
|
|
|
|
orig_in_s = None
|
2023-05-01 11:46:19 -07:00
|
|
|
|
orig_aval = None
|
|
|
|
|
for i, aval in safe_zip(in_shardings, in_avals):
|
2023-04-09 15:41:32 -07:00
|
|
|
|
oi = getattr(i, '_original_sharding', None)
|
2023-07-04 09:00:06 -07:00
|
|
|
|
if type(oi) in _orig_out_sharding_handlers:
|
2023-08-04 09:43:39 -07:00
|
|
|
|
orig_in_s = oi
|
2023-05-01 11:46:19 -07:00
|
|
|
|
orig_aval = aval
|
2023-04-09 15:41:32 -07:00
|
|
|
|
break
|
2023-08-04 09:43:39 -07:00
|
|
|
|
if orig_in_s is not None:
|
2024-02-28 15:21:50 -08:00
|
|
|
|
return _get_out_sharding_from_orig_sharding(
|
|
|
|
|
out_shardings, out_avals, orig_in_s, orig_aval)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
2024-02-28 15:21:50 -08:00
|
|
|
|
return out_shardings
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
|
2023-11-15 08:48:17 -08:00
|
|
|
|
def _get_layouts_from_executable(
|
2023-11-16 18:00:49 -08:00
|
|
|
|
xla_executable, in_layouts, out_layouts, num_ordered_effects
|
|
|
|
|
) -> tuple[Sequence[SpecifiedLayout | None], Sequence[SpecifiedLayout | None]]:
|
|
|
|
|
try:
|
|
|
|
|
in_layouts_xla = xla_executable.get_parameter_layouts()
|
|
|
|
|
out_layouts_xla = xla_executable.get_output_layouts()
|
|
|
|
|
except:
|
|
|
|
|
return (None,) * len(in_layouts), (None,) * len(out_layouts)
|
2023-11-15 08:48:17 -08:00
|
|
|
|
|
2023-11-16 18:00:49 -08:00
|
|
|
|
if num_ordered_effects > 0:
|
|
|
|
|
in_layouts_xla = in_layouts_xla[num_ordered_effects:]
|
|
|
|
|
out_layouts_xla = out_layouts_xla[num_ordered_effects:]
|
2023-11-15 08:48:17 -08:00
|
|
|
|
|
|
|
|
|
new_in_layouts = []
|
|
|
|
|
for x, i in safe_zip(in_layouts_xla, in_layouts):
|
2023-11-18 15:16:31 -08:00
|
|
|
|
x = SpecifiedLayout(x)
|
2023-11-15 08:48:17 -08:00
|
|
|
|
if isinstance(i, SpecifiedLayout):
|
|
|
|
|
if i != x:
|
|
|
|
|
raise AssertionError(
|
|
|
|
|
f"Unexpected XLA layout override: (XLA) {x} != {i} (User sharding)")
|
|
|
|
|
new_in_layouts.append(i)
|
|
|
|
|
else:
|
|
|
|
|
new_in_layouts.append(x)
|
|
|
|
|
|
|
|
|
|
new_out_layouts = []
|
|
|
|
|
for x, o in safe_zip(out_layouts_xla, out_layouts):
|
2023-11-18 15:16:31 -08:00
|
|
|
|
x = SpecifiedLayout(x)
|
2023-11-15 08:48:17 -08:00
|
|
|
|
if isinstance(o, SpecifiedLayout):
|
|
|
|
|
if o != x:
|
|
|
|
|
raise AssertionError(
|
|
|
|
|
f"Unexpected XLA layout override: (XLA) {x} != {o} (User sharding)")
|
|
|
|
|
new_out_layouts.append(o)
|
|
|
|
|
else:
|
|
|
|
|
new_out_layouts.append(x)
|
|
|
|
|
|
|
|
|
|
assert all(isinstance(i, SpecifiedLayout) for i in new_in_layouts)
|
|
|
|
|
assert all(isinstance(o, SpecifiedLayout) for o in new_out_layouts)
|
2023-11-16 18:00:49 -08:00
|
|
|
|
return new_in_layouts, new_out_layouts # type: ignore
|
2023-11-15 08:48:17 -08:00
|
|
|
|
|
|
|
|
|
|
2024-01-02 13:12:44 -08:00
|
|
|
|
def get_logical_mesh_ids(mesh_shape):
|
|
|
|
|
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
|
|
|
|
|
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
@weakref_lru_cache
|
2023-04-12 17:37:52 -07:00
|
|
|
|
def _cached_compilation(computation, name, mesh, spmd_lowering,
|
2024-02-27 09:06:21 -08:00
|
|
|
|
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
|
|
|
|
allow_prop_to_outputs, host_callbacks, backend,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
da, pmap_nreps, compiler_options_keys,
|
|
|
|
|
compiler_options_values):
|
2023-05-20 22:59:52 -07:00
|
|
|
|
# TODO(phawkins): One would normally just write:
|
|
|
|
|
# dev = np.array(device_assignment)
|
|
|
|
|
# The formulation below is substantially faster if there are many devices.
|
|
|
|
|
# If we were to optimize __getattr__ on xc.Device we might not need this
|
|
|
|
|
# workaround.
|
2023-08-09 16:57:28 -07:00
|
|
|
|
dev = np.vectorize(lambda i: da[i], otypes=[object])(
|
|
|
|
|
np.arange(len(da))
|
2023-05-20 22:59:52 -07:00
|
|
|
|
)
|
|
|
|
|
if pmap_nreps > 1:
|
|
|
|
|
num_replicas, num_partitions = pmap_nreps, 1
|
|
|
|
|
elif spmd_lowering:
|
|
|
|
|
num_replicas, num_partitions = 1, dev.size
|
2023-04-09 15:41:32 -07:00
|
|
|
|
else:
|
2023-05-20 22:59:52 -07:00
|
|
|
|
num_replicas, num_partitions = dev.size, 1
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
if pmap_nreps > 1:
|
|
|
|
|
# In `jit` device_assignment is set to None when num_replicas > 1. Do
|
|
|
|
|
# the same thing here too.
|
|
|
|
|
xla_device_assignment = None
|
|
|
|
|
else:
|
|
|
|
|
xla_device_assignment = dev.reshape((num_replicas, num_partitions))
|
|
|
|
|
|
|
|
|
|
if compiler_options_keys is None:
|
|
|
|
|
compiler_options = None
|
|
|
|
|
else:
|
|
|
|
|
compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values))
|
|
|
|
|
|
2023-07-18 14:17:56 -07:00
|
|
|
|
fdo_profile = (None if compiler_options is None else
|
|
|
|
|
compiler_options.pop("fdo_profile", None))
|
|
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
|
compile_options = compiler.get_compile_options(
|
2023-04-09 15:41:32 -07:00
|
|
|
|
num_replicas=num_replicas,
|
|
|
|
|
num_partitions=num_partitions,
|
|
|
|
|
device_assignment=xla_device_assignment,
|
|
|
|
|
use_spmd_partitioning=spmd_lowering,
|
|
|
|
|
use_auto_spmd_partitioning=auto_spmd_lowering,
|
|
|
|
|
env_options_overrides=compiler_options,
|
2023-07-18 14:17:56 -07:00
|
|
|
|
fdo_profile=fdo_profile,
|
2023-11-20 15:51:27 -08:00
|
|
|
|
detailed_logging=compiler.use_detailed_logging(computation),
|
|
|
|
|
backend=backend,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
opts = compile_options.executable_build_options
|
|
|
|
|
if auto_spmd_lowering:
|
|
|
|
|
assert mesh is not None
|
|
|
|
|
opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values())
|
|
|
|
|
opts.auto_spmd_partitioning_mesh_ids = (
|
2024-01-02 13:12:44 -08:00
|
|
|
|
get_logical_mesh_ids(list(mesh.shape.values()))
|
2023-04-09 15:41:32 -07:00
|
|
|
|
.reshape(-1))
|
|
|
|
|
compile_options.parameter_is_tupled_arguments = tuple_args
|
2024-03-05 15:56:16 -08:00
|
|
|
|
if xla_extension_version >= 241:
|
2024-02-27 09:06:21 -08:00
|
|
|
|
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
|
|
|
|
|
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
|
|
|
|
if hasattr(backend, "compile_replicated"):
|
|
|
|
|
return None, compile_options
|
|
|
|
|
|
2023-05-15 09:15:22 -07:00
|
|
|
|
with dispatch.log_elapsed_time(
|
|
|
|
|
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
|
|
|
|
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
|
2023-08-15 06:38:56 -07:00
|
|
|
|
xla_executable = compiler.compile_or_get_cached(
|
2023-04-20 06:16:12 -07:00
|
|
|
|
backend, computation, dev, compile_options, host_callbacks)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
return xla_executable, compile_options
|
|
|
|
|
|
|
|
|
|
|
2024-02-27 09:06:21 -08:00
|
|
|
|
def _maybe_get_and_check_in_shardings(
|
|
|
|
|
xla_executable, in_shardings, device_assignment,
|
|
|
|
|
global_in_avals, num_ordered_effects):
|
|
|
|
|
"""Returns in_shardings extracted from XLA or checks and returns original
|
|
|
|
|
shardings.
|
|
|
|
|
|
|
|
|
|
If in_shardings exist on `jit` or on `jax.Array`, then this function will
|
|
|
|
|
check that sharding against what XLA returns as in_shardings. If they don't
|
|
|
|
|
match, an error is raised.
|
|
|
|
|
|
|
|
|
|
If in_sharding is unspecified, then the sharding returned by XLA is returned.
|
|
|
|
|
"""
|
|
|
|
|
in_shardings_xla = _get_in_shardings_from_xla( # type: ignore
|
|
|
|
|
xla_executable, device_assignment, len(global_in_avals),
|
|
|
|
|
num_ordered_effects) # type: ignore
|
|
|
|
|
if in_shardings_xla is None:
|
|
|
|
|
return in_shardings
|
|
|
|
|
|
|
|
|
|
new_in_shardings = []
|
|
|
|
|
for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings,
|
|
|
|
|
global_in_avals):
|
|
|
|
|
if is_unspecified(orig):
|
2024-02-28 14:36:20 -08:00
|
|
|
|
if (aval is not core.abstract_token and
|
|
|
|
|
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
2024-03-06 11:41:34 -08:00
|
|
|
|
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
|
2024-02-27 09:06:21 -08:00
|
|
|
|
new_in_shardings.append(xla_s)
|
|
|
|
|
else:
|
2024-03-01 08:39:49 -08:00
|
|
|
|
# TODO(yashkatariya): Remove the if branch for abstract_token once
|
|
|
|
|
# choosing input shardings by XLA is enabled again.
|
|
|
|
|
if aval is core.abstract_token:
|
|
|
|
|
new_in_shardings.append(orig)
|
|
|
|
|
else:
|
|
|
|
|
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
|
|
|
|
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
|
|
|
|
# MANUAL HloSharding comes from other partitioning frameworks.
|
|
|
|
|
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
|
|
|
|
|
not xla_hlo_s.is_manual() and
|
2024-03-13 11:23:12 -07:00
|
|
|
|
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s))):
|
2024-03-01 08:39:49 -08:00
|
|
|
|
raise AssertionError(
|
|
|
|
|
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
|
|
|
|
"(User sharding)")
|
|
|
|
|
new_in_shardings.append(orig)
|
2024-02-27 09:06:21 -08:00
|
|
|
|
return new_in_shardings
|
|
|
|
|
|
|
|
|
|
|
2024-02-28 17:03:04 -08:00
|
|
|
|
def _maybe_get_and_check_out_shardings(
|
2023-11-17 12:18:46 -08:00
|
|
|
|
xla_executable, out_shardings, device_assignment, global_out_avals,
|
|
|
|
|
num_ordered_effects, all_default_mem_kind
|
|
|
|
|
):
|
2024-02-27 09:06:21 -08:00
|
|
|
|
out_shardings_xla = get_out_shardings_from_executable( # type: ignore
|
2023-11-17 12:18:46 -08:00
|
|
|
|
xla_executable, device_assignment, len(global_out_avals),
|
|
|
|
|
num_ordered_effects, all_default_mem_kind) # type: ignore
|
|
|
|
|
if out_shardings_xla is None:
|
2024-02-28 15:21:50 -08:00
|
|
|
|
return out_shardings
|
2023-11-17 12:18:46 -08:00
|
|
|
|
|
2024-02-28 15:21:50 -08:00
|
|
|
|
new_out_shardings = []
|
2024-02-27 09:06:21 -08:00
|
|
|
|
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
|
2023-11-17 12:18:46 -08:00
|
|
|
|
global_out_avals):
|
|
|
|
|
if is_unspecified(orig):
|
2024-02-28 14:36:20 -08:00
|
|
|
|
if (aval is not core.abstract_token and
|
|
|
|
|
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
2024-03-06 11:41:34 -08:00
|
|
|
|
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
|
2024-02-27 09:06:21 -08:00
|
|
|
|
new_out_shardings.append(xla_s)
|
2023-11-17 12:18:46 -08:00
|
|
|
|
else:
|
|
|
|
|
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
|
|
|
|
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
|
|
|
|
# MANUAL HloSharding comes from other partitioning frameworks.
|
|
|
|
|
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
|
|
|
|
|
not xla_hlo_s.is_manual() and
|
|
|
|
|
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
|
|
|
|
|
xla_s.memory_kind != orig.memory_kind)): # type: ignore
|
|
|
|
|
raise AssertionError(
|
|
|
|
|
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
|
|
|
|
"(User sharding)")
|
2024-02-27 09:06:21 -08:00
|
|
|
|
new_out_shardings.append(orig)
|
2024-02-28 15:21:50 -08:00
|
|
|
|
return new_out_shardings
|
2023-11-17 12:18:46 -08:00
|
|
|
|
|
|
|
|
|
|
2024-02-28 15:21:50 -08:00
|
|
|
|
def finalize_out_shardings(out_shardings, device_assignment):
|
2024-01-23 21:28:33 -08:00
|
|
|
|
if len(device_assignment) == 1:
|
2024-02-28 15:21:50 -08:00
|
|
|
|
return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind)
|
|
|
|
|
if isinstance(o, GSPMDSharding) else o for o in out_shardings]
|
|
|
|
|
return out_shardings
|
2024-01-23 21:28:33 -08:00
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class UnloadedMeshExecutable:
|
|
|
|
|
xla_executable: Any
|
2024-01-18 12:54:54 -08:00
|
|
|
|
device_assignment: xc.DeviceList | Sequence[xc.Device] # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
backend: xb.XlaBackend
|
|
|
|
|
input_avals: Sequence[ShapedArray]
|
2023-03-13 08:49:39 -07:00
|
|
|
|
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
output_avals: Sequence[ShapedArray]
|
2023-03-13 08:49:39 -07:00
|
|
|
|
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
committed: bool
|
|
|
|
|
name: str
|
2023-06-23 15:11:37 -07:00
|
|
|
|
unordered_effects: list[core.Effect]
|
|
|
|
|
ordered_effects: list[core.Effect]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
keepalive: Sequence[Any]
|
|
|
|
|
host_callbacks: Sequence[Any]
|
2023-06-23 15:11:37 -07:00
|
|
|
|
kept_var_idx: set[int]
|
2024-03-05 16:20:24 -08:00
|
|
|
|
mut: MutationData | None
|
2023-02-06 14:28:36 -08:00
|
|
|
|
auto_spmd_lowering: bool
|
2023-11-15 08:48:17 -08:00
|
|
|
|
in_layouts: Sequence[SpecifiedLayout | None]
|
|
|
|
|
out_layouts: Sequence[SpecifiedLayout | None]
|
2023-11-27 22:38:46 -08:00
|
|
|
|
all_args_info: AllArgsInfo | None
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-22 17:22:39 -07:00
|
|
|
|
def build_unsafe_call(self):
|
2024-01-05 14:16:32 -08:00
|
|
|
|
if xla_extension_version >= 229:
|
|
|
|
|
handle_args = InputsHandler(self.input_shardings)
|
|
|
|
|
else:
|
|
|
|
|
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
|
|
|
|
|
self.device_assignment)
|
|
|
|
|
handle_args = InputsHandler(
|
|
|
|
|
self.input_shardings, self.xla_executable.local_devices(), input_indices)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
handle_outs = global_avals_to_results_handler(
|
2024-02-28 15:21:50 -08:00
|
|
|
|
self.output_avals, self.output_shardings, self.committed) # type: ignore # arg-type
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-13 14:08:48 -07:00
|
|
|
|
unsafe_call = ExecuteReplicated( # type: ignore # assignment
|
|
|
|
|
self.xla_executable, self.name, self.backend, handle_args,
|
|
|
|
|
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
|
2024-03-05 16:20:24 -08:00
|
|
|
|
bool(self.host_callbacks), self.kept_var_idx, self.mut)
|
2023-03-22 17:22:39 -07:00
|
|
|
|
return unsafe_call
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-22 17:22:39 -07:00
|
|
|
|
def load(self) -> MeshExecutable:
|
|
|
|
|
return MeshExecutable(self.xla_executable, self.build_unsafe_call,
|
2024-03-02 13:34:46 -08:00
|
|
|
|
self.input_avals, self.output_avals,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.input_shardings, self.output_shardings,
|
|
|
|
|
self.auto_spmd_lowering, self.kept_var_idx,
|
2023-11-15 08:48:17 -08:00
|
|
|
|
self.in_layouts, self.out_layouts,
|
2023-11-28 10:04:29 -08:00
|
|
|
|
self.all_args_info, self)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# May return a MeshExecutable in the compile_replicated case.
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_hlo(name: str,
|
2023-04-21 14:37:52 -07:00
|
|
|
|
hlo: ir.Module,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
global_in_avals: Sequence[ShapedArray],
|
|
|
|
|
global_out_avals: Sequence[ShapedArray],
|
2023-07-21 14:20:39 -04:00
|
|
|
|
in_shardings: Sequence[sharding_impls.XLACompatibleSharding | AUTO],
|
|
|
|
|
out_shardings: Sequence[(sharding_impls.XLACompatibleSharding | AUTO |
|
2024-03-01 09:27:57 -08:00
|
|
|
|
UnspecifiedValue)],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
spmd_lowering: bool,
|
|
|
|
|
tuple_args: bool,
|
|
|
|
|
auto_spmd_lowering: bool,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
unordered_effects: list[core.Effect],
|
|
|
|
|
ordered_effects: list[core.Effect],
|
|
|
|
|
host_callbacks: list[Any],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
keepalive: Any,
|
2023-06-23 15:11:37 -07:00
|
|
|
|
kept_var_idx: set[int],
|
2023-02-06 14:28:36 -08:00
|
|
|
|
backend: xb.XlaBackend,
|
2024-01-18 12:54:54 -08:00
|
|
|
|
device_assignment: xc.DeviceList | Sequence[xc.Device], # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
committed: bool,
|
2023-11-15 08:48:17 -08:00
|
|
|
|
in_layouts: MaybeLayout,
|
|
|
|
|
out_layouts: MaybeLayout,
|
2023-03-30 17:13:46 -07:00
|
|
|
|
pmap_nreps: int = 1,
|
2024-03-05 16:20:24 -08:00
|
|
|
|
mut: MutationData | None = None,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
2023-09-11 11:54:29 -07:00
|
|
|
|
all_default_mem_kind: bool = True,
|
2023-11-27 22:38:46 -08:00
|
|
|
|
all_args_info: AllArgsInfo | None = None,
|
2023-09-11 11:54:29 -07:00
|
|
|
|
compiler_options=None,
|
2023-03-22 17:22:39 -07:00
|
|
|
|
) -> MeshExecutable:
|
2023-06-06 13:26:35 -07:00
|
|
|
|
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
2023-06-09 23:46:45 -07:00
|
|
|
|
hlo = mlir.refine_polymorphic_shapes(hlo)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
compiler_options_keys = tuple(
|
|
|
|
|
compiler_options.keys()) if compiler_options is not None else None
|
|
|
|
|
compiler_options_values = tuple(
|
|
|
|
|
compiler_options.values()) if compiler_options is not None else None
|
2024-01-18 12:54:54 -08:00
|
|
|
|
if isinstance(device_assignment, xc.DeviceList):
|
2023-08-09 16:57:28 -07:00
|
|
|
|
da = device_assignment
|
|
|
|
|
else:
|
|
|
|
|
da = _create_da_object(tuple(device_assignment))
|
2023-04-10 12:22:45 -07:00
|
|
|
|
del device_assignment
|
2024-02-27 09:06:21 -08:00
|
|
|
|
|
|
|
|
|
allow_prop_to_inputs = tuple(is_unspecified(i) for i in in_shardings)
|
2023-04-12 17:37:52 -07:00
|
|
|
|
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
|
2023-05-20 22:59:52 -07:00
|
|
|
|
|
|
|
|
|
mesh = None
|
|
|
|
|
if auto_spmd_lowering:
|
|
|
|
|
for i in it.chain.from_iterable([in_shardings, out_shardings]):
|
|
|
|
|
if is_auto(i):
|
|
|
|
|
mesh = i.mesh # type: ignore
|
|
|
|
|
break
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
xla_executable, compile_options = _cached_compilation(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
hlo, name, mesh, spmd_lowering,
|
2024-02-27 09:06:21 -08:00
|
|
|
|
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
|
|
|
|
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
|
2023-04-09 15:41:32 -07:00
|
|
|
|
compiler_options_keys, compiler_options_values)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
if hasattr(backend, "compile_replicated"):
|
|
|
|
|
semantics_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
|
|
|
|
|
semantics_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
|
|
|
|
|
return _compile_replicated_mesh_executable_from_hlo(
|
2023-04-21 14:37:52 -07:00
|
|
|
|
hlo, name, tuple(global_in_avals), tuple(global_out_avals),
|
2023-04-09 15:41:32 -07:00
|
|
|
|
semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering,
|
|
|
|
|
compile_options, tuple(host_callbacks), bool(unordered_effects),
|
|
|
|
|
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
|
2023-11-28 10:04:29 -08:00
|
|
|
|
pmap_nreps)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
if auto_spmd_lowering:
|
|
|
|
|
assert mesh is not None
|
2023-04-09 15:41:32 -07:00
|
|
|
|
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
|
|
|
|
|
xla_executable, mesh)
|
2023-05-20 22:59:52 -07:00
|
|
|
|
in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore
|
2023-04-09 15:41:32 -07:00
|
|
|
|
for x, i in safe_zip(in_shardings_xla, in_shardings)]
|
2024-02-28 15:21:50 -08:00
|
|
|
|
out_shardings = [x if is_auto(o) else o
|
|
|
|
|
for x, o in safe_zip(out_shardings_xla, out_shardings)]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-11-17 12:18:46 -08:00
|
|
|
|
if pmap_nreps == 1:
|
|
|
|
|
assert mesh is None
|
2024-03-05 15:56:16 -08:00
|
|
|
|
if xla_extension_version >= 241:
|
2024-02-27 09:06:21 -08:00
|
|
|
|
in_shardings = _maybe_get_and_check_in_shardings(
|
|
|
|
|
xla_executable, in_shardings, tuple(da), global_in_avals,
|
|
|
|
|
len(ordered_effects))
|
2024-02-28 17:03:04 -08:00
|
|
|
|
out_shardings = _maybe_get_and_check_out_shardings(
|
2023-11-17 12:18:46 -08:00
|
|
|
|
xla_executable, out_shardings, tuple(da), global_out_avals,
|
|
|
|
|
len(ordered_effects), all_default_mem_kind)
|
|
|
|
|
else:
|
|
|
|
|
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
|
|
|
|
|
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
2023-11-18 15:16:31 -08:00
|
|
|
|
if xla_extension_version >= 217:
|
2023-11-16 18:00:49 -08:00
|
|
|
|
in_layouts, out_layouts = _get_layouts_from_executable(
|
|
|
|
|
xla_executable, in_layouts, out_layouts, len(ordered_effects))
|
|
|
|
|
else:
|
|
|
|
|
assert all(i is None for i in in_layouts)
|
|
|
|
|
assert all(o is None for o in out_layouts)
|
2023-11-15 08:48:17 -08:00
|
|
|
|
|
2024-02-28 15:21:50 -08:00
|
|
|
|
out_shardings = maybe_get_orig_out_sharding(
|
|
|
|
|
in_shardings, out_shardings, global_in_avals, global_out_avals)
|
2023-04-09 15:41:32 -07:00
|
|
|
|
|
2024-02-28 15:21:50 -08:00
|
|
|
|
out_shardings = finalize_out_shardings(out_shardings, da)
|
2024-01-23 21:28:33 -08:00
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
return UnloadedMeshExecutable(
|
|
|
|
|
xla_executable=xla_executable,
|
2023-04-10 12:22:45 -07:00
|
|
|
|
device_assignment=da, # type: ignore
|
2023-04-09 15:41:32 -07:00
|
|
|
|
backend=backend,
|
|
|
|
|
input_avals=global_in_avals,
|
|
|
|
|
input_shardings=in_shardings, # type: ignore
|
|
|
|
|
output_avals=global_out_avals,
|
|
|
|
|
output_shardings=out_shardings, # type: ignore # arg-type
|
|
|
|
|
committed=committed,
|
|
|
|
|
name=name,
|
|
|
|
|
unordered_effects=unordered_effects,
|
|
|
|
|
ordered_effects=ordered_effects,
|
|
|
|
|
keepalive=keepalive,
|
|
|
|
|
host_callbacks=host_callbacks,
|
|
|
|
|
kept_var_idx=kept_var_idx,
|
2024-03-05 16:20:24 -08:00
|
|
|
|
mut=mut,
|
2023-04-19 12:35:15 -07:00
|
|
|
|
auto_spmd_lowering=auto_spmd_lowering,
|
2023-11-15 08:48:17 -08:00
|
|
|
|
in_layouts=in_layouts, # type: ignore
|
2023-11-27 22:38:46 -08:00
|
|
|
|
out_layouts=out_layouts, # type: ignore
|
|
|
|
|
all_args_info=all_args_info).load() # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
class MeshExecutableFastpathData(NamedTuple):
|
2023-02-16 11:54:25 -08:00
|
|
|
|
xla_executable: xc.LoadedExecutable
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_pytree_def: Any
|
2023-03-13 08:49:39 -07:00
|
|
|
|
in_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
|
|
|
|
out_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_avals: Sequence[ShapedArray]
|
|
|
|
|
out_committed: Sequence[bool]
|
|
|
|
|
kept_var_bitvec: Iterable[bool]
|
2024-01-05 14:16:32 -08:00
|
|
|
|
# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24
|
2023-12-19 14:25:25 -08:00
|
|
|
|
arg_handler_devices: Sequence[xc.Device]
|
|
|
|
|
arg_handler_indices: Sequence[tuple[Index | None, ...]]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-09-13 09:43:14 -07:00
|
|
|
|
def reflatten_outputs_for_dispatch(out_tree, out_flat):
|
|
|
|
|
# We arrive at dispatch having flattened according to the default
|
|
|
|
|
# pytree registry, but we want to re-flatten according to our
|
|
|
|
|
# dispatch-specific registry.
|
|
|
|
|
out_unflat = tree_util.tree_unflatten(out_tree, out_flat)
|
|
|
|
|
return tree_util.dispatch_registry.flatten(out_unflat, None)
|
|
|
|
|
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
class MeshExecutable(stages.XlaExecutable):
|
|
|
|
|
__slots__ = [
|
2023-04-10 12:22:45 -07:00
|
|
|
|
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
|
2024-03-02 13:34:46 -08:00
|
|
|
|
"out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering",
|
|
|
|
|
"_kept_var_idx", "_in_layouts", "_out_layouts", "_all_args_info",
|
|
|
|
|
"_unloaded_executable",
|
2023-02-06 14:28:36 -08:00
|
|
|
|
]
|
|
|
|
|
|
2024-03-02 13:34:46 -08:00
|
|
|
|
def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals,
|
|
|
|
|
in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx,
|
2023-11-28 10:04:29 -08:00
|
|
|
|
in_layouts, out_layouts,
|
2023-11-27 22:38:46 -08:00
|
|
|
|
all_args_info: AllArgsInfo | None = None,
|
2023-11-15 08:48:17 -08:00
|
|
|
|
unloaded_executable=None):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self.xla_executable = xla_executable
|
2023-03-22 17:22:39 -07:00
|
|
|
|
self.build_unsafe_call = build_unsafe_call
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# in_avals is a list of global and local avals. Aval is global if input
|
|
|
|
|
# is a GDA or jax.Array else local.
|
|
|
|
|
self.in_avals = in_avals
|
2024-03-02 13:34:46 -08:00
|
|
|
|
self.out_avals = out_avals
|
2023-03-22 17:22:39 -07:00
|
|
|
|
self._unsafe_call = None
|
2023-02-06 14:28:36 -08:00
|
|
|
|
self._in_shardings = in_shardings
|
|
|
|
|
self._out_shardings = out_shardings
|
|
|
|
|
self._auto_spmd_lowering = auto_spmd_lowering
|
|
|
|
|
self._kept_var_idx = kept_var_idx
|
2023-11-15 08:48:17 -08:00
|
|
|
|
self._in_layouts = in_layouts
|
|
|
|
|
self._out_layouts = out_layouts
|
2023-11-27 22:38:46 -08:00
|
|
|
|
self._all_args_info = all_args_info
|
2023-03-22 17:22:39 -07:00
|
|
|
|
self._unloaded_executable = unloaded_executable
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def unsafe_call(self) -> Callable[..., Any]:
|
|
|
|
|
if self._unsafe_call is None:
|
|
|
|
|
self._unsafe_call = self.build_unsafe_call()
|
2023-11-27 22:38:46 -08:00
|
|
|
|
return self._unsafe_call # type: ignore
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
# -- stages.XlaExecutable overrides
|
|
|
|
|
|
|
|
|
|
def xla_extension_executable(self):
|
|
|
|
|
return self.xla_executable
|
|
|
|
|
|
|
|
|
|
def call(self, *args):
|
2023-11-27 22:38:46 -08:00
|
|
|
|
if self._all_args_info is None:
|
|
|
|
|
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
|
|
|
|
|
ref_avals = self.in_avals
|
|
|
|
|
in_shardings = self._in_shardings
|
2023-11-28 10:04:29 -08:00
|
|
|
|
debug_info = None
|
2023-11-27 22:38:46 -08:00
|
|
|
|
else:
|
|
|
|
|
kept_args = args
|
|
|
|
|
ref_avals = self._all_args_info.in_avals
|
|
|
|
|
iter_in_shardings = iter(self._in_shardings)
|
|
|
|
|
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
|
|
|
|
|
for i, s in enumerate(self._all_args_info.in_shardings)]
|
|
|
|
|
debug_info = self._all_args_info.debug_info
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
arg_avals = map(xla.abstractify, kept_args)
|
2023-11-27 22:38:46 -08:00
|
|
|
|
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# Check the GDA sharding and the input sharding.
|
2023-11-27 22:38:46 -08:00
|
|
|
|
check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info)
|
2023-03-22 17:22:39 -07:00
|
|
|
|
return self.unsafe_call(*args) # pylint: disable=not-callable
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-03-13 08:49:39 -07:00
|
|
|
|
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self._in_shardings
|
|
|
|
|
|
2023-03-13 08:49:39 -07:00
|
|
|
|
def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
return self._out_shardings
|
|
|
|
|
|
2023-11-15 08:48:17 -08:00
|
|
|
|
def input_layouts(self):
|
|
|
|
|
return self._in_layouts
|
|
|
|
|
|
|
|
|
|
def output_layouts(self):
|
|
|
|
|
return self._out_layouts
|
|
|
|
|
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def create_cpp_call(self, no_kwargs, in_tree, out_tree):
|
|
|
|
|
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
|
|
|
|
|
not self.unsafe_call.has_unordered_effects and
|
|
|
|
|
not self.unsafe_call.has_host_callbacks):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def aot_cache_miss(*args, **kwargs):
|
|
|
|
|
params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree)
|
|
|
|
|
outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
|
2023-09-13 09:43:14 -07:00
|
|
|
|
out_flat, out_tree_dispatch = reflatten_outputs_for_dispatch(
|
|
|
|
|
out_tree, out_flat)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))
|
|
|
|
|
|
|
|
|
|
if use_fastpath:
|
|
|
|
|
out_avals = [o.aval for o in out_flat]
|
|
|
|
|
out_committed = [o._committed for o in out_flat]
|
|
|
|
|
kept_var_bitvec = [i in self._kept_var_idx
|
|
|
|
|
for i in range(len(args_flat))]
|
2024-03-06 11:41:34 -08:00
|
|
|
|
in_shardings = [
|
|
|
|
|
a.dtype._rules.physical_sharding(a, s)
|
|
|
|
|
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
|
|
|
|
|
else s
|
|
|
|
|
for s, a in zip(self._in_shardings, self.in_avals)
|
|
|
|
|
]
|
2023-02-07 11:16:01 -08:00
|
|
|
|
fastpath_data = MeshExecutableFastpathData(
|
2024-03-06 11:41:34 -08:00
|
|
|
|
self.xla_executable, out_tree_dispatch, in_shardings,
|
2023-12-19 14:25:25 -08:00
|
|
|
|
self._out_shardings, out_avals, out_committed, kept_var_bitvec,
|
|
|
|
|
self.unsafe_call.in_handler.local_devices,
|
|
|
|
|
self.unsafe_call.in_handler.input_indices)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
|
|
|
|
fastpath_data = None
|
|
|
|
|
return outs, fastpath_data
|
|
|
|
|
|
2023-12-19 14:25:25 -08:00
|
|
|
|
if xla_extension_version >= 226:
|
2024-01-05 14:16:32 -08:00
|
|
|
|
return xc._xla.pjit(
|
|
|
|
|
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
|
|
|
|
tree_util.dispatch_registry,
|
|
|
|
|
shard_arg if xla_extension_version >= 229 else temp_shard_arg) # type: ignore
|
2023-12-19 14:25:25 -08:00
|
|
|
|
else:
|
|
|
|
|
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], # type: ignore
|
|
|
|
|
tree_util.dispatch_registry)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2024-01-05 14:16:32 -08:00
|
|
|
|
# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24
|
|
|
|
|
def temp_shard_arg(arg, devices, arg_indices, sharding, canonicalize=True):
|
|
|
|
|
return shard_arg(arg, sharding)
|
|
|
|
|
|
|
|
|
|
|
2023-04-19 15:08:21 -07:00
|
|
|
|
def check_arg_avals_for_call(ref_avals, arg_avals,
|
2023-07-21 14:20:39 -04:00
|
|
|
|
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
|
2023-03-16 15:46:57 -07:00
|
|
|
|
if len(ref_avals) != len(arg_avals):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"Computation compiled for {len(ref_avals)} inputs "
|
|
|
|
|
f"but called with {len(arg_avals)}")
|
2023-07-10 18:28:50 -07:00
|
|
|
|
|
|
|
|
|
if jaxpr_debug_info is not None:
|
|
|
|
|
arg_names = [f"'{name}'" for name in jaxpr_debug_info.arg_names]
|
|
|
|
|
else:
|
|
|
|
|
num_args = len(ref_avals)
|
|
|
|
|
arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)]
|
|
|
|
|
|
2023-04-19 15:08:21 -07:00
|
|
|
|
errors = []
|
|
|
|
|
for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names):
|
2023-03-16 15:46:57 -07:00
|
|
|
|
if not core.typematch(ref_aval, arg_aval):
|
2023-07-10 18:28:50 -07:00
|
|
|
|
errors.append(
|
|
|
|
|
f"Argument {name} compiled with {ref_aval.str_short()} and called "
|
|
|
|
|
f"with {arg_aval.str_short()}")
|
2023-04-19 15:08:21 -07:00
|
|
|
|
if errors:
|
2023-07-10 18:28:50 -07:00
|
|
|
|
max_num_errors = 5
|
|
|
|
|
str_errors = "\n".join(errors[:max_num_errors])
|
|
|
|
|
if len(errors) >= max_num_errors:
|
|
|
|
|
num_mismatch_str = f"The first {max_num_errors} of {len(errors)}"
|
|
|
|
|
else:
|
|
|
|
|
num_mismatch_str = "The"
|
2023-04-19 15:08:21 -07:00
|
|
|
|
raise TypeError(
|
2023-07-10 18:28:50 -07:00
|
|
|
|
"Argument types differ from the types for which this computation was "
|
|
|
|
|
f"compiled. {num_mismatch_str} mismatches are:\n{str_errors}")
|
2023-03-16 15:46:57 -07:00
|
|
|
|
|
|
|
|
|
|
2023-04-10 12:22:45 -07:00
|
|
|
|
def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
|
|
|
|
|
# Create replicated shardings for jit(pmap) path with local devices
|
|
|
|
|
# because multihost jit(pmap) is not allowed.
|
2023-04-13 15:18:56 -07:00
|
|
|
|
gs = sharding_impls.GSPMDSharding.get_replicated(local_devices)
|
|
|
|
|
in_shardings = [gs] * num_in_shardings
|
|
|
|
|
out_shardings = [gs] * num_out_shardings
|
2023-04-10 12:22:45 -07:00
|
|
|
|
# jit(pmap) will generate Arrays with multi-device sharding.
|
2023-09-22 14:54:31 -07:00
|
|
|
|
# It is unsupported for these shardings to be uncommitted, so force
|
2023-04-10 12:22:45 -07:00
|
|
|
|
# the outputs to be committed.
|
|
|
|
|
committed = True
|
|
|
|
|
return in_shardings, out_shardings, committed, tuple(local_devices)
|
|
|
|
|
|
|
|
|
|
|
2023-04-09 15:41:32 -07:00
|
|
|
|
@weakref_lru_cache
|
2023-02-06 14:28:36 -08:00
|
|
|
|
def _compile_replicated_mesh_executable_from_hlo(
|
2023-04-09 15:41:32 -07:00
|
|
|
|
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
|
|
|
|
|
semantics_out_shardings, auto_spmd_lowering, compile_options,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
|
2023-11-28 10:04:29 -08:00
|
|
|
|
backend, da, committed, pmap_nreps):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
assert not auto_spmd_lowering
|
2023-04-09 15:41:32 -07:00
|
|
|
|
in_shardings = semantics_in_shardings.shardings
|
|
|
|
|
out_shardings = semantics_out_shardings.shardings
|
|
|
|
|
|
|
|
|
|
kept_var_idx = set(kept_var_idx)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
# Will compute out_handler with executable information.
|
|
|
|
|
unsafe_call = backend.compile_replicated(
|
|
|
|
|
is_trivial=False, name=name, computation=computation,
|
|
|
|
|
compile_options=compile_options, host_callbacks=host_callbacks,
|
|
|
|
|
has_unordered_effects=has_unordered_effects,
|
2024-01-05 14:16:32 -08:00
|
|
|
|
device_assignment=da, ordered_effects=ordered_effects,
|
|
|
|
|
in_avals=global_in_avals,
|
|
|
|
|
in_shardings=in_shardings, kept_var_idx=kept_var_idx,
|
2023-02-06 14:28:36 -08:00
|
|
|
|
out_avals=global_out_avals, out_shardings=out_shardings,
|
2023-03-06 15:01:48 -08:00
|
|
|
|
committed=committed, pmap_nreps=pmap_nreps)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
xla_executable = None
|
2023-03-22 17:22:39 -07:00
|
|
|
|
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
|
2024-03-02 13:34:46 -08:00
|
|
|
|
global_out_avals, in_shardings, out_shardings,
|
|
|
|
|
auto_spmd_lowering, kept_var_idx,
|
|
|
|
|
(None,) * len(global_in_avals),
|
2023-11-28 10:04:29 -08:00
|
|
|
|
(None,) * len(global_out_avals))
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-07-21 14:20:39 -04:00
|
|
|
|
@lru_cache
|
2023-02-07 11:16:01 -08:00
|
|
|
|
def create_mesh_pspec_sharding(
|
2023-12-08 12:09:04 +00:00
|
|
|
|
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
|
|
|
|
|
memory_kind: str | None = None) -> sharding_impls.NamedSharding:
|
2023-04-10 08:42:18 -07:00
|
|
|
|
if pspec is None:
|
2023-04-10 10:48:26 -07:00
|
|
|
|
pspec, parsed_pspec = PartitionSpec(), None
|
2023-08-04 09:43:39 -07:00
|
|
|
|
return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,
|
|
|
|
|
memory_kind=memory_kind)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
def check_device_backend_on_shardings(shardings) -> bool:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
for i in shardings:
|
2023-04-10 10:15:08 -07:00
|
|
|
|
if is_unspecified(i) or is_auto(i):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
continue
|
|
|
|
|
if hasattr(i, '_original_sharding') and getattr(
|
|
|
|
|
i._original_sharding, '_device_backend', False):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
def check_gda_or_array_xla_sharding_match(
|
2023-04-19 12:35:15 -07:00
|
|
|
|
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
2023-07-21 14:20:39 -04:00
|
|
|
|
jaxpr_debug_info: core.JaxprDebugInfo | None) -> None:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
from jax._src.array import ArrayImpl
|
2023-04-19 12:35:15 -07:00
|
|
|
|
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
|
|
|
|
|
jaxpr_debug_info.arg_names)
|
|
|
|
|
errors = []
|
|
|
|
|
num_errors = 5
|
|
|
|
|
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
|
2023-03-15 12:59:33 -07:00
|
|
|
|
if not isinstance(arg, ArrayImpl):
|
2023-02-06 14:28:36 -08:00
|
|
|
|
continue
|
2023-11-27 22:38:46 -08:00
|
|
|
|
if is_unspecified_or_auto(xs):
|
|
|
|
|
continue
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
2023-10-25 15:47:17 -07:00
|
|
|
|
db_xs = check_device_backend_on_shardings([xs])
|
|
|
|
|
if not db_xs:
|
|
|
|
|
xs = getattr(xs, '_original_sharding', xs)
|
|
|
|
|
|
2023-08-04 09:43:39 -07:00
|
|
|
|
# Raise memory kind mismatch error even if the arg is uncommitted.
|
2023-08-04 16:26:31 -07:00
|
|
|
|
if arg.sharding.memory_kind != xs.memory_kind:
|
2023-08-04 09:43:39 -07:00
|
|
|
|
errors.append(
|
2023-10-25 15:47:17 -07:00
|
|
|
|
"Got input sharding(s) that compiled object was called with: "
|
|
|
|
|
f"{arg.sharding} and sharding(s) the computation was compiled "
|
|
|
|
|
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
|
2023-08-04 09:43:39 -07:00
|
|
|
|
|
2023-10-25 15:47:17 -07:00
|
|
|
|
if (not db_xs and arg._committed and
|
2023-04-06 08:31:47 -07:00
|
|
|
|
not op_shardings.are_op_shardings_equal(
|
2023-06-05 13:40:59 -07:00
|
|
|
|
arg.sharding._to_xla_hlo_sharding(arg.ndim),
|
|
|
|
|
xs._to_xla_hlo_sharding(arg.ndim))):
|
2023-04-19 12:35:15 -07:00
|
|
|
|
errors.append(
|
2023-10-25 15:47:17 -07:00
|
|
|
|
"Got input sharding(s) that compiled object was called with: "
|
|
|
|
|
f"{arg.sharding} and sharding(s) the computation was compiled "
|
|
|
|
|
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
|
2023-04-19 12:35:15 -07:00
|
|
|
|
|
|
|
|
|
if errors:
|
2023-04-19 15:08:21 -07:00
|
|
|
|
str_errors = '\n'.join(errors[:num_errors])
|
|
|
|
|
num_mismatch_str = (
|
|
|
|
|
f'the {len(errors)} mismatches' if len(errors) < num_errors else
|
|
|
|
|
f"{num_errors} mismatches out of {len(errors)}")
|
2023-04-19 12:35:15 -07:00
|
|
|
|
raise ValueError(
|
2023-10-25 15:47:17 -07:00
|
|
|
|
"Compiled object called with input sharding(s) does not match the "
|
|
|
|
|
"sharding(s) the computation was compiled with. "
|
2023-04-19 15:08:21 -07:00
|
|
|
|
f"Here are {num_mismatch_str}:\n{str_errors}")
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
2023-02-07 11:16:01 -08:00
|
|
|
|
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
2023-04-10 10:15:08 -07:00
|
|
|
|
parsed_pspec, _, _ = sharding_impls.prepare_axis_resources(
|
|
|
|
|
pspec, "pspec to array_mapping")
|
2023-02-07 11:16:01 -08:00
|
|
|
|
return _get_array_mapping(parsed_pspec)
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_forbidden_primitives = {
|
|
|
|
|
'xla_pmap': 'pmap',
|
|
|
|
|
}
|
|
|
|
|
def _sanitize_mesh_jaxpr(jaxpr):
|
|
|
|
|
if isinstance(jaxpr, core.ClosedJaxpr):
|
|
|
|
|
jaxpr = jaxpr.jaxpr
|
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
|
if eqn.primitive.name in _forbidden_primitives:
|
|
|
|
|
raise RuntimeError(f"Nesting {_forbidden_primitives[eqn.primitive.name]} "
|
|
|
|
|
f"inside xmaps not supported!")
|
|
|
|
|
core.traverse_jaxpr_params(_sanitize_mesh_jaxpr, eqn.params)
|
|
|
|
|
|
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
|
custom_resource_typing_rules: dict[core.Primitive, Callable] = {}
|
2023-02-06 14:28:36 -08:00
|
|
|
|
|
|
|
|
|
def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
|
|
|
|
|
if isinstance(jaxpr, core.ClosedJaxpr):
|
|
|
|
|
jaxpr = jaxpr.jaxpr
|
|
|
|
|
def _check_aval(aval, what_thunk):
|
|
|
|
|
if not hasattr(aval, 'named_shape'):
|
|
|
|
|
return
|
|
|
|
|
resource_to_axis = {}
|
|
|
|
|
for axis in aval.named_shape:
|
|
|
|
|
if axis_resources:
|
|
|
|
|
for resource in axis_resources[axis]:
|
|
|
|
|
if resource in resource_to_axis:
|
|
|
|
|
other_axis = resource_to_axis[resource]
|
|
|
|
|
axis, other_axis = sorted([str(axis), str(other_axis)])
|
|
|
|
|
raise JAXTypeError(
|
|
|
|
|
f"Axes `{axis}` and `{other_axis}` are both mapped to the "
|
|
|
|
|
f"resource `{resource}`, but they coincide in the named_shape "
|
|
|
|
|
f"of {what_thunk()}")
|
|
|
|
|
resource_to_axis[resource] = axis
|
|
|
|
|
|
|
|
|
|
what_thunk = lambda: (f"an input to {what_jaxpr_thunk()}")
|
|
|
|
|
for v in jaxpr.constvars:
|
|
|
|
|
_check_aval(v.aval, what_thunk)
|
|
|
|
|
for v in jaxpr.invars:
|
|
|
|
|
_check_aval(v.aval, what_thunk)
|
|
|
|
|
what_thunk = lambda: (f"a value returned from a primitive {eqn.primitive} created "
|
|
|
|
|
f"at {source_info_util.summarize(eqn.source_info)}")
|
|
|
|
|
rec_what_jaxpr_thunk = lambda: (f"a primitive {eqn.primitive} created at"
|
|
|
|
|
f"{source_info_util.summarize(eqn.source_info)}")
|
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
|
typing_rule = custom_resource_typing_rules.get(eqn.primitive, None)
|
|
|
|
|
if typing_rule:
|
|
|
|
|
typing_rule([v.aval for v in eqn.invars], eqn.params, eqn.source_info,
|
|
|
|
|
resource_env, axis_resources)
|
|
|
|
|
else:
|
|
|
|
|
core.traverse_jaxpr_params(partial(resource_typecheck,
|
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
axis_resources=axis_resources,
|
|
|
|
|
what_jaxpr_thunk=rec_what_jaxpr_thunk),
|
|
|
|
|
eqn.params)
|
|
|
|
|
for v in eqn.outvars:
|
|
|
|
|
_check_aval(v.aval, what_thunk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def maybe_extend_axis_env(*args, **kwargs):
|
|
|
|
|
with core.extend_axis_env(*args, **kwargs):
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
2023-03-31 11:41:49 -07:00
|
|
|
|
def device_put(x, devices: Sequence[xc.ArrayImpl],
|
2023-06-23 15:11:37 -07:00
|
|
|
|
replicate: bool=False) -> list[xc.ArrayImpl]:
|
2023-02-06 14:28:36 -08:00
|
|
|
|
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
|
|
|
|
|
if replicate:
|
2023-03-16 15:46:57 -07:00
|
|
|
|
return [jax.device_put(x, device) for device in devices]
|
2023-02-06 14:28:36 -08:00
|
|
|
|
else:
|
2023-03-16 15:46:57 -07:00
|
|
|
|
return [jax.device_put(val, device) for val, device in safe_zip(x, devices)]
|