mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
1863 lines
84 KiB
Python
1863 lines
84 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import contextlib
|
|
import numpy as np
|
|
import itertools as it
|
|
from collections import OrderedDict, abc
|
|
from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
|
|
NamedTuple, Union, Sequence, Mapping)
|
|
from functools import wraps, partial, partialmethod, lru_cache
|
|
|
|
from jax import lax
|
|
from jax import numpy as jnp
|
|
|
|
from jax._src import core
|
|
from jax._src import dispatch
|
|
from jax._src import effects
|
|
from jax._src import mesh as mesh_lib
|
|
from jax._src import linear_util as lu
|
|
from jax._src import op_shardings
|
|
from jax._src import sharding_impls
|
|
from jax._src import source_info_util
|
|
from jax._src import stages
|
|
from jax._src import traceback_util
|
|
from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes,
|
|
_ensure_index_tuple, donation_vector,
|
|
shaped_abstractify, check_callable)
|
|
from jax._src.array import ArrayImpl
|
|
from jax._src.config import config
|
|
from jax._src.errors import JAXTypeError
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.interpreters.partial_eval import (
|
|
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
|
|
convert_constvars_jaxpr, new_jaxpr_eqn)
|
|
from jax._src.interpreters import pxla
|
|
from jax._src.interpreters import xla
|
|
from jax._src.pjit import (sharding_constraint_p, get_unconstrained_dims,
|
|
GSPMDSharding)
|
|
from jax._src.sharding_impls import (
|
|
ArrayMapping, NamedSharding, ParsedPartitionSpec,
|
|
array_mapping_to_axis_resources)
|
|
from jax._src.tree_util import (tree_flatten, tree_unflatten, all_leaves,
|
|
tree_map, treedef_tuple)
|
|
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
|
|
as_hashable_function, distributed_debug_log,
|
|
tuple_insert, moveaxis, split_list, wrap_name,
|
|
merge_lists, partition_list)
|
|
|
|
source_info_util.register_exclusion(__file__)
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip = safe_zip
|
|
|
|
|
|
class FrozenDict(abc.Mapping):
|
|
def __init__(self, *args, **kwargs):
|
|
self.contents = dict(*args, **kwargs)
|
|
|
|
def __iter__(self):
|
|
return iter(self.contents)
|
|
|
|
def __len__(self):
|
|
return len(self.contents)
|
|
|
|
def __getitem__(self, name):
|
|
return self.contents[name]
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, FrozenDict) and self.contents == other.contents
|
|
|
|
def __hash__(self):
|
|
return hash(tuple(self.contents.items()))
|
|
|
|
def __repr__(self):
|
|
return f"FrozenDict({self.contents})"
|
|
|
|
# Multi-dimensional generalized map
|
|
|
|
AxisName = core.AxisName
|
|
ResourceAxisName = mesh_lib.ResourceAxisName # Different name just for documentation purposes
|
|
Mesh = mesh_lib.Mesh
|
|
MeshAxisName = mesh_lib.MeshAxisName
|
|
ResourceEnv = mesh_lib.ResourceEnv
|
|
EMPTY_ENV = mesh_lib.EMPTY_ENV
|
|
thread_resources = mesh_lib.thread_resources
|
|
|
|
|
|
class SerialLoop:
|
|
"""Create an anonymous serial loop resource for use in a single xmap axis.
|
|
|
|
A use of :py:class:`SerialLoop` in :py:func:`xmap`'s ``axis_resources``
|
|
extends the resource environment with a new serial loop with a unique
|
|
unspecified name, that will only be used to partition the axis that
|
|
used a given instance.
|
|
|
|
This is unlike :py:func:`serial_loop`, which makes it possible to iterate
|
|
jointly over chunks of multiple axes (with the usual requirement that they
|
|
do not coincide in a named shape of any value in the program).
|
|
|
|
Example::
|
|
|
|
# Processes `x` in a vectorized way, but in 20 micro-batches.
|
|
xmap(f, in_axes=['i'], out_axes=[i], axis_resources={'i': SerialLoop(20)})(x)
|
|
|
|
# Computes the result in a vectorized way, but in 400 micro-batches,
|
|
# once for each coordinate (0, 0) <= (i, j) < (20, 20). Each `SerialLoop`
|
|
# creates a fresh anonymous loop.
|
|
xmap(h, in_axes=(['i'], ['j']), out_axes=['i', 'j'],
|
|
axis_resources={'i': SerialLoop(20), 'j': SerialLoop(20)})(x, y)
|
|
"""
|
|
length: int
|
|
|
|
def __init__(self, length):
|
|
self.length = length
|
|
|
|
def __eq__(self, other):
|
|
return self.length == other.length
|
|
|
|
def __hash__(self):
|
|
return hash(self.length)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def serial_loop(name: ResourceAxisName, length: int):
|
|
"""Define a serial loop resource to be available in scope of this context manager.
|
|
|
|
This is similar to :py:class:`Mesh` in that it extends the resource
|
|
environment with a resource called ``name``. But, any use of this resource
|
|
axis in ``axis_resources`` argument of :py:func:`xmap` will cause the
|
|
body of :py:func:`xmap` to get executed ``length`` times with each execution
|
|
only processing only a slice of inputs mapped along logical axes assigned
|
|
to this resource.
|
|
|
|
This is especially useful in that it makes it possible to lower the memory
|
|
usage compared to :py:func:`vmap`, because it will avoid simultaneous
|
|
materialization of intermediate values for every point in the batch.
|
|
|
|
Note that collectives over loop axes are not supported, so they are less
|
|
versatile than physical mesh axes.
|
|
|
|
Args:
|
|
name: Name of the loop in the resource environment.
|
|
length: Number of iterations.
|
|
|
|
Example::
|
|
|
|
with loop('l', 4):
|
|
out = xmap(
|
|
lambda x: jnp.sin(x) * 5, # This will be called 4 times with different
|
|
# slices of x.
|
|
in_axes=['i'], out_axes=['i'],
|
|
axis_resources={'i': 'l'})(x)
|
|
"""
|
|
old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
|
|
thread_resources.env = old_env.with_extra_loop(mesh_lib.Loop(name, length))
|
|
try:
|
|
yield
|
|
finally:
|
|
thread_resources.env = old_env
|
|
|
|
|
|
_next_resource_id = 0
|
|
class _UniqueResourceName:
|
|
def __init__(self, uid, tag=None):
|
|
self.uid = uid
|
|
self.tag = tag
|
|
def __eq__(self, other):
|
|
return type(other) is _UniqueResourceName and self.uid == other.uid
|
|
def __hash__(self):
|
|
return hash(self.uid)
|
|
def __repr__(self):
|
|
return f"<UniqueResource {self.tag} {self.uid}>"
|
|
|
|
def fresh_resource_name(tag=None):
|
|
global _next_resource_id
|
|
try:
|
|
return _UniqueResourceName(_next_resource_id, tag)
|
|
finally:
|
|
_next_resource_id += 1
|
|
|
|
|
|
# This is really a Dict[AxisName, int], but we don't define a
|
|
# pytree instance for it, so that it is treated as a leaf.
|
|
class AxisNamePos(FrozenDict):
|
|
user_repr: str
|
|
expected_rank: Optional[int] = None
|
|
|
|
def __init__(self, *args, user_repr, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.user_repr = user_repr
|
|
|
|
class AxisNamePosWithRank(AxisNamePos):
|
|
def __init__(self, *args, expected_rank, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.expected_rank = expected_rank
|
|
|
|
|
|
# str(...) == 'Ellipsis' which is really annoying
|
|
class DotDotDotRepr:
|
|
def __repr__(self): return '...'
|
|
|
|
|
|
def _parse_entry(arg_name, entry):
|
|
# Dictionaries mapping axis names to positional axes
|
|
if isinstance(entry, dict) and all(isinstance(v, int) for v in entry.keys()):
|
|
result = AxisNamePos(((name, axis) for axis, name in entry.items()),
|
|
user_repr=str(entry))
|
|
num_mapped_dims = len(entry)
|
|
# Non-empty lists or tuples that optionally terminate with an ellipsis
|
|
elif isinstance(entry, (tuple, list)):
|
|
if entry and entry[-1] == ...:
|
|
constr = AxisNamePos
|
|
entry = entry[:-1]
|
|
tail = [DotDotDotRepr()] if isinstance(entry, list) else (DotDotDotRepr(),)
|
|
user_repr = str(entry + tail)
|
|
else:
|
|
constr = partial(AxisNamePosWithRank, expected_rank=len(entry))
|
|
user_repr = str(entry)
|
|
result = constr(((name, axis) for axis, name in enumerate(entry)
|
|
if name is not None),
|
|
user_repr=user_repr)
|
|
num_mapped_dims = sum(name is not None for name in entry)
|
|
else:
|
|
raise TypeError(f"""\
|
|
Value mapping specification in xmap {arg_name} pytree can be either:
|
|
- lists of axis names (possibly ending with the ellipsis object: ...)
|
|
- dictionaries that map positional axes (integers) to axis names (e.g. {{2: 'name'}})
|
|
but got: {entry}""")
|
|
if len(result) != num_mapped_dims:
|
|
raise ValueError(f"Named axes should be unique within each {arg_name} argument "
|
|
f"specification, but one them is: {entry}")
|
|
for axis in result.values():
|
|
if axis < 0:
|
|
raise ValueError(f"xmap doesn't support negative axes in {arg_name}")
|
|
return result
|
|
|
|
def _is_axes_leaf(entry):
|
|
if isinstance(entry, dict) and all_leaves(entry.values()):
|
|
return True
|
|
# NOTE: `None`s are not considered leaves by `all_leaves`
|
|
if isinstance(entry, (tuple, list)) and all_leaves(v for v in entry if v is not None):
|
|
return True
|
|
return False
|
|
|
|
def _prepare_axes(axes, arg_name):
|
|
entries, treedef = tree_flatten(axes, is_leaf=_is_axes_leaf)
|
|
entries = map(partial(_parse_entry, arg_name), entries)
|
|
return tree_unflatten(treedef, entries), entries, treedef
|
|
|
|
Resource = Union[ResourceAxisName, SerialLoop]
|
|
ResourceSet = Union[Resource, Tuple[Resource, ...]]
|
|
|
|
# TODO: Some syntactic sugar to make the API more usable in a single-axis case?
|
|
# TODO: Are the resource axes scoped lexically or dynamically? Dynamically for now!
|
|
def xmap(fun: Callable,
|
|
in_axes,
|
|
out_axes,
|
|
*,
|
|
axis_sizes: Optional[Mapping[AxisName, int]] = None,
|
|
axis_resources: Optional[Mapping[AxisName, ResourceSet]] = None,
|
|
donate_argnums: Union[int, Sequence[int]] = (),
|
|
backend: Optional[str] = None) -> stages.Wrapped:
|
|
"""Assign a positional signature to a program that uses named array axes.
|
|
|
|
.. warning::
|
|
This is an experimental feature and the details can change at
|
|
any time. Use at your own risk!
|
|
|
|
.. warning::
|
|
This docstring is aspirational. Not all features of the named axis
|
|
programming model have been implemented just yet.
|
|
|
|
The usual programming model of JAX (or really NumPy) associates each array
|
|
with two pieces of metadata describing its type: the element type (``dtype``)
|
|
and the ``shape``. :py:func:`xmap` extends this model by adding support for
|
|
*named axes*. In particular, each array used in a function wrapped by
|
|
:py:func:`xmap` can additionally have a non-empty ``named_shape`` attribute,
|
|
which can be used to query the set of named axes (introduced by
|
|
:py:func:`xmap`) appearing in that value along with their shapes.
|
|
Furthermore, in most places where positional axis indices are allowed (for
|
|
example the `axes` arguments in :py:func:`sum`), bound axis names are also
|
|
accepted. The :py:func:`einsum` language is extended inside :py:func:`xmap`
|
|
to additionally allow contractions that involve named axes. Broadcasting of
|
|
named axes happens *by name*, i.e. all axes with equal names are expected to
|
|
have equal shapes in all arguments of a broadcasting operation, while the
|
|
result has a (set) union of all named axes. The positional semantics of the
|
|
program remain unchanged, and broadcasting still implicitly right-aligns
|
|
positional axes for unification. For an extended description of the
|
|
:py:func:`xmap` programming model, please refer to the :py:func:`xmap`
|
|
tutorial notebook in main JAX documentation.
|
|
|
|
Note that since all top-level JAX expressions are interpreted in the NumPy
|
|
programming model, :py:func:`xmap` can also be seen as an adapter that
|
|
converts a function that uses named axes (including in arguments and returned
|
|
values) into one that takes and returns values that only have positional
|
|
axes.
|
|
|
|
The default lowering strategy of :py:func:`xmap` converts all named axes into
|
|
positional axes, working similarly to multiple applications of
|
|
:py:func:`vmap`. However, this behavior can be further customized by the
|
|
``axis_resources`` argument. When specified, each axis introduced by
|
|
:py:func:`xmap` can be assigned to one or more *resource axes*. Those include
|
|
the axes of the hardware mesh, as defined by the :py:class:`Mesh` context
|
|
manager. Each value that has a named axis in its ``named_shape`` will be
|
|
partitioned over all mesh axes that axis is assigned to. Hence,
|
|
:py:func:`xmap` can be seen as an alternative to :py:func:`pmap` that also
|
|
exposes a way to automatically partition the computation over multiple
|
|
devices.
|
|
|
|
.. warning::
|
|
While it is possible to assign multiple axis names to a single resource axis,
|
|
care has to be taken to ensure that none of those named axes co-occur in a
|
|
``named_shape`` of any value in the named program. At the moment this is
|
|
**completely unchecked** and will result in **undefined behavior**. The
|
|
final release of :py:func:`xmap` will enforce this invariant, but it is a
|
|
work in progress.
|
|
|
|
Note that you do not have to worry about any of this for as long as no
|
|
resource axis is repeated in ``axis_resources.values()``.
|
|
|
|
Note that any assignment of ``axis_resources`` doesn't ever change the
|
|
results of the computation, but only how it is carried out (e.g. how many
|
|
devices are used). This makes it easy to try out various ways of
|
|
partitioning a single program in many distributed scenarios (both small- and
|
|
large-scale), to maximize the performance. As such, :py:func:`xmap` can be
|
|
seen as a way to seamlessly interpolate between :py:func:`vmap` and
|
|
:py:func:`pmap`-style execution.
|
|
|
|
Args:
|
|
fun: Function that uses named axes. Its arguments and return
|
|
value should be arrays, scalars, or (nested) standard Python containers
|
|
(tuple/list/dict) thereof (in general: valid pytrees).
|
|
in_axes: A Python object with the same container (pytree) structure as the
|
|
signature of arguments to ``fun``, but with a positional-to-named axis
|
|
mapping in place of every array argument. The valid positional-to-named
|
|
mappings are: (1) a ``Dict[int, AxisName]`` specifying that a positional
|
|
dimensions given by dictionary keys are to be converted to named axes
|
|
of given names (2) a list of axis names that ends with the Ellipsis object
|
|
(``...``) in which case a number of leading positional axes of the argument
|
|
will be converted into named axes inside the function. Note that ``in_axes``
|
|
can also be a prefix of the argument container structure, in which case the
|
|
mapping is repeated for all arrays in the collapsed subtree.
|
|
out_axes: A Python object with the same container (pytree) structure as the
|
|
returns of ``fun``, but with a positional-to-named axis mapping in place
|
|
of every returned array. The valid positional-to-named mappings are the same
|
|
as in ``in_axes``. Note that ``out_axes`` can also be a prefix of the return
|
|
container structure, in which case the mapping is repeated for all arrays
|
|
in the collapsed subtree.
|
|
axis_sizes: A dict mapping axis names to their sizes. All axes defined by xmap
|
|
have to appear either in ``in_axes`` or ``axis_sizes``. Sizes of axes
|
|
that appear in ``in_axes`` are inferred from arguments whenever possible.
|
|
In multi-host scenarios, the user-specified sizes are expected to be the
|
|
global axis sizes (and might not match the expected size of local inputs).
|
|
axis_resources: A dictionary mapping the axes introduced in this
|
|
:py:func:`xmap` to one or more resource axes. Any array that has in its
|
|
shape an axis with some resources assigned will be partitioned over the
|
|
resources associated with the respective resource axes.
|
|
donate_argnums: Specify which argument buffers are "donated" to the computation.
|
|
It is safe to donate argument buffers if you no longer need them once the
|
|
computation has finished. In some cases XLA can make use of donated
|
|
buffers to reduce the amount of memory needed to perform a computation,
|
|
for example recycling one of your input buffers to store a result. You
|
|
should not reuse buffers that you donate to a computation, JAX will raise
|
|
an error if you try to.
|
|
|
|
For more details on buffer donation see the `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
|
|
|
backend: This is an experimental feature and the API is likely to change.
|
|
Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
|
|
|
|
Returns:
|
|
A version of ``fun`` that takes in arrays with positional axes in place of
|
|
named axes bound in this :py:func:`xmap` call, and results with all named
|
|
axes converted to positional axes. If ``axis_resources`` is specified,
|
|
``fun`` can additionally execute in parallel on multiple devices.
|
|
|
|
For example, :py:func:`xmap` makes it very easy to convert a function that
|
|
computes the vector inner product (such as :py:func:`jax.numpy.vdot`) into
|
|
one that computes a matrix multiplication:
|
|
|
|
>>> import jax.numpy as jnp
|
|
>>> x = jnp.arange(10).reshape((2, 5))
|
|
>>> xmap(jnp.vdot,
|
|
... in_axes=({0: 'left'}, {1: 'right'}),
|
|
... out_axes=['left', 'right', ...])(x, x.T)
|
|
Array([[ 30, 80],
|
|
[ 80, 255]], dtype=int32)
|
|
|
|
Note that the contraction in the program is performed over the positional axes,
|
|
while named axes are just a convenient way to achieve batching. While this
|
|
might seem like a silly example at first, it might turn out to be useful in
|
|
practice, since with conjunction with ``axis_resources`` this makes it possible
|
|
to implement a distributed matrix-multiplication in just a few lines of code::
|
|
|
|
devices = np.array(jax.devices())[:4].reshape((2, 2))
|
|
with Mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
|
|
distributed_out = xmap(
|
|
jnp.vdot,
|
|
in_axes=({0: 'left'}, {1: 'right'}),
|
|
out_axes=['left', 'right', ...],
|
|
axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
|
|
|
|
Still, the above examples are quite simple. After all, the xmapped
|
|
computation was a simple NumPy function that didn't use the axis names at all!
|
|
So, let's explore a slightly larger example which is linear regression::
|
|
|
|
def regression_loss(x, y, w, b):
|
|
# Contract over in_features. Batch and out_features are present in
|
|
# both inputs and output, so they don't need to be mentioned
|
|
y_pred = jnp.einsum('{in_features},{in_features}->{}', x, w) + b
|
|
error = jnp.sum((y - y_pred) ** 2, axis='out_features')
|
|
return jnp.mean(error, axis='batch')
|
|
|
|
xmap(regression_loss,
|
|
in_axes=(['batch', 'in_features', ...],
|
|
['batch', 'out_features', ...],
|
|
['in_features', 'out_features', ...],
|
|
['out_features', ...]),
|
|
out_axes={}) # Loss is reduced over all axes, including batch!
|
|
|
|
.. note::
|
|
When using ``axis_resources`` along with a mesh that is controlled by
|
|
multiple JAX hosts, keep in mind that in any given process :py:func:`xmap`
|
|
only expects the data slice that corresponds to its local devices to be
|
|
specified. This is in line with the current multi-host :py:func:`pmap`
|
|
programming model.
|
|
"""
|
|
check_callable(fun)
|
|
|
|
if isinstance(in_axes, list) and not _is_axes_leaf(in_axes):
|
|
# To be a tree prefix of the positional args tuple, in_axes can never be a
|
|
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
|
|
# in cases like these users expect tuples and lists to be treated
|
|
# essentially interchangeably, so we canonicalize lists to tuples here
|
|
# rather than raising an error. https://github.com/google/jax/issues/2367
|
|
in_axes = tuple(in_axes)
|
|
|
|
if in_axes == (): # Allow empty argument lists
|
|
in_axes, in_axes_entries = (), []
|
|
else:
|
|
in_axes, in_axes_entries, _ = _prepare_axes(in_axes, "in_axes")
|
|
if out_axes == ():
|
|
raise ValueError("xmapped functions cannot have no return values")
|
|
else:
|
|
out_axes, out_axes_entries, out_axes_treedef = _prepare_axes(out_axes, "out_axes")
|
|
out_axes_entries = tuple(out_axes_entries) # Make entries hashable
|
|
|
|
axis_sizes = {} if axis_sizes is None else axis_sizes
|
|
axis_resources = {} if axis_resources is None else axis_resources
|
|
|
|
axis_sizes_names = set(axis_sizes.keys())
|
|
in_axes_names = set(it.chain(*(spec.keys() for spec in in_axes_entries)))
|
|
defined_names = axis_sizes_names | in_axes_names
|
|
out_axes_names = set(it.chain(*(spec.keys() for spec in out_axes_entries)))
|
|
|
|
anon_serial_loops = []
|
|
def normalize_resource(r) -> ResourceAxisName:
|
|
if isinstance(r, SerialLoop):
|
|
name = fresh_resource_name()
|
|
anon_serial_loops.append((name, r.length))
|
|
return name
|
|
return r
|
|
|
|
axes_with_resources = set(axis_resources.keys())
|
|
if axes_with_resources - defined_names:
|
|
raise ValueError(f"All axes that were assigned resources have to appear in "
|
|
f"in_axes or axis_sizes, but the following are missing: "
|
|
f"{axes_with_resources - defined_names}")
|
|
if out_axes_names - defined_names:
|
|
raise ValueError(f"All axis names appearing in out_axes must also appear in "
|
|
f"in_axes or axis_sizes, but the following are missing: "
|
|
f"{out_axes_names - defined_names}")
|
|
|
|
normalized_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]] = {}
|
|
for axis in defined_names:
|
|
resources = axis_resources.get(axis, ())
|
|
if not isinstance(resources, tuple):
|
|
resources = (resources,)
|
|
normalized_axis_resources[axis] = tuple(unsafe_map(normalize_resource, resources))
|
|
frozen_axis_resources = FrozenDict(normalized_axis_resources)
|
|
necessary_resources = set(it.chain(*frozen_axis_resources.values()))
|
|
|
|
for axis, resources in frozen_axis_resources.items():
|
|
if len(set(resources)) != len(resources): # type: ignore
|
|
raise ValueError(f"Resource assignment of a single axis must be a tuple of "
|
|
f"distinct resources, but specified {resources} for axis {axis}")
|
|
|
|
donate_argnums = _ensure_index_tuple(donate_argnums)
|
|
|
|
# A little performance optimization to avoid iterating over all args unnecessarily
|
|
has_input_rank_assertions = any(spec.expected_rank is not None for spec in in_axes_entries)
|
|
has_output_rank_assertions = any(spec.expected_rank is not None for spec in out_axes_entries)
|
|
|
|
def infer_params(*args):
|
|
# Putting this outside of fun_mapped would make resources lexically scoped
|
|
resource_env = thread_resources.env
|
|
available_resources = set(resource_env.shape.keys())
|
|
|
|
if necessary_resources - available_resources:
|
|
raise ValueError(f"In-scope resources are insufficient to execute the "
|
|
f"xmapped function. The missing resources are: "
|
|
f"{necessary_resources - available_resources}")
|
|
|
|
args_flat, in_tree = tree_flatten(args)
|
|
fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
|
if donate_argnums:
|
|
donated_invars = donation_vector(donate_argnums, args, ())
|
|
else:
|
|
donated_invars = (False,) * len(args_flat)
|
|
in_axes_flat = _flatten_axes("xmap in_axes", in_tree, in_axes, tupled_args=True)
|
|
|
|
# Some pytree containers might be unhashable, so we flatten the out_axes
|
|
# pytree into a treedef and entries which are guaranteed to be hashable.
|
|
out_axes_thunk = HashableFunction(
|
|
lambda: tuple(_flatten_axes("xmap out_axes", out_tree(), out_axes, tupled_args=False)),
|
|
closure=(out_axes_entries, out_axes_treedef))
|
|
|
|
axis_resource_count = _get_axis_resource_count(
|
|
frozen_axis_resources, resource_env)
|
|
|
|
for axis, size in axis_sizes.items():
|
|
resources = axis_resource_count[axis]
|
|
if size % resources.nglobal != 0:
|
|
global_size = "Global size" if resources.distributed else "Size"
|
|
raise ValueError(f"{global_size} of axis {axis} ({size}) is not divisible "
|
|
f"by the total number of resources assigned to this axis "
|
|
f"({frozen_axis_resources[axis]}, {resources.nglobal} in total)")
|
|
frozen_global_axis_sizes = _get_axis_sizes(
|
|
args_flat, in_axes_flat, axis_sizes, axis_resource_count)
|
|
|
|
missing_sizes = defined_names - set(frozen_global_axis_sizes.keys())
|
|
if missing_sizes:
|
|
raise ValueError(f"Failed to infer size of axes: {', '.join(unsafe_map(str, missing_sizes))}. "
|
|
f"You've probably passed in empty containers in place of arguments that had "
|
|
f"those axes in their in_axes. Provide the sizes of missing axes explicitly "
|
|
f"via axis_sizes to fix this error.")
|
|
|
|
if has_input_rank_assertions:
|
|
for arg, spec in zip(args_flat, in_axes_flat):
|
|
if spec.expected_rank is not None and spec.expected_rank != arg.ndim:
|
|
raise ValueError(f"xmap argument has an in_axes specification of {spec.user_repr}, "
|
|
f"which asserts that it should be of rank {spec.expected_rank}, "
|
|
f"but the argument has rank {arg.ndim} (and shape {arg.shape})")
|
|
|
|
_check_gda_or_array_xmap_partitioning(
|
|
frozen_axis_resources, resource_env, frozen_global_axis_sizes,
|
|
in_axes_flat, args_flat)
|
|
|
|
params = dict(
|
|
name=getattr(fun, '__name__', '<unnamed function>'),
|
|
in_axes=tuple(in_axes_flat),
|
|
out_axes_thunk=out_axes_thunk,
|
|
donated_invars=donated_invars,
|
|
global_axis_sizes=frozen_global_axis_sizes,
|
|
axis_resources=frozen_axis_resources,
|
|
resource_env=resource_env,
|
|
backend=backend,
|
|
spmd_in_axes=None,
|
|
spmd_out_axes_thunk=None)
|
|
return fun_flat, args_flat, params, in_tree, out_tree
|
|
|
|
def verify_outputs(out_flat, out_tree, params):
|
|
if has_output_rank_assertions:
|
|
for out, spec in zip(out_flat, params['out_axes_thunk']()):
|
|
if spec.expected_rank is not None and spec.expected_rank != out.ndim:
|
|
raise ValueError(f"xmap output has an out_axes specification of {spec.user_repr}, "
|
|
f"which asserts that it should be of rank {spec.expected_rank}, "
|
|
f"but the output has rank {out.ndim} (and shape {out.shape})")
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
|
|
def decorate_serial(f):
|
|
for loop_params in reversed(anon_serial_loops):
|
|
f = serial_loop(*loop_params)(f)
|
|
return f
|
|
|
|
@wraps(fun)
|
|
@decorate_serial
|
|
def fun_mapped(*args):
|
|
tree_map(dispatch.check_arg, args)
|
|
fun_flat, args_flat, params, _, out_tree = infer_params(*args)
|
|
out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
|
|
return verify_outputs(out_flat, out_tree, params)
|
|
|
|
@decorate_serial
|
|
def lower(*args, _experimental_lowering_platform: Optional[str] = None):
|
|
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
|
|
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
|
|
computation = make_xmap_callable(
|
|
fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'],
|
|
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
|
|
params['resource_env'], params['backend'], params['spmd_in_axes'],
|
|
params['spmd_out_axes_thunk'],
|
|
_experimental_lowering_platform, *avals_flat)
|
|
|
|
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
|
in_avals = in_tree.unflatten(avals_flat)
|
|
return stages.Lowered.from_flat_info(
|
|
computation, in_tree, in_avals, donate_argnums, out_tree(), # type: ignore
|
|
no_kwargs=True)
|
|
|
|
fun_mapped.lower = lower
|
|
return fun_mapped
|
|
|
|
def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_invars,
|
|
global_axis_sizes, axis_resources, resource_env, backend,
|
|
spmd_in_axes, spmd_out_axes_thunk):
|
|
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
|
|
xmap_callable = make_xmap_callable(
|
|
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
|
|
axis_resources, resource_env, backend,
|
|
spmd_in_axes, spmd_out_axes_thunk,
|
|
None, *in_avals).compile().unsafe_call
|
|
distributed_debug_log(("Running xmapped function", name),
|
|
("python function", fun.f),
|
|
("mesh", resource_env.physical_mesh),
|
|
("abstract args", in_avals))
|
|
return xmap_callable(*args)
|
|
|
|
@lu.cache
|
|
def make_xmap_callable(fun: lu.WrappedFun,
|
|
name,
|
|
in_axes, out_axes_thunk, donated_invars,
|
|
global_axis_sizes, axis_resources, resource_env, backend,
|
|
spmd_in_axes, spmd_out_axes_thunk,
|
|
lowering_platform: Optional[str],
|
|
*in_avals):
|
|
plan = EvaluationPlan.from_axis_resources(
|
|
axis_resources, resource_env, global_axis_sizes)
|
|
|
|
# TODO: Making axis substitution final style would allow us to avoid
|
|
# tracing to jaxpr here
|
|
mapped_in_avals = [_delete_aval_axes(aval, in_axes, global_axis_sizes)
|
|
for aval, in_axes in zip(in_avals, in_axes)]
|
|
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
|
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
|
"for xmap in {elapsed_time} sec",
|
|
event=dispatch.JAXPR_TRACE_EVENT):
|
|
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
|
|
out_axes = out_axes_thunk()
|
|
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
|
|
# NOTE: We don't use avals and all params, so only pass in the relevant parts (too lazy...)
|
|
_resource_typing_xmap([], dict(axis_resources=axis_resources,
|
|
out_axes=out_axes,
|
|
call_jaxpr=jaxpr,
|
|
resource_env=resource_env,
|
|
name=name),
|
|
source_info_util.new_source_info(), resource_env, {})
|
|
jaxpr = plan.subst_axes_with_resources(jaxpr)
|
|
use_spmd_lowering = config.experimental_xmap_spmd_lowering
|
|
ensure_fixed_sharding = config.experimental_xmap_ensure_fixed_sharding
|
|
if use_spmd_lowering and ensure_fixed_sharding:
|
|
jaxpr = _fix_inferred_spmd_sharding(jaxpr, resource_env)
|
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts)))
|
|
f = hide_mapped_axes(f, tuple(in_axes), tuple(out_axes))
|
|
f = plan.vectorize_and_loop(f, in_axes, out_axes)
|
|
|
|
used_resources = _jaxpr_resources(jaxpr, resource_env) | set(it.chain(*axis_resources.values()))
|
|
used_mesh_axes = used_resources & resource_env.physical_resource_axes
|
|
if used_mesh_axes:
|
|
assert spmd_in_axes is None and spmd_out_axes_thunk is None # No outer xmaps, so should be None
|
|
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
|
|
mesh = resource_env.physical_mesh
|
|
tiling_method: pxla.TilingMethod
|
|
if config.experimental_xmap_spmd_lowering_manual:
|
|
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
|
tiling_method = pxla.TileManual(manual_mesh_axes)
|
|
else:
|
|
tiling_method = pxla.TileVectorize()
|
|
in_shardings = [NamedSharding(mesh, array_mapping_to_axis_resources(i))
|
|
for i in mesh_in_axes]
|
|
out_shardings = [NamedSharding(mesh, array_mapping_to_axis_resources(o))
|
|
for o in mesh_out_axes]
|
|
return pxla.lower_mesh_computation(
|
|
f, 'xmap', name, mesh,
|
|
in_shardings, out_shardings, donated_invars,
|
|
use_spmd_lowering, in_avals,
|
|
tiling_method=tiling_method,
|
|
lowering_platform=lowering_platform)
|
|
else:
|
|
return dispatch.sharded_lowering(
|
|
f, name, donated_invars, True, in_avals, (None,) * len(in_avals),
|
|
lowering_platform=lowering_platform)
|
|
|
|
|
|
class EvaluationPlan(NamedTuple):
|
|
"""Encapsulates preprocessing common to top-level xmap invocations and its translation rule."""
|
|
resource_env: ResourceEnv
|
|
physical_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]
|
|
loop_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]
|
|
axis_subst_dict: Dict[AxisName, Tuple[ResourceAxisName, ...]]
|
|
axis_vmap_size: Dict[AxisName, Optional[int]]
|
|
|
|
@property
|
|
def axis_subst(self) -> core.AxisSubst:
|
|
return lambda name: self.axis_subst_dict.get(name, (name,))
|
|
|
|
@property
|
|
def resource_axis_env(self):
|
|
env = dict(self.resource_env.shape)
|
|
for axis, size in self.axis_vmap_size.items():
|
|
if size is None:
|
|
continue
|
|
vmap_axis = self.axis_subst_dict[axis][-1]
|
|
env[vmap_axis] = size
|
|
return env
|
|
|
|
@classmethod
|
|
def from_axis_resources(cls,
|
|
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
|
|
resource_env: ResourceEnv,
|
|
global_axis_sizes: Dict[AxisName, int]):
|
|
physical_axis_resources, loop_axis_resources = _unzip_axis_resources(
|
|
axis_resources, resource_env)
|
|
axis_resource_count = _get_axis_resource_count(
|
|
axis_resources, resource_env)
|
|
axis_subst_dict = dict(axis_resources)
|
|
axis_vmap_size: Dict[AxisName, Optional[int]] = {}
|
|
for naxis, raxes in sorted(axis_resources.items(), key=lambda x: str(x[0])):
|
|
num_resources = axis_resource_count[naxis]
|
|
assert global_axis_sizes[naxis] % num_resources.nglobal == 0
|
|
local_tile_size = global_axis_sizes[naxis] // num_resources.nglobal
|
|
# We have to vmap when there are no resources (to handle the axis name!) or
|
|
# when every resource gets chunks of values.
|
|
if not raxes or local_tile_size > 1:
|
|
axis_vmap_size[naxis] = local_tile_size
|
|
axis_subst_dict[naxis] += (fresh_resource_name(naxis),)
|
|
else:
|
|
axis_vmap_size[naxis] = None
|
|
return cls(resource_env,
|
|
physical_axis_resources, loop_axis_resources,
|
|
axis_subst_dict, axis_vmap_size)
|
|
|
|
def subst_axes_with_resources(self, jaxpr):
|
|
try:
|
|
if any(self.loop_axis_resources.values()):
|
|
_check_no_loop_collectives(jaxpr, self.loop_axis_resources)
|
|
with core.extend_axis_env_nd(self.resource_axis_env.items()):
|
|
return core.subst_axis_names_jaxpr(jaxpr, self.axis_subst)
|
|
except core.DuplicateAxisNameError:
|
|
raise AssertionError("Incomplete resource type-checking? Please open a bug report!")
|
|
|
|
def vectorize_and_loop(self, f: lu.WrappedFun, in_axes, out_axes) -> lu.WrappedFun:
|
|
vmap_axes = {
|
|
naxis: raxes[-1]
|
|
for naxis, raxes in self.axis_subst_dict.items()
|
|
if self.axis_vmap_size[naxis] is not None
|
|
}
|
|
for naxis, vaxis in sorted(vmap_axes.items(), key=lambda x: x[1].uid):
|
|
local_tile_size = self.axis_vmap_size[naxis]
|
|
map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
|
|
map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
|
|
f = batching.vtile(f, map_in_axes, map_out_axes, tile_size=local_tile_size, axis_name=vaxis)
|
|
|
|
used_loops = set(it.chain.from_iterable(self.loop_axis_resources.values()))
|
|
if not used_loops:
|
|
return f
|
|
|
|
if len(used_loops) > 1:
|
|
# TODO: Support multiple loops
|
|
raise NotImplementedError("Only one loop per xmap is supported")
|
|
loop_in_axes = _to_resource_axes(in_axes, self.loop_axis_resources)
|
|
loop_out_axes = _to_resource_axes(out_axes, self.loop_axis_resources)
|
|
loop_name, = used_loops
|
|
loop_length = self.resource_env.shape[loop_name]
|
|
def looped_f(*args):
|
|
def body(i, _):
|
|
# XXX: This call_wrapped is only valid under the assumption that scan
|
|
# only ever traces the body once (which it does at the moment).
|
|
result = f.call_wrapped(
|
|
*(_slice_tile(arg, spec.get(loop_name, None), i, loop_length)
|
|
for arg, spec in zip(args, loop_in_axes)))
|
|
return i + 1, result
|
|
_, stacked_results = lax.scan(body, 0, (), length=loop_length)
|
|
return [_merge_leading_axis(sresult, spec.get(loop_name, None))
|
|
for sresult, spec in zip(stacked_results, loop_out_axes)]
|
|
return lu.wrap_init(looped_f)
|
|
|
|
def to_mesh_axes(self, in_axes, out_axes=None):
|
|
"""
|
|
Convert in/out_axes parameters ranging over logical dimensions to
|
|
in/out_axes that range over the mesh dimensions.
|
|
"""
|
|
if out_axes is None:
|
|
return _to_resource_axes(in_axes, self.physical_axis_resources)
|
|
else:
|
|
return (_to_resource_axes(in_axes, self.physical_axis_resources),
|
|
_to_resource_axes(out_axes, self.physical_axis_resources))
|
|
|
|
# -------- xmap primitive and its transforms --------
|
|
|
|
# xmap has a different set of parameters than pmap, so we make it its own primitive type
|
|
class XMapPrimitive(core.MapPrimitive):
|
|
def __init__(self):
|
|
super().__init__('xmap')
|
|
self.def_impl(xmap_impl)
|
|
self.def_custom_bind(self.bind)
|
|
|
|
def bind(self, fun, *args, in_axes, **params):
|
|
assert len(in_axes) == len(args), (in_axes, args)
|
|
return core.map_bind(self, fun, *args, in_axes=in_axes, **params)
|
|
|
|
def process(self, trace, fun, tracers, params):
|
|
return trace.process_xmap(self, fun, tracers, params)
|
|
|
|
def post_process(self, trace, out_tracers, params):
|
|
post_process = getattr(trace, 'post_process_xmap', None)
|
|
if post_process is None:
|
|
raise NotImplementedError
|
|
return post_process(self, out_tracers, params)
|
|
|
|
def get_bind_params(self, params):
|
|
new_params = dict(params)
|
|
jaxpr = new_params.pop('call_jaxpr')
|
|
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr), jaxpr, ())
|
|
axes = new_params.pop('out_axes')
|
|
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
|
|
spmd_axes = new_params.pop('spmd_out_axes')
|
|
if spmd_axes is not None:
|
|
new_params['spmd_out_axes_thunk'] = \
|
|
HashableFunction(lambda: spmd_axes, closure=spmd_axes)
|
|
else:
|
|
new_params['spmd_out_axes_thunk'] = None
|
|
return [subfun], new_params
|
|
|
|
xmap_p = XMapPrimitive()
|
|
core.EvalTrace.process_xmap = core.EvalTrace.process_call # type: ignore
|
|
def _process_xmap_default(self, call_primitive, f, tracers, params):
|
|
raise NotImplementedError(f"{type(self)} must override process_xmap to handle xmap")
|
|
core.Trace.process_xmap = _process_xmap_default # type: ignore
|
|
|
|
def _xmap_axis_subst(params, subst, traverse):
|
|
if 'call_jaxpr' not in params: # TODO(apaszke): This feels sketchy, but I'm not sure why
|
|
return params
|
|
if not traverse:
|
|
return params
|
|
def shadowed_subst(name):
|
|
return (name,) if name in params['global_axis_sizes'] else subst(name)
|
|
with core.extend_axis_env_nd(params['global_axis_sizes'].items()):
|
|
new_jaxpr = core.subst_axis_names_jaxpr(params['call_jaxpr'], shadowed_subst)
|
|
return dict(params, call_jaxpr=new_jaxpr)
|
|
core.axis_substitution_rules[xmap_p] = _xmap_axis_subst
|
|
|
|
# NOTE: We don't have to handle spmd_{in|out}_axes here, because
|
|
# SPMD batching always gets involved as the last transform before XLA translation
|
|
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
|
|
ad.call_param_updaters[xmap_p] = xla.xla_call_jvp_update_params
|
|
|
|
def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
|
|
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts
|
|
fun = lu.hashable_partial(
|
|
lu.wrap_init(ad.backward_pass),
|
|
call_jaxpr, reduce_axes + tuple(params['global_axis_sizes'].keys()), False)
|
|
fun, nz_arg_cts = ad.nonzero_outputs(fun)
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
|
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
|
|
in_axes, out_axes = params['in_axes'], params['out_axes']
|
|
new_in_axes = (*(axis for axis, x in zip(in_axes, args) if not ad.is_undefined_primal(x)),
|
|
*(axis for axis, x in zip(out_axes, cts_in) if type(x) is not ad.Zero))
|
|
# NOTE: This assumes that the output cotangents being zero is a deterministic
|
|
# function of which input cotangents were zero.
|
|
@as_hashable_function(closure=(in_axes, tuple(type(c) is ad.Zero for c in cts_in)))
|
|
def out_axes_thunk():
|
|
return tuple(axis for axis, nz in zip(in_axes, nz_arg_cts()) if nz)
|
|
new_params = dict(params,
|
|
name=wrap_name(params['name'], 'transpose'),
|
|
in_axes=new_in_axes,
|
|
out_axes_thunk=out_axes_thunk,
|
|
donated_invars=(False,) * len(new_in_axes),
|
|
spmd_out_axes_thunk=None)
|
|
del new_params['out_axes']
|
|
del new_params['spmd_out_axes']
|
|
out_flat = xmap_p.bind(fun, *all_args, **new_params)
|
|
arg_cts = tree_unflatten(out_tree(), out_flat)
|
|
|
|
axis_resource_count = _get_axis_resource_count(
|
|
params['axis_resources'], params['resource_env'])
|
|
local_axis_sizes = {
|
|
axis: axis_resource_count[axis].to_local(global_size)
|
|
for axis, global_size in params['global_axis_sizes'].items()
|
|
}
|
|
def unmap_zero(zero, axes):
|
|
return ad.Zero(_insert_aval_axes(zero.aval, axes, local_axis_sizes))
|
|
return tuple(unmap_zero(arg_ct, in_axis) if type(arg_ct) is ad.Zero else arg_ct
|
|
for arg_ct, in_axis in zip(arg_cts, in_axes))
|
|
ad.primitive_transposes[xmap_p] = _xmap_transpose
|
|
|
|
|
|
def _typecheck_xmap(
|
|
_, *in_atoms, call_jaxpr, name, in_axes, out_axes, donated_invars,
|
|
global_axis_sizes, axis_resources, resource_env, backend,
|
|
spmd_in_axes, spmd_out_axes):
|
|
in_avals = [x.aval for x in in_atoms]
|
|
axis_resource_count = _get_axis_resource_count(
|
|
axis_resources, resource_env)
|
|
local_axis_sizes = {
|
|
axis: axis_resource_count[axis].to_local(global_size)
|
|
for axis, global_size in global_axis_sizes.items()
|
|
}
|
|
binder_in_avals = [_insert_aval_axes(v.aval, a_in_axes, local_axis_sizes)
|
|
for v, a_in_axes in zip(call_jaxpr.invars, in_axes)]
|
|
for binder_in_aval, in_aval in zip(binder_in_avals, in_avals):
|
|
if not core.typecompat(binder_in_aval, in_aval):
|
|
raise core.JaxprTypeError(
|
|
f"xmap passes operand {in_aval} to jaxpr expecting {binder_in_aval}")
|
|
|
|
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
|
core._check_jaxpr(lambda: core.JaxprPpContext(), call_jaxpr)
|
|
|
|
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
|
|
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
|
|
for a, a_out_axes in zip(mapped_out_avals, out_axes)]
|
|
return out_avals, call_jaxpr.effects
|
|
core.custom_typechecks[xmap_p] = _typecheck_xmap
|
|
|
|
|
|
def _resource_typing_xmap(avals,
|
|
params,
|
|
source_info: source_info_util.SourceInfo,
|
|
resource_env,
|
|
outer_axis_resources):
|
|
axis_resources = params['axis_resources']
|
|
inner_axis_resources = dict(outer_axis_resources)
|
|
inner_axis_resources.update(axis_resources)
|
|
if len(inner_axis_resources) < len(outer_axis_resources) + len(axis_resources):
|
|
overlap = set(outer_axis_resources) & set(axis_resources)
|
|
raise JAXTypeError(
|
|
f"Detected disallowed xmap axis name shadowing at "
|
|
f"{source_info_util.summarize(source_info)} "
|
|
f"(shadowed axes: {mesh_lib.show_axes(overlap)})")
|
|
|
|
if resource_env.physical_mesh != params['resource_env'].physical_mesh:
|
|
raise RuntimeError("Changing the physical mesh is not allowed inside xmap.")
|
|
|
|
call_jaxpr = params['call_jaxpr']
|
|
pxla.resource_typecheck(
|
|
params['call_jaxpr'], resource_env, inner_axis_resources,
|
|
lambda: (f"an xmapped function {params['name']} " +
|
|
(f"(xmap called at {source_info_util.summarize(source_info)})"
|
|
if source_info else "")))
|
|
|
|
for v, axes in zip(call_jaxpr.outvars, params['out_axes']):
|
|
broadcast_axes = set(axes) - set(v.aval.named_shape)
|
|
used_resources = set(it.chain.from_iterable(
|
|
inner_axis_resources[a] for a in v.aval.named_shape))
|
|
for baxis in broadcast_axes:
|
|
baxis_resources = set(inner_axis_resources[baxis])
|
|
overlap = baxis_resources & used_resources
|
|
if overlap:
|
|
resource_to_axis = {}
|
|
for axis in v.aval.named_shape:
|
|
for raxis in inner_axis_resources[axis]:
|
|
resource_to_axis[raxis] = axis
|
|
partitioning_axes = {resource_to_axis[raxis] for raxis in overlap}
|
|
raise JAXTypeError(
|
|
f"One of xmapped function ({params['name']}) outputs is broadcast "
|
|
f"along axis `{baxis}` which is assigned to resources "
|
|
f"{mesh_lib.show_axes(baxis_resources)}, but the output is already "
|
|
f"partitioned along {mesh_lib.show_axes(overlap)}, because its "
|
|
f"named shape contains {mesh_lib.show_axes(partitioning_axes)}")
|
|
pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
|
|
|
|
|
|
# This is DynamicJaxprTrace.process_map with some very minor modifications
|
|
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
|
assert primitive is xmap_p
|
|
in_avals = [t.aval for t in tracers]
|
|
global_axis_sizes = params['global_axis_sizes']
|
|
mapped_in_avals = [_delete_aval_axes(a, a_in_axes, global_axis_sizes)
|
|
for a, a_in_axes in zip(in_avals, params['in_axes'])]
|
|
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
|
with core.new_sublevel():
|
|
jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
|
|
f, self.main, mapped_in_avals)
|
|
out_axes = params['out_axes_thunk']()
|
|
if params['spmd_out_axes_thunk'] is not None:
|
|
spmd_out_axes = params['spmd_out_axes_thunk']()
|
|
else:
|
|
spmd_out_axes = None
|
|
axis_resource_count = _get_axis_resource_count(
|
|
params['axis_resources'], params['resource_env'])
|
|
local_axis_sizes = {
|
|
axis: axis_resource_count[axis].to_local(global_size)
|
|
for axis, global_size in global_axis_sizes.items()
|
|
}
|
|
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
|
|
for a, a_out_axes in zip(mapped_out_avals, out_axes)]
|
|
_check_out_avals_vs_out_axes(out_avals, out_axes, params['global_axis_sizes'])
|
|
source_info = source_info_util.current()
|
|
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
|
invars = map(self.getvar, tracers)
|
|
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
|
outvars = map(self.makevar, out_tracers)
|
|
new_in_axes = (AxisNamePos(user_repr='{}'),) * len(consts) + params['in_axes']
|
|
if params['spmd_in_axes'] is None:
|
|
new_spmd_in_axes = None
|
|
else:
|
|
new_spmd_in_axes = (None,) * len(consts) + params['spmd_in_axes']
|
|
new_donated_invars = (False,) * len(consts) + params['donated_invars']
|
|
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
|
call_jaxpr = convert_constvars_jaxpr(jaxpr)
|
|
new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
|
|
donated_invars=new_donated_invars,
|
|
spmd_in_axes=new_spmd_in_axes,
|
|
spmd_out_axes=spmd_out_axes,
|
|
call_jaxpr=call_jaxpr)
|
|
del new_params['out_axes_thunk']
|
|
del new_params['spmd_out_axes_thunk']
|
|
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
|
|
new_params, call_jaxpr.effects, source_info)
|
|
self.frame.add_eqn(eqn)
|
|
return out_tracers
|
|
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore
|
|
|
|
def _xmap_partial_eval_custom_params_updater(
|
|
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
|
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
|
num_res: int, params_known: dict, params_staged: dict
|
|
) -> Tuple[dict, dict]:
|
|
assert params_known['spmd_in_axes'] is None is params_known['spmd_out_axes']
|
|
assert params_staged['spmd_in_axes'] is None is params_staged['spmd_out_axes']
|
|
|
|
# prune inputs to jaxpr_known according to unks_in
|
|
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
|
|
in_axes_known, _ = pe.partition_list(unks_in, params_known['in_axes'])
|
|
if num_res == 0:
|
|
residual_axes = []
|
|
else:
|
|
residual_axes = [
|
|
AxisNamePos(zip(sort_named_shape, range(len(sort_named_shape))),
|
|
user_repr=f'<internal: {sort_named_shape}>')
|
|
for named_shape in (v.aval.named_shape for v in params_known['call_jaxpr'].outvars[:-num_res])
|
|
# We sort here to make the iteration order deterministic
|
|
for sort_named_shape in [sorted(named_shape, key=str)]
|
|
]
|
|
_, out_axes_known = pe.partition_list(kept_outs_known, params_known['out_axes'])
|
|
new_params_known = dict(params_known,
|
|
in_axes=tuple(in_axes_known),
|
|
out_axes=(*out_axes_known, *residual_axes),
|
|
donated_invars=tuple(donated_invars_known))
|
|
assert len(new_params_known['in_axes']) == len(params_known['call_jaxpr'].invars)
|
|
assert len(new_params_known['out_axes']) == len(params_known['call_jaxpr'].outvars)
|
|
|
|
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
|
|
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
|
|
donated_invars_staged = [False] * num_res + donated_invars_staged
|
|
_, in_axes_staged = pe.partition_list(inst_in, params_staged['in_axes'])
|
|
in_axes_staged = [*residual_axes, *in_axes_staged]
|
|
_, out_axes_staged = pe.partition_list(kept_outs_staged, params_staged['out_axes'])
|
|
new_params_staged = dict(params_staged, in_axes=tuple(in_axes_staged),
|
|
out_axes=tuple(out_axes_staged),
|
|
donated_invars=tuple(donated_invars_staged))
|
|
assert len(new_params_staged['in_axes']) == len(params_staged['call_jaxpr'].invars)
|
|
assert len(new_params_staged['out_axes']) == len(params_staged['call_jaxpr'].outvars)
|
|
return new_params_known, new_params_staged
|
|
pe.partial_eval_jaxpr_custom_rules[xmap_p] = \
|
|
partial(pe.call_partial_eval_custom_rule, 'call_jaxpr',
|
|
_xmap_partial_eval_custom_params_updater)
|
|
|
|
|
|
@lu.transformation_with_aux
|
|
def out_local_named_shapes(local_axes, *args, **kwargs):
|
|
ans = yield args, kwargs
|
|
ans_axes = [frozenset(a.aval.named_shape) & local_axes for a in ans]
|
|
yield ans, ans_axes
|
|
|
|
|
|
def _jaxpr_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params):
|
|
assert primitive is xmap_p
|
|
assert params['spmd_out_axes_thunk'] is params['spmd_in_axes'] is None
|
|
in_axes = params['in_axes']
|
|
donated_invars = params['donated_invars']
|
|
global_axis_sizes = params['global_axis_sizes']
|
|
out_axes_thunk = params['out_axes_thunk']
|
|
|
|
# Adjust input tracers' pvals for mapped axes, and unpack.
|
|
in_pvals = [t.pval if t.pval.is_known() else
|
|
pe.PartialVal.unknown(
|
|
_delete_aval_axes(t.pval.get_aval(), axes, global_axis_sizes))
|
|
for t, axes in zip(tracers, in_axes)]
|
|
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
|
|
|
|
# Wrap f to perform partial evaluation, and plumb out aux data.
|
|
f = pe.trace_to_subjaxpr_nounits(f, self.main, False)
|
|
f, aux = pe.partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
|
|
# Also grab the local named shapes of the output (known and res).
|
|
f, out_named_shapes = out_local_named_shapes(f, frozenset(global_axis_sizes))
|
|
|
|
# Adjust params for knowns (donated_invars, in_axes, out_axes_thunk).
|
|
out_axes = None # cache this to avoid calling out_axes_thunk() more than once
|
|
|
|
@as_hashable_function(closure=out_axes_thunk)
|
|
def new_out_axes_thunk():
|
|
nonlocal out_axes
|
|
out_axes = out_axes_thunk()
|
|
out_knowns, _, _, _ = aux()
|
|
_, out_axes_known = partition_list(out_knowns, out_axes)
|
|
return (*out_axes_known, *res_axes())
|
|
def res_axes():
|
|
_, _, jaxpr_unknown, _ = aux()
|
|
num_res = len(jaxpr_unknown.constvars)
|
|
res_named_shapes = out_named_shapes()[-num_res:] if num_res else []
|
|
sorted_named_shapes = [sorted(ns, key=str) for ns in res_named_shapes]
|
|
return [AxisNamePos(zip(named_shape, range(len(named_shape))),
|
|
user_repr=f'<internal: {named_shape}>')
|
|
for named_shape in sorted_named_shapes]
|
|
known_params = dict(
|
|
params, in_axes=tuple(a for a, k in zip(in_axes, in_knowns) if k),
|
|
donated_invars=tuple(d for d, k in zip(donated_invars, in_knowns) if k),
|
|
out_axes_thunk=new_out_axes_thunk)
|
|
|
|
# Run the known part.
|
|
out = primitive.bind(f, *in_consts, **known_params)
|
|
out_knowns, out_avals, jaxpr_unknown, env = aux()
|
|
known_outvals, res = split_list(out, [len(out)-len(jaxpr_unknown.constvars)])
|
|
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
|
jaxpr_unknown = pe.convert_constvars_jaxpr(jaxpr_unknown)
|
|
|
|
# Set up new params.
|
|
if out_axes is None:
|
|
out_axes = out_axes_thunk() # new_out_axes_thunk may have set during bind
|
|
out_axes_unknown = [a for a, k in zip(out_axes, out_knowns) if not k]
|
|
unknown_params = dict(
|
|
params, call_jaxpr=jaxpr_unknown, out_axes=tuple(out_axes_unknown),
|
|
spmd_out_axes=None,
|
|
donated_invars=(*(False for _ in res),
|
|
*(d for d, k in zip(donated_invars, in_knowns) if not k)),
|
|
in_axes=(*res_axes(), *(None for _ in env),
|
|
*(a for a, k in zip(in_axes, in_knowns) if not k)))
|
|
del unknown_params['out_axes_thunk']
|
|
del unknown_params['spmd_out_axes_thunk']
|
|
# Create input tracers for unknown part.
|
|
res_tracers = map(self.new_instantiated_const, res)
|
|
env_tracers = map(self.full_raise, env)
|
|
unknown_arg_tracers = [t for t in tracers if not t.pval.is_known()]
|
|
# Create output tracers for unknown part, adjusting avals.
|
|
axis_resource_count = _get_axis_resource_count(
|
|
params['axis_resources'], params['resource_env'])
|
|
local_axis_sizes = {
|
|
ax: axis_resource_count[ax].to_local(global_size)
|
|
for ax, global_size in global_axis_sizes.items()}
|
|
out_pvals = [pe.PartialVal.unknown(_insert_aval_axes(a, ax, local_axis_sizes))
|
|
for a, ax in zip(out_avals, out_axes_unknown)]
|
|
unknown_tracers_out = [pe.JaxprTracer(self, pval, None) for pval in out_pvals]
|
|
# Build eqn to be staged out and attach it to unknown output tracers.
|
|
eqn = pe.new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers),
|
|
unknown_tracers_out, primitive, unknown_params,
|
|
jaxpr_unknown.effects, source_info_util.current())
|
|
for t in unknown_tracers_out: t.recipe = eqn
|
|
return merge_lists(out_knowns, unknown_tracers_out, known_outvals)
|
|
pe.JaxprTrace.process_xmap = _jaxpr_trace_process_xmap
|
|
|
|
def _batch_trace_update_spmd_axes(
|
|
spmd_in_axes, spmd_out_axes_thunk,
|
|
axis_name, dims, dims_out_thunk):
|
|
"""Extends spmd in and out axes with the position of the trace's batch dimension."""
|
|
not_mapped = batching.not_mapped
|
|
def insert_spmd_axis(axes, nd):
|
|
too_short = nd - len(axes)
|
|
if too_short > 0:
|
|
axes += (None,) * too_short
|
|
return tuple_insert(axes, nd, axis_name)
|
|
|
|
if spmd_in_axes is None:
|
|
spmd_in_axes = ((),) * len(dims)
|
|
new_spmd_in_axes = tuple(
|
|
spmd_axes if d is not_mapped else insert_spmd_axis(spmd_axes, d)
|
|
for spmd_axes, d in zip(spmd_in_axes, dims))
|
|
|
|
@as_hashable_function(closure=spmd_out_axes_thunk)
|
|
def new_spmd_out_axes_thunk():
|
|
dims_out = dims_out_thunk()
|
|
if spmd_out_axes_thunk is None:
|
|
spmd_out_axes = ((),) * len(dims_out)
|
|
else:
|
|
spmd_out_axes = spmd_out_axes_thunk()
|
|
return tuple(
|
|
spmd_out_axes if nd is not_mapped else insert_spmd_axis(spmd_out_axes, nd)
|
|
for spmd_out_axes, nd in zip(spmd_out_axes, dims_out))
|
|
|
|
return new_spmd_in_axes, new_spmd_out_axes_thunk
|
|
|
|
def _axis_after_insertion(axis, inserted_named_axes):
|
|
for inserted_axis in sorted(inserted_named_axes.values()):
|
|
if inserted_axis >= axis:
|
|
break
|
|
axis += 1
|
|
return axis
|
|
|
|
def _fmap_dims(axes, f):
|
|
return AxisNamePos(((name, f(axis)) for name, axis in axes.items()),
|
|
user_repr=axes.user_repr)
|
|
|
|
def _batch_trace_process_xmap(self, is_spmd, primitive, f: lu.WrappedFun, tracers, params):
|
|
not_mapped = batching.not_mapped
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
|
assert primitive is xmap_p
|
|
if not is_spmd and all(dim is not_mapped for dim in dims):
|
|
return primitive.bind(f, *vals, **params)
|
|
else:
|
|
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
|
|
new_in_axes = tuple(
|
|
_fmap_dims(in_axes, lambda a: a + (d is not not_mapped and d <= a))
|
|
for d, in_axes in zip(dims, params['in_axes']))
|
|
mapped_dims_in = tuple(
|
|
d if d is not_mapped else d - sum(a < d for a in in_axis.values())
|
|
for d, in_axis in zip(dims, params['in_axes']))
|
|
f, mapped_dims_out = batching.batch_subtrace(f, self.main, mapped_dims_in)
|
|
out_axes_thunk: Callable[[], Sequence[AxisNamePos]] = params['out_axes_thunk']
|
|
dims_out_thunk = lambda: tuple(d if d is not_mapped else _axis_after_insertion(d, out_axes)
|
|
for d, out_axes in zip(mapped_dims_out(), out_axes_thunk()))
|
|
# NOTE: This assumes that the choice of the dimensions over which outputs
|
|
# are batched is entirely dependent on the function and not e.g. on the
|
|
# data or its shapes.
|
|
@as_hashable_function(closure=out_axes_thunk)
|
|
def new_out_axes_thunk():
|
|
return tuple(
|
|
out_axes if d is not_mapped else
|
|
_fmap_dims(out_axes, lambda a, nd=_axis_after_insertion(d, out_axes): a + (nd <= a))
|
|
for out_axes, d in zip(out_axes_thunk(), mapped_dims_out()))
|
|
|
|
if not is_spmd:
|
|
assert params['spmd_in_axes'] is None and params['spmd_out_axes_thunk'] is None
|
|
new_spmd_in_axes = None
|
|
new_spmd_out_axes_thunk = None
|
|
else:
|
|
new_spmd_in_axes, new_spmd_out_axes_thunk = _batch_trace_update_spmd_axes(
|
|
params['spmd_in_axes'], params['spmd_out_axes_thunk'],
|
|
self.axis_name, dims, dims_out_thunk)
|
|
|
|
new_params = dict(params,
|
|
in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk,
|
|
spmd_in_axes=new_spmd_in_axes,
|
|
spmd_out_axes_thunk=new_spmd_out_axes_thunk)
|
|
vals_out = primitive.bind(f, *vals, **new_params)
|
|
dims_out = dims_out_thunk()
|
|
return [batching.BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
|
|
batching.BatchTrace.process_xmap = partialmethod(_batch_trace_process_xmap, False) # type: ignore
|
|
pxla.SPMDBatchTrace.process_xmap = partialmethod(_batch_trace_process_xmap, True) # type: ignore
|
|
|
|
|
|
def _batch_trace_post_process_xmap(self, primitive, out_tracers, params):
|
|
not_mapped = batching.not_mapped
|
|
BT = batching.BatchTracer
|
|
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) for t in out_tracers)
|
|
main = self.main
|
|
def todo(vals):
|
|
trace = main.with_cur_sublevel()
|
|
return [BT(trace, v, d if d is not_mapped else _axis_after_insertion(d, oa), s)
|
|
for v, d, oa, s in zip(vals, dims, params['out_axes_thunk'](), srcs)]
|
|
def out_axes_transform(out_axes):
|
|
return tuple(oa if d is not_mapped else
|
|
_fmap_dims(oa, lambda a, nd=_axis_after_insertion(d, oa): a + (nd <= a))
|
|
for oa, d in zip(out_axes, dims))
|
|
return vals, (todo, out_axes_transform)
|
|
batching.BatchTrace.post_process_xmap = _batch_trace_post_process_xmap
|
|
|
|
|
|
# -------- nested xmap handling --------
|
|
|
|
def _xmap_lowering_rule(ctx, *args, **kwargs):
|
|
if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext):
|
|
if config.experimental_xmap_spmd_lowering_manual:
|
|
return _xmap_lowering_rule_spmd_manual(ctx, *args, **kwargs)
|
|
else:
|
|
return _xmap_lowering_rule_spmd(ctx, *args, **kwargs)
|
|
# Here ShardingContext is used in place of ReplicaAxisContext because when
|
|
# axis_resources and mesh is not used with xmap, `make_xmap_callable` will
|
|
# go via `dispatch.sharded_lowering` path which sets the context to
|
|
# ShardingContext. sharding_impls.ShardingContext is not used for SPMD.
|
|
elif isinstance(ctx.module_context.axis_context,
|
|
(sharding_impls.ReplicaAxisContext, sharding_impls.ShardingContext)):
|
|
return _xmap_lowering_rule_replica(ctx, *args, **kwargs)
|
|
else:
|
|
raise AssertionError("Unrecognized axis context type!")
|
|
mlir.register_lowering(xmap_p, _xmap_lowering_rule)
|
|
|
|
def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
|
call_jaxpr, name,
|
|
in_axes, out_axes, donated_invars,
|
|
global_axis_sizes,
|
|
spmd_in_axes, spmd_out_axes,
|
|
axis_resources, resource_env, backend):
|
|
xla.check_backend_matches(backend, ctx.module_context.platform)
|
|
# The only way for any of those two assertions to be violated is when xmap
|
|
# is using the SPMD lowering, but then this rule shouldn't even trigger.
|
|
assert spmd_in_axes is None and spmd_out_axes is None
|
|
plan = EvaluationPlan.from_axis_resources(
|
|
axis_resources, resource_env, global_axis_sizes)
|
|
|
|
axis_resource_count = _get_axis_resource_count(
|
|
axis_resources, resource_env)
|
|
if any(resource_count.distributed for resource_count in axis_resource_count.values()):
|
|
raise NotImplementedError
|
|
|
|
mesh = resource_env.physical_mesh
|
|
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
|
|
|
|
local_avals = [pxla.tile_aval_nd(
|
|
mesh.shape, aval_mesh_in_axes,
|
|
_insert_aval_axes(v.aval, aval_in_axes, global_axis_sizes))
|
|
for v, aval_in_axes, aval_mesh_in_axes
|
|
in zip(call_jaxpr.invars, in_axes, mesh_in_axes)]
|
|
# We have to substitute before tracing, because we want the vectorized
|
|
# axes to be used in the jaxpr.
|
|
resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
|
|
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
|
|
f = hide_mapped_axes(f, tuple(in_axes), tuple(out_axes))
|
|
f = plan.vectorize_and_loop(f, in_axes, out_axes)
|
|
# NOTE: We don't extend the resource env with the mesh shape, because those
|
|
# resources are already in scope! It's the outermost xmap that introduces
|
|
# them!
|
|
vectorized_jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(f, local_avals)
|
|
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
|
|
const_nodes = map(mlir.ir_constants, consts)
|
|
|
|
local_mesh_shape = mesh.local_mesh.shape
|
|
tiled_ins = (
|
|
mlir.lower_fun(partial(_tile, in_axes=arg_in_axes,
|
|
axis_sizes=local_mesh_shape),
|
|
multiple_results=False)(
|
|
ctx.replace(primitive=None,
|
|
avals_in=[aval], avals_out=None),
|
|
in_node)[0]
|
|
for v, aval, in_node, arg_in_axes
|
|
in zip(call_jaxpr.invars, ctx.avals_in, in_nodes, mesh_in_axes))
|
|
|
|
# NOTE: We don't extend the resource env with the mesh shape, because those
|
|
# resources are already in scope! It's the outermost xmap that introduces
|
|
# them!
|
|
# We in-line here rather than generating a Call HLO as in the xla_call
|
|
# translation rule just because the extra tuple stuff is a pain.
|
|
sub_ctx = ctx.module_context.replace(
|
|
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')))
|
|
if any(effects.ordered_effects.contains(eff) for eff
|
|
in vectorized_jaxpr.effects):
|
|
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
|
tiled_outs, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, mlir.TokenSet(),
|
|
const_nodes, *tiled_ins,
|
|
dim_var_values=ctx.dim_var_values)
|
|
|
|
outs = [
|
|
mlir.lower_fun(
|
|
partial(_untile, out_axes=ans_out_axes, axis_sizes=local_mesh_shape,
|
|
platform=ctx.module_context.platform),
|
|
multiple_results=False)(
|
|
ctx.replace(primitive=None,
|
|
avals_in=[vectorized_outvar.aval],
|
|
avals_out=None), tiled_out)[0]
|
|
for v, vectorized_outvar, tiled_out, ans_out_axes
|
|
in zip(call_jaxpr.outvars, vectorized_jaxpr.outvars, tiled_outs,
|
|
mesh_out_axes)]
|
|
return outs
|
|
|
|
|
|
def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
|
call_jaxpr, name, in_axes, out_axes,
|
|
donated_invars, global_axis_sizes, spmd_in_axes,
|
|
spmd_out_axes, axis_resources,
|
|
resource_env, backend):
|
|
xla.check_backend_matches(backend, ctx.module_context.platform)
|
|
plan = EvaluationPlan.from_axis_resources(
|
|
axis_resources, resource_env, global_axis_sizes)
|
|
|
|
resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
|
|
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
|
|
f = hide_mapped_axes(f, in_axes, out_axes)
|
|
f = plan.vectorize_and_loop(f, in_axes, out_axes)
|
|
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
|
|
mesh = resource_env.physical_mesh
|
|
f = pxla.vtile_by_mesh(f, mesh, mesh_in_axes, mesh_out_axes)
|
|
|
|
# XXX: We modify mesh_in_axes and mesh_out_axes here
|
|
def add_spmd_axes(
|
|
flat_mesh_axes: Sequence[ArrayMapping],
|
|
flat_extra_axes: Optional[Sequence[Sequence[Sequence[MeshAxisName]]]]):
|
|
if flat_extra_axes is None:
|
|
return
|
|
for axes, extra in zip(flat_mesh_axes, flat_extra_axes):
|
|
for dim, dim_extra_axis in enumerate(extra):
|
|
if dim_extra_axis is None: continue
|
|
assert dim_extra_axis not in axes
|
|
assert not config.jax_enable_checks or all(v != dim for v in axes.values())
|
|
axes[dim_extra_axis] = dim
|
|
add_spmd_axes(mesh_in_axes, spmd_in_axes)
|
|
add_spmd_axes(mesh_out_axes, spmd_out_axes)
|
|
global_in_avals = ctx.avals_in
|
|
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
|
|
|
|
global_sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)
|
|
sharded_global_in_nodes = [
|
|
[mlir.wrap_with_sharding_op(node, global_sharding_spec(aval, aval_axes).sharding_proto())]
|
|
if aval_axes else [node]
|
|
for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes)
|
|
]
|
|
const_nodes = map(mlir.ir_constants, consts)
|
|
|
|
# We in-line here rather than generating a Call HLO as in the xla_call
|
|
# translation rule just because the extra tuple stuff is a pain.
|
|
sub_ctx = ctx.module_context.replace(
|
|
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')))
|
|
if any(effects.ordered_effects.contains(eff) for eff
|
|
in vectorized_jaxpr.effects):
|
|
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
|
global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr,
|
|
mlir.TokenSet(), const_nodes, *sharded_global_in_nodes,
|
|
dim_var_values=ctx.dim_var_values)
|
|
|
|
sharded_global_out_nodes = [
|
|
mlir.wrap_with_sharding_op(node, global_sharding_spec(aval, aval_axes).sharding_proto())
|
|
if aval_axes else node
|
|
for (node,), aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes)
|
|
]
|
|
|
|
return sharded_global_out_nodes
|
|
|
|
|
|
def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
|
call_jaxpr, name, in_axes, out_axes,
|
|
donated_invars, global_axis_sizes, spmd_in_axes,
|
|
spmd_out_axes, axis_resources,
|
|
resource_env, backend):
|
|
assert spmd_in_axes is None and spmd_out_axes is None
|
|
# This first part (up to vtile_manual) is shared with non-MANUAL SPMD rule.
|
|
xla.check_backend_matches(backend, ctx.module_context.platform)
|
|
plan = EvaluationPlan.from_axis_resources(
|
|
axis_resources, resource_env, global_axis_sizes)
|
|
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
|
|
|
resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
|
|
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
|
|
f = hide_mapped_axes(f, in_axes, out_axes)
|
|
f = plan.vectorize_and_loop(f, in_axes, out_axes)
|
|
|
|
# NOTE: Sharding constraints are handled entirely by vtile_manual!
|
|
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
|
|
mesh = resource_env.physical_mesh
|
|
f = pxla.vtile_manual(f, tuple(manual_mesh_axes), mesh, mesh_in_axes, mesh_out_axes)
|
|
|
|
# NOTE: We don't extend the resource env with the mesh shape, because those
|
|
# resources are already in scope! It's the outermost xmap that introduces
|
|
# them!
|
|
global_in_avals = ctx.avals_in
|
|
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
|
|
const_nodes = map(mlir.ir_constants, consts)
|
|
|
|
# We in-line here rather than generating a Call HLO as in the xla_call
|
|
# translation rule just because the extra tuple stuff is a pain.
|
|
assert isinstance(ctx.module_context.axis_context,
|
|
sharding_impls.SPMDAxisContext)
|
|
sub_ctx = ctx.module_context.replace(
|
|
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')),
|
|
axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes))
|
|
if any(effects.ordered_effects.contains(eff) for eff
|
|
in vectorized_jaxpr.effects):
|
|
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
|
global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr,
|
|
mlir.TokenSet(), const_nodes, *([n] for n in global_in_nodes),
|
|
dim_var_values=ctx.dim_var_values)
|
|
|
|
return global_out_nodes
|
|
|
|
|
|
def _tile_base_indices(tile_shape, axes, axis_sizes):
|
|
zero = np.zeros((), dtype=np.int32)
|
|
linear_idxs = [zero] * len(tile_shape)
|
|
strides = [1] * len(tile_shape)
|
|
for name, axis in reversed(axes.items()):
|
|
axis_index = lax.axis_index(name)
|
|
stride_c = np.array(strides[axis], np.int32)
|
|
if linear_idxs[axis] is zero and strides[axis] == 1:
|
|
linear_idxs[axis] = axis_index
|
|
else:
|
|
linear_idxs[axis] = lax.add(linear_idxs[axis],
|
|
lax.mul(axis_index, stride_c))
|
|
strides[axis] *= axis_sizes[name]
|
|
return [zero if linear_idx is zero else
|
|
lax.mul(linear_idx, np.array(tile_dim_size, np.int32))
|
|
for linear_idx, tile_dim_size in zip(linear_idxs, tile_shape)]
|
|
|
|
|
|
def _tile(x, in_axes, axis_sizes):
|
|
if not in_axes:
|
|
return x
|
|
tile_shape = list(x.shape)
|
|
for name, axis in in_axes.items():
|
|
axis_size = axis_sizes[name]
|
|
assert tile_shape[axis] % axis_size == 0
|
|
tile_shape[axis] //= axis_size
|
|
base_idxs = _tile_base_indices(tile_shape, in_axes, axis_sizes)
|
|
return lax.dynamic_slice(x, base_idxs, tile_shape)
|
|
|
|
|
|
# TODO(b/110096942): more efficient gather
|
|
def _untile(x, out_axes, axis_sizes, platform):
|
|
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
|
convert_bool = (np.issubdtype(x.dtype, np.bool_)
|
|
and platform in ('cpu', 'gpu'))
|
|
if convert_bool:
|
|
x = lax.convert_element_type(x, np.dtype(np.float32))
|
|
|
|
tile_shape = list(x.shape)
|
|
shape = list(tile_shape)
|
|
for name, axis in out_axes.items():
|
|
shape[axis] *= axis_sizes[name]
|
|
base_idxs = _tile_base_indices(tile_shape, out_axes, axis_sizes)
|
|
|
|
padded = lax.broadcast(np.array(0, x.dtype), shape)
|
|
padded = lax.dynamic_update_slice(padded, x, base_idxs)
|
|
out = lax.psum(padded, tuple(out_axes.keys()))
|
|
|
|
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
|
if convert_bool:
|
|
nonzero = lax.ne(out, np.array(0, dtype=np.float32))
|
|
out = lax.convert_element_type(nonzero, np.dtype(np.bool_))
|
|
return out
|
|
|
|
|
|
# -------- helper functions --------
|
|
|
|
def _delete_aval_axes(aval, axes: AxisNamePos, global_axis_sizes):
|
|
assert isinstance(aval, core.ShapedArray)
|
|
shape = list(aval.shape)
|
|
named_shape = dict(aval.named_shape)
|
|
for name, dim in sorted(axes.items(), key=lambda x: x[1], reverse=True):
|
|
named_shape[name] = global_axis_sizes[name]
|
|
del shape[dim]
|
|
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
|
|
|
def _insert_aval_axes(aval, axes: AxisNamePos, local_axis_sizes):
|
|
assert isinstance(aval, core.ShapedArray)
|
|
shape = list(aval.shape)
|
|
named_shape = dict(aval.named_shape)
|
|
for name, dim in sorted(axes.items(), key=lambda x: x[1]):
|
|
shape.insert(dim, local_axis_sizes[name])
|
|
named_shape.pop(name, None) # The name might be missing --- it's a broadcast.
|
|
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
|
|
|
|
|
class ResourceCount(NamedTuple):
|
|
nglobal: int
|
|
nlocal: Optional[int]
|
|
distributed: bool
|
|
|
|
def to_local(self, global_size):
|
|
return global_size
|
|
|
|
|
|
def _get_axis_resource_count(
|
|
axis_resources, resource_env) -> Dict[ResourceAxisName, ResourceCount]:
|
|
global_res_shape = resource_env.shape
|
|
local_res_shape = None
|
|
|
|
distributed = (False if resource_env.physical_mesh.empty else
|
|
resource_env.physical_mesh.size != len(resource_env.physical_mesh.local_devices))
|
|
resource_count_map = {}
|
|
for axis, resources in axis_resources.items():
|
|
if local_res_shape is None:
|
|
nlocal = None
|
|
else:
|
|
nlocal = int(np.prod(map(local_res_shape.get, resources), dtype=np.int64))
|
|
resource_count_map[axis] = ResourceCount(
|
|
int(np.prod(map(global_res_shape.get, resources), dtype=np.int64)),
|
|
nlocal, distributed)
|
|
return resource_count_map
|
|
|
|
|
|
def _get_axis_sizes(args_flat: Iterable[Any],
|
|
in_axes_flat: Iterable[AxisNamePos],
|
|
global_axis_sizes: Dict[AxisName, int],
|
|
axis_resource_count: Dict[AxisName, ResourceCount]):
|
|
global_axis_sizes = dict(global_axis_sizes)
|
|
for arg, in_axes in zip(args_flat, in_axes_flat):
|
|
for name, dim in in_axes.items():
|
|
try:
|
|
dim_size = arg.shape[dim]
|
|
except IndexError:
|
|
# TODO(apaszke): Handle negative indices. Check for overlap too!
|
|
raise ValueError(f"One of xmap arguments has an in_axes specification of "
|
|
f"{in_axes.user_repr}, which implies that it has at least "
|
|
f"{max(in_axes.values()) + 1} dimensions, but the argument "
|
|
f"has rank {arg.ndim}")
|
|
global_dim_size = dim_size
|
|
if name in global_axis_sizes:
|
|
expected_global_dim_size = global_axis_sizes[name]
|
|
if global_dim_size != expected_global_dim_size:
|
|
raise ValueError(f"The size of axis {name} was previously inferred to be "
|
|
f"{expected_global_dim_size}, but found an argument of shape {arg.shape} "
|
|
f"with in_axes specification {in_axes.user_repr}. Shape mismatch "
|
|
f"occurs in dimension {dim}: {global_dim_size} != {expected_global_dim_size}")
|
|
global_axis_sizes[name] = global_dim_size
|
|
return FrozenDict(global_axis_sizes)
|
|
|
|
|
|
@lu.transformation
|
|
def hide_mapped_axes(flat_in_axes, flat_out_axes, *flat_args):
|
|
def _squeeze_mapped_axes(arg, axes: AxisNamePos):
|
|
for dim in sorted(axes.values(), reverse=True):
|
|
arg = arg.squeeze(dim)
|
|
return arg
|
|
|
|
def _unsqueeze_mapped_axes(out, axes: AxisNamePos):
|
|
try:
|
|
return jnp.expand_dims(out, tuple(axes.values()))
|
|
except ValueError as e:
|
|
# Improve the axis out of bounds errors
|
|
# TODO(apaszke): Handle negative indices. Check for overlap too!
|
|
if e.args[0].startswith('axis') and 'out of bounds' in e.args[0]:
|
|
raise ValueError(f"One of xmap outputs has an out_axes specification of "
|
|
f"{axes.user_repr}, which requires the result of the xmapped "
|
|
f"function to have at least {max(axes.values()) - len(axes) + 1} "
|
|
f"positional dimensions, but it only has {out.ndim}")
|
|
raise
|
|
|
|
squeezed_args = map(_squeeze_mapped_axes, flat_args, flat_in_axes)
|
|
flat_outputs = yield squeezed_args, {}
|
|
yield map(_unsqueeze_mapped_axes, flat_outputs, flat_out_axes)
|
|
|
|
|
|
def _jaxpr_resources(jaxpr, resource_env) -> Set[ResourceAxisName]:
|
|
if isinstance(jaxpr, core.ClosedJaxpr):
|
|
jaxpr = jaxpr.jaxpr
|
|
assert isinstance(jaxpr, core.Jaxpr)
|
|
used_resources = set()
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive is xmap_p:
|
|
if eqn.params['resource_env'].physical_mesh != resource_env.physical_mesh:
|
|
raise RuntimeError("Changing the physical mesh is not allowed inside xmap.")
|
|
used_resources |= set(it.chain(*eqn.params['axis_resources'].values()))
|
|
updates = core.traverse_jaxpr_params(
|
|
partial(_jaxpr_resources, resource_env=resource_env), eqn.params).values()
|
|
for update in updates:
|
|
used_resources |= update
|
|
return used_resources
|
|
|
|
|
|
def _to_resource_axes(axes_specs: Sequence[AxisNamePos],
|
|
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]):
|
|
"""
|
|
Convert in/out_axes parameters ranging over logical dimensions to
|
|
ones that range over resource dimensions.
|
|
|
|
Note that values no longer have to be distinct, as multiple resource
|
|
axes can tile a single positional axes. This is why the result is
|
|
an OrderedDict with an implicit major-to-minor ordering.
|
|
"""
|
|
return tuple(OrderedDict((resource_axis, pos_axis)
|
|
for logical_axis, pos_axis in axes.items()
|
|
for resource_axis in axis_resources[logical_axis])
|
|
for axes in axes_specs)
|
|
|
|
|
|
def _merge_leading_axis(x, axis: Optional[int]):
|
|
if axis is None:
|
|
# We assume that the output does not vary along the leading axis
|
|
return lax.index_in_dim(x, 0, axis=0, keepdims=False)
|
|
else:
|
|
x_moved = moveaxis(x, 0, axis)
|
|
shape = list(x_moved.shape)
|
|
shape[axis:axis + 2] = [shape[axis] * shape[axis + 1]]
|
|
return x_moved.reshape(shape)
|
|
|
|
|
|
def _slice_tile(x, dim: Optional[int], i, n: int):
|
|
"""Selects an `i`th (out of `n`) tiles of `x` along `dim`."""
|
|
if dim is None: return x
|
|
(tile_size, rem) = divmod(x.shape[dim], n)
|
|
assert rem == 0, "Please open a bug report!"
|
|
return lax.dynamic_slice_in_dim(x, i * tile_size, slice_size=tile_size, axis=dim)
|
|
|
|
|
|
def _unzip_axis_resources(axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
|
|
resource_env: ResourceEnv):
|
|
"""Splits axis_resources into separate dicts for physical and loop resources."""
|
|
physical_axis_resources = {}
|
|
loop_axis_resources = {}
|
|
loop_resource_axes = resource_env.loop_resource_axes
|
|
for axis, raxes in axis_resources.items():
|
|
first_loop = 0
|
|
for raxis in raxes:
|
|
if raxis in loop_resource_axes:
|
|
break
|
|
else:
|
|
first_loop += 1
|
|
physical_axis_resources[axis] = raxes[:first_loop]
|
|
loop_resources = loop_axis_resources[axis] = raxes[first_loop:]
|
|
if not all(name in loop_resource_axes for name in loop_resources):
|
|
raise NotImplementedError("Loop resources cannot appear before mesh axes "
|
|
"in the resource_axis argument")
|
|
return physical_axis_resources, loop_axis_resources
|
|
|
|
|
|
def _check_out_avals_vs_out_axes(out_avals: Sequence[core.AbstractValue],
|
|
out_axes: Sequence[AxisNamePos],
|
|
global_axis_sizes: Dict[AxisName, int]):
|
|
defined_axes = set(global_axis_sizes)
|
|
for aval, axes in zip(out_avals, out_axes):
|
|
if not isinstance(aval, core.ShapedArray):
|
|
if axes:
|
|
raise AssertionError(f"Only array abstract values can have non-empty "
|
|
f"out_axes, but {aval} has {axes}")
|
|
continue
|
|
undeclared_axes = (set(aval.named_shape) - set(axes)) & defined_axes
|
|
if undeclared_axes:
|
|
undeclared_axes_str = sorted(str(axis) for axis in undeclared_axes)
|
|
raise TypeError(f"One of xmap results has an out_axes specification of "
|
|
f"{axes.user_repr}, but is actually mapped along more axes "
|
|
f"defined by this xmap call: {', '.join(undeclared_axes_str)}")
|
|
|
|
|
|
def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
|
|
global_axis_sizes, in_axes_flat,
|
|
args_flat):
|
|
@lru_cache()
|
|
def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor):
|
|
if not op_shardings.are_op_shardings_equal(
|
|
in_sharding._to_xla_op_sharding(ndim),
|
|
xmap_sharding._to_xla_op_sharding(ndim)):
|
|
raise ValueError(
|
|
f"Got an input {arr_flavor} to xmap with different partitioning than "
|
|
"specified in xmap. The partitioning must match. "
|
|
f"Got {arr_flavor} spec: {in_sharding.spec} and "
|
|
f"xmap spec: {xmap_sharding.spec}")
|
|
|
|
mesh_in_axes = EvaluationPlan.from_axis_resources( # pytype: disable=wrong-arg-types # always-use-return-annotations
|
|
axis_resources, resource_env, global_axis_sizes).to_mesh_axes(in_axes_flat)
|
|
for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes):
|
|
if isinstance(arg, ArrayImpl):
|
|
if not isinstance(arg.sharding, NamedSharding):
|
|
continue
|
|
mesh = arg.sharding.mesh
|
|
if mesh != resource_env.physical_mesh:
|
|
raise ValueError("xmap's mesh and Array's mesh should be equal. "
|
|
f"Got xmap mesh: {resource_env.physical_mesh},\n"
|
|
f"Array mesh: {mesh}")
|
|
|
|
s = arg.sharding
|
|
xmap_sharding = pxla.create_mesh_pspec_sharding(
|
|
mesh, array_mapping_to_axis_resources(xmap_array_mapping))
|
|
# This check is cached because comparing OpSharding is expensive during
|
|
# dispatch and if the shardings are the same, then there is no need to
|
|
# compare twice.
|
|
_check_sharding(s, xmap_sharding, arg.ndim, 'Array')
|
|
|
|
|
|
# TODO: We should relax this at least for "constructor primitives"
|
|
# such as axis_index or zeros.
|
|
def _check_no_loop_collectives(jaxpr, loop_axis_resources):
|
|
if isinstance(jaxpr, core.ClosedJaxpr):
|
|
jaxpr = jaxpr.jaxpr
|
|
def subst_no_loop(name):
|
|
if loop_axis_resources.get(name, ()):
|
|
raise RuntimeError(f"Named axes with loop resources assigned to them cannot "
|
|
f"be referenced inside the xmapped computation (e.g. in "
|
|
f"collectives), but `{name}` violates that rule")
|
|
return (name,)
|
|
for eqn in jaxpr.eqns:
|
|
core.subst_axis_names(eqn.primitive, eqn.params, subst_no_loop, traverse=False)
|
|
rec = partial(_check_no_loop_collectives, loop_axis_resources=loop_axis_resources)
|
|
core.traverse_jaxpr_params(rec, eqn.params)
|
|
|
|
|
|
def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
|
|
rec = lambda jaxpr: _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name)
|
|
if isinstance(jaxpr, core.ClosedJaxpr):
|
|
return jaxpr.map_jaxpr(rec)
|
|
assert isinstance(jaxpr, core.Jaxpr)
|
|
if gen_fresh_name is None:
|
|
gen_fresh_name = core.gensym([jaxpr])
|
|
new_eqns = []
|
|
for eqn in jaxpr.eqns:
|
|
new_jaxpr_params = core.traverse_jaxpr_params(rec, eqn.params)
|
|
tmp_outvars = [gen_fresh_name(v.aval) for v in eqn.outvars]
|
|
new_eqns.append(eqn.replace(
|
|
outvars=tmp_outvars, params=dict(eqn.params, **new_jaxpr_params)))
|
|
for outvar, tmpvar in zip(eqn.outvars, tmp_outvars):
|
|
mps = NamedSharding._from_parsed_pspec(
|
|
resource_env.physical_mesh, ParsedPartitionSpec((), ()))
|
|
unconstrained_dims = get_unconstrained_dims(mps)
|
|
gspmd_sharding = GSPMDSharding.get_replicated(mps._device_assignment)
|
|
new_eqns.append(core.JaxprEqn(
|
|
[tmpvar], [outvar], sharding_constraint_p,
|
|
dict(resource_env=resource_env,
|
|
sharding=gspmd_sharding,
|
|
unconstrained_dims=unconstrained_dims),
|
|
set(),
|
|
eqn.source_info))
|
|
return jaxpr.replace(eqns=new_eqns)
|
|
|
|
def _flatten_axes(what, tree, axes, tupled_args):
|
|
try:
|
|
return tuple(flatten_axes(what, tree, axes, tupled_args=tupled_args))
|
|
except ValueError:
|
|
pass
|
|
# Replace axis_resources with unparsed versions to avoid revealing internal details
|
|
flatten_axes(what, tree, tree_map(lambda parsed: NoQuotesStr(parsed.user_repr), axes),
|
|
tupled_args=tupled_args)
|
|
raise AssertionError("Please open a bug request!") # This should be unreachable
|
|
|
|
class NoQuotesStr(str):
|
|
__repr__ = str.__str__
|
|
|
|
|
|
# -------- config flags --------
|
|
|
|
def _thread_local_flag_unsupported(_):
|
|
raise RuntimeError("thread-local xmap flags not supported!")
|
|
def _clear_compilation_cache(_):
|
|
make_xmap_callable.cache_clear() # type: ignore
|
|
|
|
def _ensure_spmd_and(f):
|
|
def update(v):
|
|
if v and not config.experimental_xmap_spmd_lowering:
|
|
raise RuntimeError("This flag requires enabling the experimental_xmap_spmd_lowering flag")
|
|
return f(v)
|
|
return update
|
|
|
|
|
|
try:
|
|
config.define_bool_state(
|
|
name="experimental_xmap_spmd_lowering",
|
|
default=False,
|
|
help=("When set, multi-device xmap computations will be compiled through "
|
|
"the XLA SPMD partitioner instead of explicit cross-replica collectives. "
|
|
"Not supported on CPU!"),
|
|
update_global_hook=_clear_compilation_cache,
|
|
update_thread_local_hook=_thread_local_flag_unsupported)
|
|
config.define_bool_state(
|
|
name="experimental_xmap_spmd_lowering_manual",
|
|
default=False,
|
|
help=("When set, multi-device xmap computations will be compiled using "
|
|
"the MANUAL partitioning feature of the XLA SPMD partitioner instead of "
|
|
"sharding constraints on vectorized code. "
|
|
"Requires experimental_xmap_spmd_lowering!"),
|
|
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
|
|
update_thread_local_hook=_thread_local_flag_unsupported)
|
|
config.define_bool_state(
|
|
name="experimental_xmap_ensure_fixed_sharding",
|
|
default=False,
|
|
help=("When set and `experimental_xmap_spmd_lowering` is enabled, the lowering will "
|
|
"try to limit the flexibility of the automated SPMD partitioner heuristics "
|
|
"by emitting additional sharding annotations for program intermediates."),
|
|
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
|
|
update_thread_local_hook=_thread_local_flag_unsupported)
|
|
except Exception:
|
|
raise ImportError("jax.experimental.maps has to be imported before JAX flags "
|
|
"are parsed")
|