2021-02-05 16:50:38 -08:00
|
|
|
|
# Copyright 2021 Google LLC
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
2021-05-06 12:34:15 -07:00
|
|
|
|
from enum import IntEnum
|
2021-04-15 06:12:18 -07:00
|
|
|
|
import numpy as np
|
2021-04-26 03:45:31 -07:00
|
|
|
|
from collections import OrderedDict, Counter
|
|
|
|
|
from typing import Callable, Sequence, Tuple, Union
|
2021-02-05 16:50:38 -08:00
|
|
|
|
from warnings import warn
|
2021-04-26 03:45:31 -07:00
|
|
|
|
import itertools as it
|
2021-04-30 09:56:53 -07:00
|
|
|
|
from functools import partial
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
|
|
|
|
from . import maps
|
2021-04-26 03:45:31 -07:00
|
|
|
|
from . import PartitionSpec
|
2021-02-05 16:50:38 -08:00
|
|
|
|
from .. import core
|
|
|
|
|
from .. import linear_util as lu
|
2021-04-13 09:42:54 -07:00
|
|
|
|
from .._src.api import _check_callable, _check_arg
|
2021-04-27 02:19:18 -07:00
|
|
|
|
from .._src import source_info_util
|
2021-02-05 16:50:38 -08:00
|
|
|
|
from ..api_util import (argnums_partial_except, flatten_axes,
|
|
|
|
|
flatten_fun_nokwargs, _ensure_index_tuple,
|
|
|
|
|
donation_vector, rebase_donate_argnums)
|
2021-04-27 02:19:18 -07:00
|
|
|
|
from ..errors import JAXTypeError
|
2021-02-05 16:50:38 -08:00
|
|
|
|
from ..interpreters import ad
|
|
|
|
|
from ..interpreters import pxla
|
|
|
|
|
from ..interpreters import xla
|
2021-05-06 12:34:15 -07:00
|
|
|
|
from ..interpreters import batching
|
2021-04-21 04:09:30 -07:00
|
|
|
|
from ..interpreters import partial_eval as pe
|
2021-02-05 16:50:38 -08:00
|
|
|
|
from ..lib import xla_bridge as xb
|
|
|
|
|
from ..lib import xla_client as xc
|
|
|
|
|
from ..tree_util import tree_flatten, tree_unflatten
|
|
|
|
|
from .._src.util import (extend_name_stack, HashableFunction, safe_zip,
|
2021-04-21 11:04:52 -07:00
|
|
|
|
wrap_name, wraps, distributed_debug_log,
|
2021-05-06 12:34:15 -07:00
|
|
|
|
split_list, cache, tuple_insert)
|
2021-02-05 16:50:38 -08:00
|
|
|
|
xops = xc._xla.ops
|
|
|
|
|
|
|
|
|
|
def pjit(fun: Callable,
|
|
|
|
|
in_axis_resources,
|
|
|
|
|
out_axis_resources,
|
|
|
|
|
static_argnums: Union[int, Sequence[int]] = (),
|
|
|
|
|
donate_argnums: Union[int, Sequence[int]] = ()):
|
2021-05-06 12:00:18 -07:00
|
|
|
|
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
|
|
|
|
|
|
|
|
|
|
The returned function has semantics equivalent to those of ``fun``, but is
|
|
|
|
|
compiled to an XLA computation that runs across multiple devices
|
|
|
|
|
(e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted
|
|
|
|
|
version of ``fun`` would not fit in a single device's memory, or to speed up
|
|
|
|
|
``fun`` by running each operation in parallel across multiple devices.
|
|
|
|
|
|
|
|
|
|
The partitioning over devices happens automatically based on
|
|
|
|
|
propagation of input partitioning specified in ``in_axis_resources`` and
|
|
|
|
|
output partitioning specified in ``out_axis_resources``. The resources
|
|
|
|
|
specified in those two arguments must refer to mesh axes, as defined by
|
|
|
|
|
the :py:func:`jax.experimental.maps.mesh` context manager. Note that the mesh
|
|
|
|
|
definition at ``pjit`` application time is ignored, and the returned function
|
|
|
|
|
will use the mesh definition available at each call site.
|
|
|
|
|
|
|
|
|
|
Inputs to a pjit'd function will be automatically partitioned across devices
|
|
|
|
|
if they're not already correctly partitioned based on ``in_axis_resources``.
|
|
|
|
|
In some scenarios, ensuring that the inputs are already correctly pre-partitioned
|
|
|
|
|
can increase performance. For example, if passing the output of one pjit'd function
|
|
|
|
|
to another pjit’d function (or the same pjit’d function in a loop), make sure the
|
|
|
|
|
relevant ``out_axis_resources`` match the corresponding ``in_axis_resources``.
|
|
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
**Multi-process platforms:** On multi-process platforms such as TPU pods,
|
|
|
|
|
``pjit`` can be used to run computations across all available devices across
|
|
|
|
|
processes. To achieve this, ``pjit`` is designed to be used in SPMD Python
|
|
|
|
|
programs, where every process is running the same Python code such that all
|
|
|
|
|
processes run the same pjit'd function in the same order.
|
|
|
|
|
|
|
|
|
|
When running in this configuration, the mesh should contain devices across
|
|
|
|
|
all processes. However, any input argument dimensions partitioned over
|
|
|
|
|
multi-process mesh axes should be of size equal to the corresponding *local*
|
|
|
|
|
mesh axis size, and outputs will be similarly sized according to the local
|
|
|
|
|
mesh. ``fun`` will still be executed across *all* devices in the mesh,
|
|
|
|
|
including those from other processes, and will be given a global view of the
|
|
|
|
|
data spread accross multiple processes as a single array. However, outside
|
|
|
|
|
of ``pjit`` every process only "sees" its local piece of the input and output,
|
|
|
|
|
corresponding to its local sub-mesh.
|
|
|
|
|
|
|
|
|
|
The SPMD model requires that the same multi-process ``pjit``'d functions must
|
|
|
|
|
be run in the same order on all processes, but they can be interspersed with
|
|
|
|
|
arbitrary operations running in a single process.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
fun: Function to be compiled. Should be a pure function, as side-effects may
|
|
|
|
|
only be executed once. Its arguments and return value should be arrays,
|
|
|
|
|
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
|
|
|
|
|
Positional arguments indicated by ``static_argnums`` can be anything at
|
|
|
|
|
all, provided they are hashable and have an equality operation defined.
|
|
|
|
|
Static arguments are included as part of a compilation cache key, which is
|
|
|
|
|
why hash and equality operators must be defined.
|
|
|
|
|
in_axis_resources: Pytree of structure matching that of arguments to ``fun``,
|
|
|
|
|
with all actual arguments replaced by resource assignment specifications.
|
|
|
|
|
It is also valid to specify a pytree prefix (e.g. one value in place of a
|
|
|
|
|
whole subtree), in which case the leaves get broadcast to all values in
|
|
|
|
|
that subtree.
|
|
|
|
|
|
|
|
|
|
The valid resource assignment specifications are:
|
|
|
|
|
- :py:obj:`None`, in which case the value will be replicated on all devices
|
|
|
|
|
- :py:class:`PartitionSpec`, a tuple of length at most equal to the rank
|
|
|
|
|
of the partitioned value. Each element can be a :py:obj:`None`, a mesh
|
|
|
|
|
axis or a tuple of mesh axes, and specifies the set of resources assigned
|
|
|
|
|
to partition the value's dimension matching its position in the spec.
|
|
|
|
|
|
|
|
|
|
The size of every dimension has to be a multiple of the total number of
|
|
|
|
|
resources assigned to it.
|
|
|
|
|
out_axis_resources: Like ``in_axis_resources``, but specifies resource
|
|
|
|
|
assignment for function outputs.
|
|
|
|
|
static_argnums: An optional int or collection of ints that specify which
|
|
|
|
|
positional arguments to treat as static (compile-time constant).
|
|
|
|
|
Operations that only depend on static arguments will be constant-folded in
|
|
|
|
|
Python (during tracing), and so the corresponding argument values can be
|
|
|
|
|
any Python object.
|
|
|
|
|
|
|
|
|
|
Static arguments should be hashable, meaning both ``__hash__`` and
|
|
|
|
|
``__eq__`` are implemented, and immutable. Calling the jitted function
|
|
|
|
|
with different values for these constants will trigger recompilation.
|
|
|
|
|
Arguments that are not arrays or containers thereof must be marked as
|
|
|
|
|
static.
|
|
|
|
|
|
|
|
|
|
If ``static_argnums`` is not provided, no arguments are treated as static.
|
|
|
|
|
donate_argnums: Specify which arguments are "donated" to the computation.
|
|
|
|
|
It is safe to donate arguments if you no longer need them once the
|
|
|
|
|
computation has finished. In some cases XLA can make use of donated
|
|
|
|
|
buffers to reduce the amount of memory needed to perform a computation,
|
|
|
|
|
for example recycling one of your input buffers to store a result. You
|
|
|
|
|
should not reuse buffers that you donate to a computation, JAX will raise
|
|
|
|
|
an error if you try to.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A wrapped version of ``fun``, set up for just-in-time compilation and
|
|
|
|
|
automatic partitioned by the mesh available at each call site.
|
|
|
|
|
|
|
|
|
|
For example, a convolution operator can be automatically partitioned over
|
|
|
|
|
an arbitrary set of devices by a single ```pjit`` application:
|
|
|
|
|
|
|
|
|
|
>>> import jax
|
|
|
|
|
>>> import jax.numpy as jnp
|
|
|
|
|
>>> from jax.experimental.maps import mesh
|
|
|
|
|
>>> from jax.experimental.pjit import PartitionSpec, pjit
|
|
|
|
|
>>>
|
|
|
|
|
>>> x = jnp.arange(8, dtype=jnp.float32)
|
|
|
|
|
>>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'),
|
|
|
|
|
... in_axis_resources=None, out_axis_resources=PartitionSpec('devices'))
|
|
|
|
|
>>> with mesh(jax.devices(), ('devices',)):
|
|
|
|
|
... print(f(x)) # doctest: +SKIP
|
|
|
|
|
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
|
|
|
|
|
"""
|
2021-02-05 16:50:38 -08:00
|
|
|
|
warn("pjit is an experimental feature and probably has bugs!")
|
|
|
|
|
_check_callable(fun)
|
|
|
|
|
|
|
|
|
|
# To be a tree prefix of the positional args tuple, in_axes can never be a
|
|
|
|
|
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
|
|
|
|
|
# in cases like these users expect tuples and lists to be treated
|
|
|
|
|
# essentially interchangeably, so we canonicalize lists to tuples here
|
|
|
|
|
# rather than raising an error. https://github.com/google/jax/issues/2367
|
|
|
|
|
if isinstance(in_axis_resources, list):
|
|
|
|
|
in_axis_resources = tuple(in_axis_resources)
|
|
|
|
|
if isinstance(out_axis_resources, list):
|
|
|
|
|
out_axis_resources = tuple(out_axis_resources)
|
|
|
|
|
|
2021-04-26 03:45:31 -07:00
|
|
|
|
in_axis_resources, in_axis_resources_entries, _ = \
|
|
|
|
|
_prepare_axis_resources(in_axis_resources, "in_axis_resources")
|
|
|
|
|
out_axis_resources, out_axis_resources_entries, out_axis_treedef = \
|
|
|
|
|
_prepare_axis_resources(out_axis_resources, "out_axis_resources")
|
|
|
|
|
out_axis_resources_entries = tuple(out_axis_resources_entries)
|
|
|
|
|
|
2021-02-05 16:50:38 -08:00
|
|
|
|
static_argnums = _ensure_index_tuple(static_argnums)
|
|
|
|
|
donate_argnums = _ensure_index_tuple(donate_argnums)
|
|
|
|
|
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
|
|
|
|
|
|
|
|
|
|
@wraps(fun)
|
|
|
|
|
def wrapped(*args, **kwargs):
|
|
|
|
|
if kwargs:
|
|
|
|
|
raise NotImplementedError("pjit over kwargs not yet supported")
|
|
|
|
|
if max(static_argnums + donate_argnums, default=-1) >= len(args):
|
|
|
|
|
raise ValueError(f"jitted function has static_argnums={static_argnums}, "
|
|
|
|
|
f"donate_argnums={donate_argnums} but "
|
|
|
|
|
f"was called with only {len(args)} positional arguments.")
|
|
|
|
|
|
|
|
|
|
# Putting this outside of wrapped would make resources lexically scoped
|
|
|
|
|
resource_env = maps.thread_resources.env
|
2021-04-30 09:56:53 -07:00
|
|
|
|
mesh = resource_env.physical_mesh
|
2021-05-06 04:18:47 -07:00
|
|
|
|
if mesh.empty:
|
|
|
|
|
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
|
|
|
|
|
"it's defined at the call site?")
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
|
|
|
|
f = lu.wrap_init(fun)
|
|
|
|
|
if static_argnums:
|
2021-03-29 13:52:39 -07:00
|
|
|
|
f, dyn_args = argnums_partial_except(
|
|
|
|
|
f, static_argnums, args, allow_invalid=False)
|
2021-02-05 16:50:38 -08:00
|
|
|
|
else:
|
|
|
|
|
dyn_args = args
|
|
|
|
|
|
|
|
|
|
args_flat, in_tree = tree_flatten(args)
|
|
|
|
|
for arg in args_flat: _check_arg(arg)
|
|
|
|
|
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
|
|
|
|
|
if donate_argnums:
|
|
|
|
|
donated_invars = donation_vector(donate_argnums, dyn_args, ())
|
|
|
|
|
else:
|
|
|
|
|
donated_invars = (False,) * len(args_flat)
|
|
|
|
|
|
2021-04-30 09:56:53 -07:00
|
|
|
|
local_in_avals = tuple(core.raise_to_shaped(core.get_aval(a)) for a in args_flat)
|
|
|
|
|
jaxpr, in_axis_resources_flat, out_axis_resources_flat = \
|
|
|
|
|
_pjit_jaxpr(flat_fun, mesh, local_in_avals,
|
2021-05-05 06:07:16 -07:00
|
|
|
|
in_tree, hashable_pytree(in_axis_resources),
|
|
|
|
|
HashableFunction(out_tree, closure=()), hashable_pytree(out_axis_resources))
|
2021-04-30 09:56:53 -07:00
|
|
|
|
|
|
|
|
|
out = pjit_p.bind(
|
2021-02-05 16:50:38 -08:00
|
|
|
|
*args_flat,
|
2021-04-30 09:56:53 -07:00
|
|
|
|
jaxpr=jaxpr,
|
2021-02-05 16:50:38 -08:00
|
|
|
|
in_axis_resources=in_axis_resources_flat,
|
2021-04-30 09:56:53 -07:00
|
|
|
|
out_axis_resources=out_axis_resources_flat,
|
2021-02-05 16:50:38 -08:00
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
donated_invars=donated_invars,
|
|
|
|
|
name=flat_fun.__name__)
|
|
|
|
|
return tree_unflatten(out_tree(), out)
|
|
|
|
|
|
|
|
|
|
return wrapped
|
|
|
|
|
|
2021-04-30 09:56:53 -07:00
|
|
|
|
class _ListWithW(list):
|
|
|
|
|
__slots__ = ('__weakref__',)
|
|
|
|
|
|
2021-05-05 06:07:16 -07:00
|
|
|
|
def hashable_pytree(pytree):
|
|
|
|
|
vals, treedef = tree_flatten(pytree)
|
|
|
|
|
vals = tuple(vals)
|
|
|
|
|
return HashableFunction(lambda: tree_unflatten(treedef, vals),
|
|
|
|
|
closure=(treedef, vals))
|
|
|
|
|
|
2021-04-30 09:56:53 -07:00
|
|
|
|
@lu.cache
|
|
|
|
|
def _pjit_jaxpr(fun, mesh, local_in_avals,
|
2021-05-05 06:07:16 -07:00
|
|
|
|
in_tree, in_axis_resources_thunk,
|
|
|
|
|
out_tree, out_axis_resources_thunk):
|
2021-04-30 09:56:53 -07:00
|
|
|
|
in_axis_resources_flat = tuple(flatten_axes("pjit in_axis_resources",
|
2021-05-05 06:07:16 -07:00
|
|
|
|
in_tree, in_axis_resources_thunk()))
|
2021-04-30 09:56:53 -07:00
|
|
|
|
_check_shapes_against_resources("pjit arguments", False, mesh.local_mesh.shape,
|
|
|
|
|
local_in_avals, in_axis_resources_flat)
|
|
|
|
|
global_in_avals = local_to_global(mesh, local_in_avals, in_axis_resources_flat)
|
|
|
|
|
|
|
|
|
|
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals)
|
|
|
|
|
jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
|
|
|
|
|
|
|
|
|
out_axis_resources_flat = tuple(flatten_axes("pjit out_axis_resources",
|
2021-05-05 06:07:16 -07:00
|
|
|
|
out_tree(), out_axis_resources_thunk()))
|
2021-04-30 09:56:53 -07:00
|
|
|
|
_check_shapes_against_resources("pjit outputs", mesh.is_multi_process, mesh.shape,
|
|
|
|
|
global_out_avals, out_axis_resources_flat)
|
|
|
|
|
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
|
|
|
|
|
return _ListWithW([jaxpr, in_axis_resources_flat, out_axis_resources_flat])
|
|
|
|
|
|
|
|
|
|
|
2021-05-06 12:34:15 -07:00
|
|
|
|
class SpecSync(IntEnum):
|
|
|
|
|
"""Encodes how much out of sync the real value of partitions is compared to the user specified one.
|
|
|
|
|
|
|
|
|
|
We use this to make sure we don't show garbage modified values while claiming
|
|
|
|
|
that the users have specified them like that.
|
|
|
|
|
"""
|
|
|
|
|
DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes
|
|
|
|
|
IN_SYNC = 2 # Entirely in sync
|
|
|
|
|
|
2021-04-26 03:45:31 -07:00
|
|
|
|
class ParsedPartitionSpec:
|
2021-05-06 12:34:15 -07:00
|
|
|
|
__slots__ = ('partitions', 'unsafe_user_spec', 'sync')
|
|
|
|
|
|
|
|
|
|
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
|
2021-04-26 03:45:31 -07:00
|
|
|
|
self.partitions = tuple(partitions)
|
2021-05-06 12:34:15 -07:00
|
|
|
|
self.unsafe_user_spec = user_spec
|
|
|
|
|
self.sync = sync
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def user_spec(self):
|
|
|
|
|
return self.unsynced_user_spec(SpecSync.IN_SYNC)
|
|
|
|
|
|
|
|
|
|
def unsynced_user_spec(self, min_sync):
|
|
|
|
|
if self.sync < min_sync:
|
|
|
|
|
raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})")
|
|
|
|
|
return self.unsafe_user_spec
|
|
|
|
|
|
|
|
|
|
def insert_axis_partitions(self, dim, val):
|
|
|
|
|
parts = self.partitions
|
|
|
|
|
too_short = dim - len(parts)
|
|
|
|
|
if too_short > 0:
|
|
|
|
|
parts += ((),) * too_short
|
|
|
|
|
new_partitions = tuple_insert(parts, dim, val)
|
|
|
|
|
new_sync = SpecSync.DIM_PERMUTE if val == () else SpecSync.IN_SYNC
|
|
|
|
|
return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync)
|
2021-04-26 03:45:31 -07:00
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_user_input(cls, entry, arg_name):
|
|
|
|
|
if entry is None:
|
|
|
|
|
return cls(entry, ())
|
|
|
|
|
if not isinstance(entry, PartitionSpec):
|
|
|
|
|
raise TypeError(f"{arg_name} are expected to be "
|
|
|
|
|
f"PartitionSpec instances or None, but got {entry}")
|
|
|
|
|
axis_specs = []
|
|
|
|
|
for axis_spec in entry:
|
|
|
|
|
if axis_spec is None:
|
|
|
|
|
axis_spec = ()
|
|
|
|
|
elif isinstance(axis_spec, (list, tuple)):
|
|
|
|
|
axis_spec = tuple(axis_spec)
|
|
|
|
|
else:
|
|
|
|
|
axis_spec = (axis_spec,)
|
|
|
|
|
axis_specs.append(axis_spec)
|
|
|
|
|
return cls(entry, axis_specs)
|
|
|
|
|
|
|
|
|
|
def __hash__(self):
|
2021-05-06 12:34:15 -07:00
|
|
|
|
return hash((self.partitions, self.sync))
|
2021-04-26 03:45:31 -07:00
|
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
2021-05-06 12:34:15 -07:00
|
|
|
|
return (self.partitions == other.partitions and
|
|
|
|
|
self.unsafe_user_spec == other.unsafe_user_spec and
|
|
|
|
|
self.sync == other.sync)
|
2021-04-26 03:45:31 -07:00
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.partitions)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, i):
|
|
|
|
|
return self.partitions[i]
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
return iter(self.partitions)
|
|
|
|
|
|
|
|
|
|
REPLICATED = ParsedPartitionSpec(None, ())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _prepare_axis_resources(axis_resources, arg_name):
|
|
|
|
|
# PyTrees don't treat None values as leaves, so we explicitly need
|
|
|
|
|
# to explicitly declare them as such
|
|
|
|
|
entries, treedef = tree_flatten(axis_resources, is_leaf=lambda x: x is None)
|
|
|
|
|
what = f"{arg_name} leaf specifications"
|
|
|
|
|
entries = [ParsedPartitionSpec.from_user_input(entry, what) for entry in entries]
|
|
|
|
|
_check_unique_resources(entries, arg_name)
|
|
|
|
|
return tree_unflatten(treedef, entries), entries, treedef
|
|
|
|
|
|
|
|
|
|
def _check_unique_resources(axis_resources, arg_name):
|
|
|
|
|
for arg_axis_resources in axis_resources:
|
|
|
|
|
if not arg_axis_resources: continue
|
|
|
|
|
resource_counts = Counter(it.chain.from_iterable(arg_axis_resources))
|
2021-05-05 06:43:47 -07:00
|
|
|
|
if not resource_counts: continue
|
2021-04-26 03:45:31 -07:00
|
|
|
|
if resource_counts.most_common(1)[0][1] > 1:
|
|
|
|
|
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
|
|
|
|
|
if multiple_uses:
|
|
|
|
|
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
|
2021-05-06 12:34:15 -07:00
|
|
|
|
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
|
2021-04-26 03:45:31 -07:00
|
|
|
|
f"has duplicate entries for {maps.show_axes(multiple_uses)}")
|
|
|
|
|
|
2021-04-30 09:56:53 -07:00
|
|
|
|
def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape, flat_avals, flat_axis_resources):
|
|
|
|
|
global_str = " global" if is_global_shape else ""
|
|
|
|
|
for aval, aval_axis_resources in zip(flat_avals, flat_axis_resources):
|
|
|
|
|
shape = aval.shape
|
2021-04-20 11:39:33 -07:00
|
|
|
|
if len(shape) < len(aval_axis_resources):
|
|
|
|
|
raise ValueError(f"One of {what} was given the resource assignment "
|
2021-05-06 12:34:15 -07:00
|
|
|
|
f"of {aval_axis_resources.user_spec}, which implies that "
|
2021-04-26 03:45:31 -07:00
|
|
|
|
f"it has a rank of at least {len(aval_axis_resources)}, "
|
|
|
|
|
f"but it is {len(shape)}")
|
2021-04-15 06:12:18 -07:00
|
|
|
|
for i, axis_resources in enumerate(aval_axis_resources):
|
|
|
|
|
try:
|
2021-04-30 09:56:53 -07:00
|
|
|
|
size = int(np.prod([mesh_shape[resource] for resource in axis_resources], dtype=np.int64))
|
2021-04-15 06:12:18 -07:00
|
|
|
|
except KeyError as e:
|
|
|
|
|
raise ValueError(f"One of {what} was given the resource assignment "
|
2021-05-06 12:34:15 -07:00
|
|
|
|
f"of {aval_axis_resources.user_spec}, but resource axis "
|
2021-04-30 09:56:53 -07:00
|
|
|
|
f"{e.args[0]} is undefined. Did you forget to declare the mesh?") from None
|
2021-04-15 06:12:18 -07:00
|
|
|
|
if shape[i] % size != 0:
|
|
|
|
|
raise ValueError(f"One of {what} was given the resource assignment "
|
2021-05-06 12:34:15 -07:00
|
|
|
|
f"of {aval_axis_resources.user_spec}, which implies that "
|
2021-04-30 09:56:53 -07:00
|
|
|
|
f"the{global_str} size of its dimension {i} should be "
|
|
|
|
|
f"divisible by {size}, but it is equal to {shape[i]}")
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
2021-04-30 09:56:53 -07:00
|
|
|
|
# -------------------- pjit rules --------------------
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
2021-04-30 09:56:53 -07:00
|
|
|
|
pjit_p = core.Primitive("pjit")
|
|
|
|
|
pjit_p.multiple_results = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _pjit_call_impl(*args, jaxpr,
|
|
|
|
|
in_axis_resources, out_axis_resources,
|
|
|
|
|
resource_env, donated_invars, name):
|
|
|
|
|
compiled = pjit_callable(
|
|
|
|
|
jaxpr, in_axis_resources, out_axis_resources,
|
|
|
|
|
resource_env, donated_invars, name)
|
|
|
|
|
distributed_debug_log(("Running pjit'd function", name),
|
|
|
|
|
("mesh", resource_env.physical_mesh))
|
|
|
|
|
return compiled(*args)
|
|
|
|
|
pjit_p.def_impl(_pjit_call_impl)
|
|
|
|
|
|
|
|
|
|
@cache()
|
|
|
|
|
def pjit_callable(
|
|
|
|
|
jaxpr: core.ClosedJaxpr,
|
2021-04-26 03:45:31 -07:00
|
|
|
|
in_axis_resources: Tuple[ParsedPartitionSpec, ...],
|
2021-04-30 09:56:53 -07:00
|
|
|
|
out_axis_resources: Tuple[ParsedPartitionSpec, ...],
|
2021-02-05 16:50:38 -08:00
|
|
|
|
resource_env,
|
|
|
|
|
donated_invars,
|
2021-04-30 09:56:53 -07:00
|
|
|
|
name: str):
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
|
|
|
|
in_axes = [get_array_mapping(axes) for axes in in_axis_resources]
|
2021-04-30 09:56:53 -07:00
|
|
|
|
out_axes = [get_array_mapping(axes) for axes in out_axis_resources]
|
|
|
|
|
fun = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
|
|
|
|
local_in_avals = global_to_local(resource_env.physical_mesh,
|
|
|
|
|
jaxpr.in_avals, in_axis_resources)
|
2021-02-05 16:50:38 -08:00
|
|
|
|
# TODO(skye): allow for using a submesh of physical_mesh
|
|
|
|
|
return pxla.mesh_callable(fun, name, None, resource_env.physical_mesh,
|
2021-04-22 02:46:42 -07:00
|
|
|
|
in_axes, out_axes, donated_invars,
|
2021-04-27 02:19:18 -07:00
|
|
|
|
True, *local_in_avals, tile_by_mesh_axes=False,
|
|
|
|
|
do_resource_typecheck="pjit")
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
2021-04-30 09:56:53 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env, **_):
|
|
|
|
|
return global_to_local(resource_env.physical_mesh,
|
|
|
|
|
jaxpr.out_avals, out_axis_resources)
|
|
|
|
|
pjit_p.def_abstract_eval(_pjit_abstract_eval)
|
|
|
|
|
|
2021-04-21 04:09:30 -07:00
|
|
|
|
|
|
|
|
|
def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
|
2021-04-30 09:56:53 -07:00
|
|
|
|
jaxpr, in_axis_resources, out_axis_resources,
|
2021-04-21 04:09:30 -07:00
|
|
|
|
resource_env, donated_invars):
|
|
|
|
|
mesh = resource_env.physical_mesh
|
|
|
|
|
subc = xc.XlaBuilder(f"pjit_{name}")
|
|
|
|
|
|
|
|
|
|
args = []
|
|
|
|
|
for i, (n, axis_resources) in enumerate(safe_zip(in_nodes, in_axis_resources)):
|
|
|
|
|
# N.B. inlined calls shouldn't have shardings set directly on the inputs or
|
|
|
|
|
# outputs (set_sharding_proto adds an identity operation).
|
|
|
|
|
arg = xb.parameter(subc, i, c.GetShape(n))
|
|
|
|
|
args.append(xb.set_sharding_proto(subc, arg,
|
|
|
|
|
get_sharding_proto(c, n, axis_resources, mesh)))
|
|
|
|
|
|
2021-04-30 09:56:53 -07:00
|
|
|
|
# TODO: Think about how to avoid duplicating constants with the outer jaxpr
|
2021-04-21 04:09:30 -07:00
|
|
|
|
out_nodes = xla.jaxpr_subcomp(
|
2021-04-30 09:56:53 -07:00
|
|
|
|
subc, jaxpr.jaxpr, backend, axis_env, xla._xla_consts(subc, jaxpr.consts),
|
2021-04-21 04:09:30 -07:00
|
|
|
|
extend_name_stack(name_stack, wrap_name(name, "pjit")), *args)
|
2021-04-27 10:29:39 -07:00
|
|
|
|
out_nodes = [
|
|
|
|
|
xb.set_sharding_proto(subc, out,
|
|
|
|
|
get_sharding_proto(subc, out, axis_resources, mesh))
|
|
|
|
|
for out, axis_resources in safe_zip(out_nodes, out_axis_resources)
|
|
|
|
|
]
|
2021-04-21 04:09:30 -07:00
|
|
|
|
|
|
|
|
|
subc = subc.build(xops.Tuple(subc, out_nodes))
|
|
|
|
|
return xops.Call(c, subc, list(in_nodes))
|
2021-04-30 09:56:53 -07:00
|
|
|
|
xla.call_translations[pjit_p] = _pjit_translation_rule
|
|
|
|
|
|
|
|
|
|
|
2021-05-06 12:34:15 -07:00
|
|
|
|
def _pjit_batcher(vals_in, dims_in,
|
|
|
|
|
axis_name, main_type,
|
|
|
|
|
jaxpr, in_axis_resources, out_axis_resources,
|
|
|
|
|
resource_env, donated_invars, name):
|
|
|
|
|
axis_size, = {x.shape[d] for x, d in zip(vals_in, dims_in) if d is not batching.not_mapped}
|
|
|
|
|
# batch_jaxpr expects all batching dimensions to be equal to 0
|
|
|
|
|
vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
|
|
|
|
else x for x, d in zip(vals_in, dims_in)]
|
|
|
|
|
is_mapped_in = [d is not batching.not_mapped for d in dims_in]
|
|
|
|
|
new_jaxpr, is_mapped_out = batching.batch_jaxpr(
|
|
|
|
|
jaxpr, axis_size, is_mapped_in,
|
|
|
|
|
instantiate=False, axis_name=axis_name, main_type=main_type)
|
|
|
|
|
|
|
|
|
|
in_axis_resources = tuple(
|
|
|
|
|
spec.insert_axis_partitions(0, ()) if is_mapped else spec
|
|
|
|
|
for is_mapped, spec in zip(is_mapped_in, in_axis_resources))
|
|
|
|
|
out_axis_resources = tuple(
|
|
|
|
|
spec.insert_axis_partitions(0, ()) if is_mapped else spec
|
|
|
|
|
for is_mapped, spec in zip(is_mapped_out, out_axis_resources))
|
|
|
|
|
vals_out = pjit_p.bind(
|
|
|
|
|
*vals_in,
|
|
|
|
|
jaxpr=new_jaxpr,
|
|
|
|
|
in_axis_resources=in_axis_resources,
|
|
|
|
|
out_axis_resources=out_axis_resources,
|
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
donated_invars=donated_invars,
|
|
|
|
|
name=name)
|
|
|
|
|
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
|
|
|
|
|
return vals_out, dims_out
|
|
|
|
|
batching.initial_style_batchers[pjit_p] = _pjit_batcher
|
|
|
|
|
|
|
|
|
|
|
2021-04-30 09:56:53 -07:00
|
|
|
|
def _pjit_jvp(primals_in, tangents_in,
|
|
|
|
|
jaxpr, in_axis_resources, out_axis_resources,
|
|
|
|
|
resource_env, donated_invars, name):
|
|
|
|
|
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
|
|
|
|
|
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
|
|
|
|
|
jaxpr, is_nz_tangents_in, instantiate=False)
|
|
|
|
|
|
|
|
|
|
def _filter_zeros(is_nz_l, l):
|
|
|
|
|
return (x for nz, x in zip(is_nz_l, l) if nz)
|
|
|
|
|
_filter_zeros_in = partial(_filter_zeros, is_nz_tangents_in)
|
|
|
|
|
_filter_zeros_out = partial(_filter_zeros, is_nz_tangents_out)
|
|
|
|
|
outputs = pjit_p.bind(
|
|
|
|
|
*primals_in, *_filter_zeros_in(tangents_in),
|
|
|
|
|
jaxpr=jaxpr_jvp,
|
|
|
|
|
in_axis_resources=(*in_axis_resources, *_filter_zeros_in(in_axis_resources)),
|
|
|
|
|
out_axis_resources=(*out_axis_resources, *_filter_zeros_out(out_axis_resources)),
|
|
|
|
|
resource_env=resource_env,
|
|
|
|
|
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
|
|
|
|
|
name=wrap_name(name, 'jvp'))
|
|
|
|
|
|
|
|
|
|
primals_out, tangents_out = split_list(outputs, [-len(is_nz_tangents_out)])
|
|
|
|
|
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
|
|
|
|
|
tangents_out_it = iter(tangents_out)
|
|
|
|
|
return primals_out, [next(tangents_out_it) if nz else ad.Zero(aval)
|
|
|
|
|
for nz, aval in zip(is_nz_tangents_out, jaxpr.out_avals)]
|
|
|
|
|
ad.primitive_jvps[pjit_p] = _pjit_jvp
|
|
|
|
|
|
2021-04-22 15:30:03 -07:00
|
|
|
|
|
2021-04-27 02:19:18 -07:00
|
|
|
|
def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_axis_resources):
|
|
|
|
|
pjit_resources = set(it.chain.from_iterable(pos_axis_resources))
|
|
|
|
|
aval_resources = set(it.chain.from_iterable(
|
|
|
|
|
named_axis_resources[a] for a in aval.named_shape))
|
|
|
|
|
overlap = pjit_resources & aval_resources
|
|
|
|
|
if overlap:
|
|
|
|
|
raise JAXTypeError(
|
2021-05-06 12:34:15 -07:00
|
|
|
|
f"{what} has an axis resources specification of "
|
|
|
|
|
f"{pos_axis_resources.unsynced_user_spec(SpecSync.DIM_PERMUTE)} "
|
2021-04-27 02:19:18 -07:00
|
|
|
|
f"that uses one or more mesh axes already used by xmap to partition "
|
|
|
|
|
f"a named axis appearing in its named_shape (both use mesh axes "
|
|
|
|
|
f"{maps.show_axes(overlap)})")
|
|
|
|
|
|
|
|
|
|
def _resource_typing_pjit(avals, params, source_info, named_axis_resources):
|
2021-04-30 09:56:53 -07:00
|
|
|
|
jaxpr = params["jaxpr"]
|
2021-04-27 02:19:18 -07:00
|
|
|
|
what = "pjit input"
|
2021-04-30 09:56:53 -07:00
|
|
|
|
for aval, pos_axis_resources in zip(jaxpr.in_avals, params['in_axis_resources']):
|
|
|
|
|
_check_resources_against_named_axes(what, aval, pos_axis_resources, named_axis_resources)
|
2021-04-27 02:19:18 -07:00
|
|
|
|
pxla.resource_typecheck(
|
2021-04-30 09:56:53 -07:00
|
|
|
|
jaxpr.jaxpr, named_axis_resources,
|
2021-04-27 02:19:18 -07:00
|
|
|
|
lambda: (f"a pjit'ed function {params['name']} "
|
|
|
|
|
f"(pjit called at {source_info_util.summarize(source_info)})"))
|
|
|
|
|
what = "pjit output"
|
2021-04-30 09:56:53 -07:00
|
|
|
|
for aval, pos_axis_resources in zip(jaxpr.out_avals, params['out_axis_resources']):
|
|
|
|
|
_check_resources_against_named_axes(what, aval, pos_axis_resources, named_axis_resources)
|
|
|
|
|
pxla.custom_resource_typing_rules[pjit_p] = _resource_typing_pjit
|
2021-04-27 02:19:18 -07:00
|
|
|
|
|
2021-04-22 15:30:03 -07:00
|
|
|
|
|
2021-04-21 04:09:30 -07:00
|
|
|
|
# -------------------- with_sharding_constraint --------------------
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
|
|
|
|
def with_sharding_constraint(x, axis_resources):
|
2021-02-24 09:40:29 -08:00
|
|
|
|
x_flat, tree = tree_flatten(x)
|
2021-04-26 03:45:31 -07:00
|
|
|
|
parsed_axis_resources, entries, _ = _prepare_axis_resources(axis_resources, "axis_resources")
|
2021-02-24 09:40:29 -08:00
|
|
|
|
axis_resources_flat = tuple(
|
|
|
|
|
flatten_axes("with_sharding_constraint axis_resources",
|
2021-04-26 03:45:31 -07:00
|
|
|
|
tree, parsed_axis_resources))
|
2021-04-15 06:12:18 -07:00
|
|
|
|
resource_env = maps.thread_resources.env
|
2021-04-30 09:56:53 -07:00
|
|
|
|
mesh = resource_env.physical_mesh
|
2021-04-15 06:12:18 -07:00
|
|
|
|
_check_shapes_against_resources(
|
|
|
|
|
"with_sharding_constraint arguments",
|
2021-04-30 09:56:53 -07:00
|
|
|
|
mesh.is_multi_process, mesh.shape,
|
|
|
|
|
x_flat, axis_resources_flat)
|
2021-04-15 06:12:18 -07:00
|
|
|
|
outs = [sharding_constraint_p.bind(y, axis_resources=r, resource_env=resource_env)
|
2021-02-24 09:40:29 -08:00
|
|
|
|
for y, r in safe_zip(x_flat, axis_resources_flat)]
|
|
|
|
|
return tree_unflatten(tree, outs)
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
|
|
|
|
def _sharding_constraint_impl(x, axis_resources, resource_env):
|
|
|
|
|
# TODO(skye): can we also prevent this from being called in other
|
|
|
|
|
# non-pjit contexts? (e.g. pmap, control flow)
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"with_sharding_constraint() should only be called inside pjit()")
|
|
|
|
|
|
|
|
|
|
def _sharding_constraint_translation_rule(c, x_node, axis_resources, resource_env):
|
|
|
|
|
mesh = resource_env.physical_mesh
|
|
|
|
|
return xb.set_sharding_proto(c, x_node,
|
|
|
|
|
get_sharding_proto(c, x_node, axis_resources, mesh))
|
|
|
|
|
|
|
|
|
|
sharding_constraint_p = core.Primitive("sharding_constraint")
|
|
|
|
|
sharding_constraint_p.def_impl(_sharding_constraint_impl)
|
|
|
|
|
sharding_constraint_p.def_abstract_eval(lambda x, **unused: x)
|
|
|
|
|
ad.deflinear2(sharding_constraint_p,
|
|
|
|
|
lambda ct, _, axis_resources, resource_env: (
|
2021-04-26 06:41:44 -07:00
|
|
|
|
sharding_constraint_p.bind(
|
|
|
|
|
ct, axis_resources=axis_resources, resource_env=resource_env),))
|
2021-02-05 16:50:38 -08:00
|
|
|
|
xla.translations[sharding_constraint_p] = _sharding_constraint_translation_rule
|
|
|
|
|
|
2021-04-27 02:19:18 -07:00
|
|
|
|
def _resource_typing_sharding_constraint(avals, params, source_info, named_axis_resources):
|
|
|
|
|
aval, = avals
|
|
|
|
|
_check_resources_against_named_axes(
|
|
|
|
|
"with_sharding_constraint input", aval,
|
|
|
|
|
params['axis_resources'], named_axis_resources)
|
|
|
|
|
pxla.custom_resource_typing_rules[sharding_constraint_p] = \
|
|
|
|
|
_resource_typing_sharding_constraint
|
|
|
|
|
|
2021-04-21 04:09:30 -07:00
|
|
|
|
# -------------------- helpers --------------------
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
2021-04-26 03:45:31 -07:00
|
|
|
|
def get_array_mapping(axis_resources: ParsedPartitionSpec) -> pxla.ArrayMapping:
|
|
|
|
|
return OrderedDict((axis, i)
|
|
|
|
|
for i, axes in enumerate(axis_resources)
|
|
|
|
|
for axis in axes)
|
2021-02-05 16:50:38 -08:00
|
|
|
|
|
2021-06-01 14:32:59 +03:00
|
|
|
|
def get_sharding_proto(c, xla_op, axis_resources: ParsedPartitionSpec,
|
|
|
|
|
mesh: maps.Mesh) -> xc.OpSharding:
|
2021-02-05 16:50:38 -08:00
|
|
|
|
xla_shape = c.GetShape(xla_op)
|
2021-04-27 10:29:39 -07:00
|
|
|
|
if xla_shape.is_token():
|
|
|
|
|
aval = core.abstract_token
|
|
|
|
|
assert axis_resources is REPLICATED
|
|
|
|
|
else:
|
|
|
|
|
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type())
|
2021-06-01 22:50:12 -07:00
|
|
|
|
return get_aval_sharding_proto(aval, axis_resources, mesh)
|
2021-06-01 14:32:59 +03:00
|
|
|
|
|
|
|
|
|
|
2021-06-01 22:50:12 -07:00
|
|
|
|
def get_aval_sharding_proto(aval: core.AbstractValue,
|
2021-06-01 14:32:59 +03:00
|
|
|
|
axis_resources: ParsedPartitionSpec,
|
|
|
|
|
mesh: maps.Mesh) -> xc.OpSharding:
|
2021-02-05 16:50:38 -08:00
|
|
|
|
array_mapping = get_array_mapping(axis_resources)
|
|
|
|
|
sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
|
|
|
|
|
aval, array_mapping)
|
|
|
|
|
return sharding_spec.sharding_proto()
|
2021-04-30 09:56:53 -07:00
|
|
|
|
|
|
|
|
|
def global_to_local(mesh, avals, axes):
|
|
|
|
|
return [mesh.global_to_local(get_array_mapping(aval_axes), aval)
|
|
|
|
|
for aval, aval_axes in zip(avals, axes)]
|
|
|
|
|
|
|
|
|
|
def local_to_global(mesh, avals, axes):
|
|
|
|
|
return [mesh.local_to_global(get_array_mapping(aval_axes), aval)
|
|
|
|
|
for aval, aval_axes in zip(avals, axes)]
|