rocm_jax/jax/_src/api.py

3357 lines
141 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX user-facing transformations and utilities.
The transformations here mostly wrap internal transformations, providing
convenience flags to control behavior and handling Python containers of
arguments and outputs. The Python containers handled are pytrees (see
tree_util.py), which include nested tuples/lists/dicts, where the leaves are
arrays.
"""
import collections
import functools
from functools import partial
import inspect
import itertools as it
from typing import (Any, Callable, Generator, Iterable, NamedTuple, Mapping,
Optional, Sequence, Tuple, TypeVar, Union, overload, Dict,
Hashable, List)
from typing_extensions import Literal
from warnings import warn
import numpy as np
from contextlib import contextmanager, ExitStack
import jax
from jax import core
from jax import linear_util as lu
from jax import stages
from jax.core import eval_jaxpr
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,
treedef_is_leaf, treedef_children,
Partial, PyTreeDef, all_leaves, treedef_tuple)
from jax._src import callback as jcb
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple, argnames_partial_except,
validate_argnames, validate_argnums)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, new_name_stack, wrap_name, cache,
wraps, HashableFunction, weakref_lru_cache)
# Unused imports to be exported
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, process_index,
process_count, host_id, host_ids,
host_count, default_backend)
from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint
from jax._src.ad_checkpoint import _remat_static_argnums
from jax.core import ShapedArray, raise_to_shaped
from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
custom_vjp, linear_call)
from jax.custom_transpose import custom_transpose
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.config import (
flags, config, bool_env,
disable_jit as _disable_jit,
debug_nans as config_debug_nans,
debug_infs as config_debug_infs,
_thread_local_state as config_thread_local_state,
explicit_device_put_scope as config_explicit_device_put_scope,
explicit_device_get_scope as config_explicit_device_get_scope)
traceback_util.register_exclusion(__file__)
_dtype = partial(dtypes.dtype, canonicalize=True)
AxisName = Any
# These TypeVars are used below to express the fact that function types
# (i.e. call signatures) are invariant under the vmap transformation.
F = TypeVar("F", bound=Callable)
T = TypeVar("T")
U = TypeVar("U")
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
FLAGS = flags.FLAGS
flags.DEFINE_bool(
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", True),
"A flag enabling the C++ jax.jit fast path."
"Set this to `False` only if it crashes otherwise and report "
"the error to the jax-team.")
flags.DEFINE_bool(
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True),
"A flag enabling the C++ jax.pmap fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")
flags.DEFINE_bool(
"experimental_cpp_pjit", bool_env("JAX_CPP_PJIT", False),
"A flag enabling the C++ pjit fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")
def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
leaves = tree_leaves(output)
buffers = []
for da_or_sda in leaves:
if hasattr(da_or_sda, "device_buffer"):
buffers.append(da_or_sda.device_buffer)
elif hasattr(da_or_sda, "device_buffers"):
buffers.extend(da_or_sda.device_buffers)
try:
dispatch.check_special(xla.xla_call_p, buffers)
except FloatingPointError:
# compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs
print("Invalid nan value encountered in the output of a C++-jit/pmap "
"function. Calling the de-optimized version.")
fun._cache_miss(*args, **kwargs)[0] # probably won't return
def _update_debug_special_global(_):
if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
jax_jit.global_state().post_hook = _nan_check_posthook
else:
jax_jit.global_state().post_hook = None
def _update_debug_special_thread_local(_):
if (getattr(config_thread_local_state, "jax_debug_nans", False) or
getattr(config_thread_local_state, "jax_debug_infs", False)):
jax_jit.thread_local_state().post_hook = _nan_check_posthook
else:
jax_jit.thread_local_state().post_hook = None
config_debug_nans._add_hooks(_update_debug_special_global,
_update_debug_special_thread_local)
config_debug_infs._add_hooks(_update_debug_special_global,
_update_debug_special_thread_local)
float0 = dtypes.float0
def _check_callable(fun):
# In Python 3.10+, the only thing stopping us from supporting staticmethods
# is that we can't take weak references to them, which the C++ JIT requires.
if isinstance(fun, staticmethod):
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
if not callable(fun):
raise TypeError(f"Expected a callable value, got {fun}")
if _isgeneratorfunction(fun):
raise TypeError(f"Expected a function, got a generator function: {fun}")
def _isgeneratorfunction(fun):
# TODO 3.9+: remove
# re-implemented here because of https://bugs.python.org/issue33261
while inspect.ismethod(fun):
fun = fun.__func__
while isinstance(fun, functools.partial):
fun = fun.func
return inspect.isfunction(fun) and bool(fun.__code__.co_flags & inspect.CO_GENERATOR)
_POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD
def _infer_argnums_and_argnames(
sig: inspect.Signature,
argnums: Union[int, Iterable[int], None],
argnames: Union[str, Iterable[str], None],
) -> Tuple[Tuple[int, ...], Tuple[str, ...]]:
"""Infer missing argnums and argnames for a function with inspect."""
if argnums is None and argnames is None:
return (), ()
if argnums is not None and argnames is not None:
argnums = _ensure_index_tuple(argnums)
argnames = _ensure_str_tuple(argnames)
return argnums, argnames
parameters = sig.parameters
if argnums is None:
assert argnames is not None
argnames = _ensure_str_tuple(argnames)
argnums = tuple(
i for i, (k, param) in enumerate(parameters.items())
if param.kind == _POSITIONAL_OR_KEYWORD and k in argnames
)
else:
argnums = _ensure_index_tuple(argnums)
argnames = tuple(
k for i, (k, param) in enumerate(parameters.items())
if param.kind == _POSITIONAL_OR_KEYWORD and i in argnums
)
return argnums, argnames
def jit(
fun: Callable,
*,
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
keep_unused: bool = False,
abstracted_axes: Optional[Any] = None,
) -> stages.Wrapped:
"""Sets up ``fun`` for just-in-time compilation with XLA.
Args:
fun: Function to be jitted. ``fun`` should be a pure function, as
side-effects may only be executed once.
The arguments and return value of ``fun`` should be arrays,
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
Positional arguments indicated by ``static_argnums`` can be anything at
all, provided they are hashable and have an equality operation defined.
Static arguments are included as part of a compilation cache key, which is
why hash and equality operators must be defined.
JAX keeps a weak reference to ``fun`` for use as a compilation cache key,
so the object ``fun`` must be weakly-referenceable. Most :class:`Callable`
objects will already satisfy this requirement.
static_argnums: An optional int or collection of ints that specify which
positional arguments to treat as static (compile-time constant).
Operations that only depend on static arguments will be constant-folded in
Python (during tracing), and so the corresponding argument values can be
any Python object.
Static arguments should be hashable, meaning both ``__hash__`` and
``__eq__`` are implemented, and immutable. Calling the jitted function
with different values for these constants will trigger recompilation.
Arguments that are not arrays or containers thereof must be marked as
static.
If neither ``static_argnums`` nor ``static_argnames`` is provided, no
arguments are treated as static. If ``static_argnums`` is not provided but
``static_argnames`` is, or vice versa, JAX uses
:code:`inspect.signature(fun)` to find any positional arguments that
correspond to ``static_argnames``
(or vice versa). If both ``static_argnums`` and ``static_argnames`` are
provided, ``inspect.signature`` is not used, and only actual
parameters listed in either ``static_argnums`` or ``static_argnames`` will
be treated as static.
static_argnames: An optional string or collection of strings specifying
which named arguments to treat as static (compile-time constant). See the
comment on ``static_argnums`` for details. If not
provided but ``static_argnums`` is set, the default is based on calling
``inspect.signature(fun)`` to find corresponding named arguments.
device: This is an experimental feature and the API is likely to change.
Optional, the Device the jitted function will run on. (Available devices
can be retrieved via :py:func:`jax.devices`.) The default is inherited
from XLA's DeviceAssignment logic and is usually to use
``jax.devices()[0]``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
donate_argnums: Specify which positional argument buffers are "donated" to
the computation. It is safe to donate argument buffers if you no longer
need them once the computation has finished. In some cases XLA can make
use of donated buffers to reduce the amount of memory needed to perform a
computation, for example recycling one of your input buffers to store a
result. You should not reuse buffers that you donate to a computation, JAX
will raise an error if you try to. By default, no argument buffers are
donated.
Note that donate_argnums only work for positional arguments, and keyword
arguments will not be donated.
For more details on buffer donation see the
[FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
inline: Specify whether this function should be inlined into enclosing
jaxprs (rather than being represented as an application of the xla_call
primitive with its own subjaxpr). Default False.
keep_unused: If `False` (the default), arguments that JAX determines to be
unused by `fun` *may* be dropped from resulting compiled XLA executables.
Such arguments will not be transferred to the device nor provided to the
underlying executable. If `True`, unused arguments will not be pruned.
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation.
Examples:
In the following example, ``selu`` can be compiled into a single fused kernel
by XLA:
>>> import jax
>>>
>>> @jax.jit
... def selu(x, alpha=1.67, lmbda=1.05):
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.PRNGKey(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x)) # doctest: +SKIP
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232 0.76827 0.59566 ]
To pass arguments such as ``static_argnames`` when decorating a function, a common
pattern is to use :func:`functools.partial`:
>>> from functools import partial
>>>
>>> @partial(jax.jit, static_argnames=['n'])
... def g(x, n):
... for i in range(n):
... x = x ** 2
... return x
>>>
>>> g(jnp.arange(4), 3)
DeviceArray([ 0, 1, 256, 6561], dtype=int32)
"""
if abstracted_axes and not config.jax_dynamic_shapes:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
if FLAGS.experimental_cpp_jit and not config.jax_dynamic_shapes:
return _jit(True, fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline, keep_unused)
return _jit(False, fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline, keep_unused, abstracted_axes)
def _jit(
use_cpp_jit: bool,
fun: Callable,
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
keep_unused: bool = False,
abstracted_axes: Optional[Any] = None,
) -> stages.Wrapped:
# Implemements common logic between CPP and Python backends
_check_callable(fun)
# Coerce input
donate_argnums = _ensure_index_tuple(donate_argnums)
try:
sig = inspect.signature(fun)
except ValueError:
# Some built-in functions don't support signature.
# See: https://github.com/python/cpython/issues/73485
# In this case no validation is done
static_argnums = () if static_argnums is None else _ensure_index_tuple(static_argnums)
static_argnames = () if static_argnames is None else _ensure_str_tuple(static_argnames)
else:
# Infer argnums and argnames according to docstring
static_argnums, static_argnames = _infer_argnums_and_argnames(
sig, static_argnums, static_argnames)
# Validation
validate_argnums(sig, static_argnums, "static_argnums")
validate_argnums(sig, donate_argnums, "donate_argnums")
validate_argnames(sig, static_argnames, "static_argnames")
# Compensate for static argnums absorbing args
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
if use_cpp_jit:
return _cpp_jit(
fun, static_argnums=static_argnums, static_argnames=static_argnames,
device=device, backend=backend, donate_argnums=donate_argnums,
inline=inline, keep_unused=keep_unused)
return _python_jit(
fun, static_argnums=static_argnums, static_argnames=static_argnames,
device=device, backend=backend, donate_argnums=donate_argnums,
inline=inline, keep_unused=keep_unused, abstracted_axes=abstracted_axes)
def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
args, kwargs):
# Validate donate_argnums
if max(donate_argnums, default=-1) >= len(args):
raise ValueError(
f"jitted function has donate_argnums={donate_argnums} but "
f"was called with only {len(args)} positional arguments.")
f = lu.wrap_init(fun)
f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
f, kwargs = argnames_partial_except(f, static_argnames, kwargs)
args_flat, in_tree = tree_flatten((args, kwargs))
if donate_argnums:
donated_invars = donation_vector(donate_argnums, args, kwargs)
else:
donated_invars = (False,) * len(args_flat)
return f, in_tree, args_flat, donated_invars
PytreeOfAbstractedAxesSpec = Any
def _python_jit(
fun: Callable,
*,
static_argnums: Tuple[int, ...],
static_argnames: Tuple[str, ...],
device: Optional[xc.Device],
backend: Optional[str],
donate_argnums: Tuple[int, ...],
inline: bool,
keep_unused: bool,
abstracted_axes: Optional[PytreeOfAbstractedAxesSpec],
) -> stages.Wrapped:
@wraps(fun)
@api_boundary
def f_jitted(*args, **kwargs):
if config.jax_disable_jit:
return fun(*args, **kwargs)
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
for arg in args_flat:
_check_arg(arg)
if jax.config.jax_dynamic_shapes:
axes_specs = (None if abstracted_axes is None else
_flat_axes_specs(abstracted_axes, *args, **kwargs))
in_type = pe.infer_lambda_input_type(axes_specs, args_flat)
flat_fun = lu.annotate(flat_fun, in_type)
out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline,
keep_unused=keep_unused)
return tree_unflatten(out_tree(), out_flat)
f_jitted.lower = _jit_lower(fun, static_argnums, static_argnames, device,
backend, donate_argnums, inline, keep_unused,
abstracted_axes)
def clear_cache():
dispatch.xla_callable.evict_function(fun)
f_jitted.clear_cache = clear_cache
return f_jitted
def _flat_axes_specs(abstracted_axes, *args, **kwargs
) -> List[pe.AbstractedAxesSpec]:
if kwargs: raise NotImplementedError
def ax_leaf(l):
return (isinstance(l, dict) and all_leaves(l.values()) or
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
return broadcast_prefix(abstracted_axes, args, ax_leaf)
class _BackendAndDeviceInfo(NamedTuple):
default_device: xc.Device
committed_to_device: bool
class _FastpathData(NamedTuple):
xla_executable: xla.XlaExecutable
out_pytree_def: Any
sticky_device: Optional[xc.Device]
avals: Iterable[Any]
lazy_exprs: Iterable[Any]
kept_var_bitvec: Iterable[bool]
shardings: Iterable[Any]
committed: Iterable[bool]
_cpp_jit_cache = jax_jit.CompiledFunctionCache()
def _cpp_jit_clear_cache(self):
self._clear_cache()
dispatch.xla_callable.evict_function(self._fun)
def _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
use_fastpath = (
xc._version >= 92 and
# This is if we have already executed this code-path (most-recent entry
# has been reset to None). Thus, we do not support the fast-path.
execute is not None and
type(execute) is pxla.ExecuteReplicated and
# No effects in computation
not execute.ordered_effects and
not execute.has_unordered_effects and
not execute.has_host_callbacks and
all(isinstance(x, xc.Array) for x in out_flat) and
# Not supported: dynamic shapes
not jax.config.jax_dynamic_shapes
# TODO(chky): Check sharding is SingleDeviceSharding
)
if use_fastpath:
sticky_device = None
lazy_exprs = [None] * len(out_flat)
kept_var_bitvec = [i in execute.kept_var_idx for i in range(len(args_flat))]
avals = [out.aval for out in out_flat]
shardings = [out.sharding for out in out_flat]
committed = [out._committed for out in out_flat]
return _FastpathData(execute.xla_executable, out_pytree_def, sticky_device,
avals, lazy_exprs, kept_var_bitvec, shardings,
committed)
return None
def _device_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
# TODO(sharadmv): Clean up usage of `execute.args`
use_fastpath = (
# This is if we have already executed this code-path (most-recent entry
# has been reset to None). Thus, we do not support the fast-path.
execute is not None and
execute.func is dispatch._execute_compiled and # not trivial, not pmap
# No effects in computation
not execute.args[5] and not execute.args[6] and
# Has no host callbacks
not execute.args[8] and
# Not supported: ShardedDeviceArray
all(device_array.type_is_device_array(x) for x in out_flat) and
# Not supported: dynamic shapes
not jax.config.jax_dynamic_shapes
and type(execute.args[4]) is dispatch.SimpleResultHandler)
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
(_, xla_executable, _, _, result_handlers, _, _, kept_var_idx,
_) = execute.args # pytype: disable=attribute-error
sticky_device = None
avals = []
lazy_exprs = [None] * len(result_handlers)
for result_handler in result_handlers:
aval, sticky_device = result_handler.args
avals.append(aval)
assert len(avals) == len(out_flat)
kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))]
shardings = []
committed = []
return _FastpathData(xla_executable, out_pytree_def, sticky_device, avals,
lazy_exprs, kept_var_bitvec, shardings, committed)
return None
def _cpp_jit(
fun: Callable,
*,
static_argnums: Tuple[int, ...],
static_argnames: Tuple[str, ...],
device: Optional[xc.Device],
backend: Optional[str],
donate_argnums: Tuple[int, ...],
inline: bool,
keep_unused: bool,
) -> stages.Wrapped:
# An implementation of `jit` that tries to do as much as possible in C++.
# The goal of this function is to speed up the time it takes to process the
# arguments, find the correct C++ executable, start the transfer of arguments
# and schedule the computation.
# As long as it does not support all features of the Python implementation
# the C++ code will fallback to `_python_jit` when it faces some unsupported
# feature.
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got device={device} and backend={backend}.")
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
if jax.config.jax_dynamic_shapes:
in_type = pe.infer_lambda_input_type(None, args_flat)
flat_fun = lu.annotate(flat_fun, in_type)
out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline, keep_unused=keep_unused)
out_pytree_def = out_tree()
out = tree_unflatten(out_pytree_def, out_flat)
### Decide whether we can support the C++ fast path
# High level note: The Python tracing mechanism is complex; in particular
# to know whether `jax.jit(f)(x)` will execute or trace, it's not enough to
# inspect the argument x, we actually do need to execute it and look at the
# outputs that could be tracers (if f is capturing `Tracer` by closure).
execute: Optional[functools.partial] = (
dispatch.xla_callable.most_recent_entry())
fastpath_data = None
# TODO(sharadmv): Enable fast path for effectful jaxprs
if jax.config.jax_array:
fastpath_data = _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat)
else:
fastpath_data = _device_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat)
return out, fastpath_data
def get_device_info():
"""Backends do not exist before __main__ is being executed."""
committed_to_device = device is not None or backend is not None
if device is not None:
default_device = device
else:
backend_ = xb.get_backend(backend)
default_device = backend_.get_default_device_assignment(1)[0]
return _BackendAndDeviceInfo(default_device, committed_to_device)
jitted_f_kwargs = {}
jitted_f_kwargs["has_explicit_device"] = (
device is not None or backend is not None)
cpp_jitted_f = jax_jit.jit(
fun,
cache_miss,
get_device_info,
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
cache=_cpp_jit_cache,
**jitted_f_kwargs) # type: ignore
f_jitted = wraps(fun)(cpp_jitted_f)
f_jitted.lower = _jit_lower(fun, static_argnums, static_argnames, device,
backend, donate_argnums, inline, keep_unused,
None)
f_jitted._fun = fun
type(f_jitted).clear_cache = _cpp_jit_clear_cache
return f_jitted
def _jit_lower(fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline, keep_unused: bool,
abstracted_axes: Optional[PytreeOfAbstractedAxesSpec]):
"""Make a ``lower`` method for jitted functions."""
# If the function we returned from ``jit`` were a class instance,
# this might naturally be a method, with ``fun`` as a ``self`` and
# all the other arguments stored as attributes.
def arg_spec(x):
from jax.experimental.sharding import PmapSharding
# like xla.arg_spec but duck-types on x.shape and x.dtype
aval = None if jax.config.jax_dynamic_shapes else shaped_abstractify(x)
if jax.config.jax_array:
if hasattr(x, 'sharding'):
if isinstance(x.sharding, PmapSharding):
return aval, None
return aval, (x.sharding if x._committed else None)
else:
return aval, None
else:
device = getattr(x, '_device', None)
return aval, device
@api_boundary
def lower(*args, **kwargs) -> stages.Lowered:
"""Lower this function for the given arguments.
A lowered function is staged out of Python and translated to a
compiler's input language, possibly in a backend-dependent
manner. It is ready for compilation but not yet compiled.
Returns:
A ``Lowered`` instance representing the lowering.
"""
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
arg_specs_and_devices = map(arg_spec, args_flat)
if jax.config.jax_dynamic_shapes:
axes_specs = (None if abstracted_axes is None else
_flat_axes_specs(abstracted_axes, *args, **kwargs))
in_type = pe.infer_lambda_input_type(axes_specs, args_flat)
flat_fun = lu.annotate(flat_fun, in_type)
in_avals = [aval for aval, explicit in in_type if explicit]
else:
if abstracted_axes:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
in_avals, _ = unzip2(arg_specs_and_devices)
if jax.config.jax_array:
computation = dispatch.sharded_lowering(
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,
keep_unused, *arg_specs_and_devices)
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree())
else:
computation = dispatch.lower_xla_callable(
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,
keep_unused, *arg_specs_and_devices)
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree())
return lower
@contextmanager
def disable_jit(disable: bool = True):
"""Context manager that disables :py:func:`jit` behavior under its dynamic context.
For debugging it is useful to have a mechanism that disables :py:func:`jit`
everywhere in a dynamic context. Note that this not only disables explicit
uses of `jit` by the user, but will also remove any implicit JIT compilation
used by the JAX library: this includes implicit JIT computation of `body` and
`cond` functions passed to higher-level primitives like :func:`scan` and
:func:`while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
and any other case where `jit` is used within an API's implementation.
Values that have a data dependence on the arguments to a jitted function are
traced and abstracted. For example, an abstract value may be a
:py:class:`ShapedArray` instance, representing the set of all possible arrays
with a given shape and dtype, but not representing one concrete array with
specific values. You might notice those if you use a benign side-effecting
operation in a jitted function, like a print:
>>> import jax
>>>
>>> @jax.jit
... def f(x):
... y = x * 2
... print("Value of y is", y)
... return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=0/1)>
[5 7 9]
Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`,
which represents an array with a fixed shape and type but an arbitrary value.
The value of ``y`` is also traced. If we want to see a concrete value while
debugging, and avoid the tracer too, we can use the :py:func:`disable_jit`
context manager:
>>> import jax
>>>
>>> with jax.disable_jit():
... print(f(jax.numpy.array([1, 2, 3])))
...
Value of y is [2 4 6]
[5 7 9]
"""
with _disable_jit(disable):
yield
def xla_computation(fun: Callable,
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
in_parts=None, out_parts=None,
backend: Optional[str] = None,
tuple_args: bool = False,
instantiate_const_outputs: Optional[bool] = None,
return_shape: bool = False,
donate_argnums: Union[int, Iterable[int]] = ()) -> Callable:
"""Creates a function that produces its XLA computation given example args.
Args:
fun: Function from which to form XLA computations.
static_argnums: See the :py:func:`jax.jit` docstring.
axis_env: Optional, a sequence of pairs where the first element is an axis
name and the second element is a positive integer representing the size of
the mapped axis with that name. This parameter is useful when lowering
functions that involve parallel communication collectives, and it
specifies the axis name/size environment that would be set up by
applications of :py:func:`jax.pmap`. See the examples below.
in_parts: Optional, how each argument to ``fun`` should be partitioned or
replicated. This is used to specify partitioned XLA computations, see
``sharded_jit`` for more info.
out_parts: Optional, how each output of ``fun`` should be partitioned or
replicated. This is used to specify partitioned XLA computations, see
``sharded_jit`` for more info.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
tuple_args: Optional bool, defaults to ``False``. If ``True``, the resulting
XLA computation will have a single tuple argument that is unpacked into
the specified function arguments. If `None`, tupling will be enabled when
there are more than 100 arguments, since some platforms have limits on
argument arity.
instantiate_const_outputs: Deprecated argument, does nothing.
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the XLA
computation and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape``,
``dtype``, and ``named_shape`` attributes representing the corresponding
types of the output leaves.
donate_argnums: Specify which arguments are "donated" to the computation.
It is safe to donate arguments if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
buffers to reduce the amount of memory needed to perform a computation,
for example recycling one of your input buffers to store a result. You
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
a built XLA Computation (see xla_client.py), from which representations of
the unoptimized XLA HLO computation can be extracted using methods like
``as_hlo_text``, ``as_serialized_hlo_module_proto``, and
``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the
wrapped function returns a pair where the first element is the XLA
Computation and the second element is a pytree representing the structure,
shapes, dtypes, and named shapes of the output of ``fun``.
Concrete example arguments are not always necessary. For those arguments not
indicated by ``static_argnums``, any object with ``shape`` and ``dtype``
attributes is acceptable (excepting namedtuples, which are treated as Python
containers).
For example:
>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> c = jax.xla_computation(f)(3.)
>>> print(c.as_hlo_text()) # doctest: +SKIP
HloModule xla_computation_f.6
<BLANKLINE>
ENTRY xla_computation_f.6 {
constant.2 = pred[] constant(false)
parameter.1 = f32[] parameter(0)
cosine.3 = f32[] cosine(parameter.1)
sine.4 = f32[] sine(cosine.3)
ROOT tuple.5 = (f32[]) tuple(sine.4)
}
<BLANKLINE>
<BLANKLINE>
Alternatively, the assignment to ``c`` above could be written:
>>> import types
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> c = jax.xla_computation(f)(scalar)
Here's an example that involves a parallel collective and axis name:
>>> def f(x): return x - jax.lax.psum(x, 'i')
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2)
>>> print(c.as_hlo_text()) # doctest: +SKIP
HloModule jaxpr_computation.9
primitive_computation.3 {
parameter.4 = s32[] parameter(0)
parameter.5 = s32[] parameter(1)
ROOT add.6 = s32[] add(parameter.4, parameter.5)
}
ENTRY jaxpr_computation.9 {
tuple.1 = () tuple()
parameter.2 = s32[] parameter(0)
all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3
ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7)
}
<BLANKLINE>
<BLANKLINE>
Notice the ``replica_groups`` that were generated. Here's an example that
generates more interesting ``replica_groups``:
>>> from jax import lax
>>> def g(x):
... rowsum = lax.psum(x, 'i')
... colsum = lax.psum(x, 'j')
... allsum = lax.psum(x, ('i', 'j'))
... return rowsum, colsum, allsum
...
>>> axis_env = [('i', 4), ('j', 2)]
>>> c = xla_computation(g, axis_env=axis_env)(5.)
>>> print(c.as_hlo_text()) # doctest: +SKIP
HloModule jaxpr_computation__1.19
[removed uninteresting text here]
ENTRY jaxpr_computation__1.19 {
tuple.1 = () tuple()
parameter.2 = f32[] parameter(0)
all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3
all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8
all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
}
"""
del instantiate_const_outputs # Unused
_check_callable(fun)
static_argnums = _ensure_index_tuple(static_argnums)
donate_argnums = _ensure_index_tuple(donate_argnums)
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
fun_name = getattr(fun, "__name__", "unknown")
platform = backend if backend is not None else xb.get_backend().platform
def make_axis_env(nreps):
if axis_env is None:
return xla.AxisEnv(nreps, (), ())
else:
nreps = nreps * prod(size for name, size in axis_env)
names, sizes = unzip2(axis_env)
return xla.AxisEnv(nreps, names, sizes)
@wraps(fun)
@api_boundary
def computation_maker(*args, **kwargs):
if max(static_argnums + donate_argnums, default=-1) >= len(args):
raise ValueError(f"jitted function has static_argnums={static_argnums},"
f" donate_argnums={donate_argnums} but "
f"was called with only {len(args)} positional arguments.")
f = lu.wrap_init(fun)
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
if donate_argnums:
donated_invars = donation_vector(donate_argnums, dyn_args, kwargs)
else:
donated_invars = (False,) * len(args_flat)
if in_parts is None:
in_parts_flat = None
else:
in_parts_flat = tuple(flatten_axes(
"xla_computation in_parts", in_tree.children()[0], in_parts))
jaxtree_fun, out_tree = flatten_fun(f, in_tree)
avals = map(shaped_abstractify, args_flat)
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
if out_parts is None:
out_parts_flat = None
else:
out_parts_flat = tuple(flatten_axes(
"xla_computation out_parts", out_tree(), out_parts))
unordered_effects = [eff for eff in jaxpr.effects
if eff not in core.ordered_effects]
ordered_effects = [eff for eff in jaxpr.effects
if eff in core.ordered_effects]
lowering_result = mlir.lower_jaxpr_to_module(
f"xla_computation_{fun_name}",
core.ClosedJaxpr(jaxpr, consts),
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
backend_or_name=backend,
platform=platform,
axis_context=mlir.ReplicaAxisContext(axis_env_),
name_stack=new_name_stack(wrap_name(fun_name, "xla_computation")),
donated_args=donated_invars,
arg_shardings=(None if in_parts_flat is None else map(
xla.sharding_to_proto, in_parts_flat)),
result_shardings=(None if out_parts_flat is None else map(
xla.sharding_to_proto, out_parts_flat)))
if tuple_args is not None:
should_tuple = tuple_args
else:
dispatch.should_tuple_args(len(avals), backend.platform)
built = xc._xla.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(lowering_result.module),
use_tuple_args=should_tuple,
return_tuple=True)
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
for out_aval in out_avals:
if not isinstance(out_aval, xla.ShapedArray):
raise RuntimeError("As we want to propagate the weak_type, we need "
"to get a ShapedArray, otherwise this "
"information is lost")
if return_shape:
return built, out_shape
else:
return built
return computation_maker
def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = ()) -> Callable:
"""Creates a function that evaluates the gradient of ``fun``.
Args:
fun: Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, or standard Python containers.
Argument arrays in the positions specified by ``argnums`` must be of
inexact (i.e., floating-point or complex) type. It
should return a scalar (which includes arrays with shape ``()`` but not
arrays with shape ``(1,)`` etc.)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a
function that computes the total gradient while ``grad(f)`` will create
one that computes the per-example gradient.
Returns:
A function with the same arguments as ``fun``, that evaluates the gradient
of ``fun``. If ``argnums`` is an integer then the gradient has the same
shape and type as the positional argument indicated by that integer. If
argnums is a tuple of integers, the gradient is a tuple of values with the
same shapes and types as the corresponding arguments. If ``has_aux`` is True
then a pair of (gradient, auxiliary_data) is returned.
For example:
>>> import jax
>>>
>>> grad_tanh = jax.grad(jax.numpy.tanh)
>>> print(grad_tanh(0.2))
0.961043
"""
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes)
docstr = ("Gradient of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"gradient, which has the same shape as the arguments at "
"positions {argnums}.")
@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def grad_f(*args, **kwargs):
_, g = value_and_grad_f(*args, **kwargs)
return g
@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def grad_f_aux(*args, **kwargs):
(_, aux), g = value_and_grad_f(*args, **kwargs)
return g, aux
return grad_f_aux if has_aux else grad_f
def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
) -> Callable[..., Tuple[Any, Any]]:
"""Create a function that evaluates both ``fun`` and the gradient of ``fun``.
Args:
fun: Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, or standard Python containers. It
should return a scalar (which includes arrays with shape ``()`` but not
arrays with shape ``(1,)`` etc.)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``value_and_grad(f, reduce_axes=('batch',))`` will
create a function that computes the total gradient while
``value_and_grad(f)`` will create one that computes the per-example
gradient.
Returns:
A function with the same arguments as ``fun`` that evaluates both ``fun``
and the gradient of ``fun`` and returns them as a pair (a two-element
tuple). If ``argnums`` is an integer then the gradient has the same shape
and type as the positional argument indicated by that integer. If argnums is
a sequence of integers, the gradient is a tuple of values with the same
shapes and types as the corresponding arguments. If ``has_aux`` is True
then a tuple of ((value, auxiliary_data), gradient) is returned.
"""
docstr = ("Value and gradient of {fun} with respect to positional "
"argument(s) {argnums}. Takes the same arguments as {fun} but "
"returns a two-element tuple where the first element is the value "
"of {fun} and the second element is the gradient, which has the "
"same shape as the arguments at positions {argnums}.")
_check_callable(fun)
argnums = core.concrete_or_error(_ensure_index, argnums)
reduce_axes = _ensure_str_tuple(reduce_axes)
@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def value_and_grad_f(*args, **kwargs):
max_argnum = argnums if isinstance(argnums, int) else max(argnums)
if max_argnum >= len(args):
raise TypeError(f"differentiating with respect to argnums={argnums} requires at least "
f"{max_argnum + 1} positional arguments to be passed by the caller, "
f"but got only {len(args)} positional arguments.")
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args,
require_static_args_hashable=False)
for leaf in tree_leaves(dyn_args):
_check_input_dtype_grad(holomorphic, allow_int, leaf)
if not has_aux:
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
else:
ans, vjp_py, aux = _vjp(
f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes)
_check_scalar(ans)
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
g = vjp_py(lax_internal._one(ans))
g = g[0] if isinstance(argnums, int) else g
if not has_aux:
return ans, g
else:
return (ans, aux), g
return value_and_grad_f
def _check_scalar(x):
msg = "Gradient only defined for scalar-output functions. Output {}.".format
try:
aval = core.get_aval(x)
except TypeError as e:
raise TypeError(msg(f"was {x}")) from e
else:
if isinstance(aval, ShapedArray):
if aval.shape != ():
raise TypeError(msg(f"had shape: {aval.shape}"))
else:
raise TypeError(msg(f"had abstract value {aval}"))
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
_check_arg(x)
aval = core.get_aval(x)
if core.is_opaque_dtype(aval.dtype):
raise TypeError(
f"{name} with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
f"but got {aval.dtype.name}.")
if (dtypes.issubdtype(aval.dtype, np.integer) or
dtypes.issubdtype(aval.dtype, np.bool_)):
if not allow_int:
raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype "
f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. "
"If you want to use Boolean- or integer-valued inputs, use vjp "
"or set allow_int to True.")
elif not dtypes.issubdtype(aval.dtype, np.inexact):
raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a "
f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.")
_check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.get_aval(x)
if core.is_opaque_dtype(aval.dtype):
raise TypeError(
f"{name} with output element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
f"but got {aval.dtype.name}.")
elif dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For holomorphic differentiation, pass holomorphic=True. "
"For differentiation of non-holomorphic functions involving complex "
"outputs, use jax.vjp directly.")
elif not dtypes.issubdtype(aval.dtype, np.floating):
raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For differentiation of functions with integer outputs, use "
"jax.vjp directly.")
_check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad")
def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False) -> Callable:
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default ``0``).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. Default False.
Returns:
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using forward-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
... return jnp.asarray(
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
[[ 1. 0. 0. ]
[ 0. 0. 5. ]
[ 0. 16. -2. ]
[ 1.6209 0. 0.84147]]
"""
_check_callable(fun)
argnums = _ensure_index(argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args,
require_static_args_hashable=False)
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
if not has_aux:
pushfwd = partial(_jvp, f_partial, dyn_args)
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
else:
pushfwd = partial(_jvp, f_partial, dyn_args, has_aux=True)
y, jac, aux = vmap(pushfwd, out_axes=(None, -1, None))(_std_basis(dyn_args))
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacfwd_unravel, example_args), y, jac)
if not has_aux:
return jac_tree
else:
return jac_tree, aux
return jacfun
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
_check_arg(x)
aval = core.get_aval(x)
if core.is_opaque_dtype(aval.dtype):
raise TypeError(
f"jacfwd with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
f"dtype, but got {aval.dtype.name}.")
elif not dtypes.issubdtype(aval.dtype, np.floating):
raise TypeError("jacfwd requires real-valued inputs (input dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For holomorphic differentiation, pass holomorphic=True. "
"For differentiation of non-holomorphic functions involving "
"complex inputs or integer inputs, use jax.jvp directly.")
def _check_output_dtype_jacfwd(holomorphic, x):
aval = core.get_aval(x)
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires outputs with complex dtype, "
f"but got {aval.dtype.name}.")
def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable:
"""Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.
Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default ``0``).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
Returns:
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
... return jnp.asarray(
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacrev(f)(jnp.array([1., 2., 3.])))
[[ 1. 0. 0. ]
[ 0. 0. 5. ]
[ 0. 16. -2. ]
[ 1.6209 0. 0.84147]]
"""
_check_callable(fun)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args,
require_static_args_hashable=False)
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
if not has_aux:
y, pullback = _vjp(f_partial, *dyn_args)
else:
y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
jac = vmap(pullback)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
jac_tree = tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree)
if not has_aux:
return jac_tree
else:
return jac_tree, aux
return jacfun
jacobian = jacrev
_check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev")
_check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev")
def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False) -> Callable:
"""Hessian of ``fun`` as a dense array.
Args:
fun: Function whose Hessian is to be computed. Its arguments at positions
specified by ``argnums`` should be arrays, scalars, or standard Python
containers thereof. It should return arrays, scalars, or standard Python
containers thereof.
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default ``0``).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. Default False.
Returns:
A function with the same arguments as ``fun``, that evaluates the Hessian of
``fun``.
>>> import jax
>>>
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[ 6. -2.]
[ -2. -480.]]
:py:func:`hessian` is a generalization of the usual definition of the Hessian
that supports nested Python containers (i.e. pytrees) as inputs and outputs.
The tree structure of ``jax.hessian(fun)(x)`` is given by forming a tree
product of the structure of ``fun(x)`` with a tree product of two copies of
the structure of ``x``. A tree product of two tree structures is formed by
replacing each leaf of the first tree with a copy of the second. For example:
>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': DeviceArray([[[ 2., 0.], [ 0., 0.]],
[[ 0., 0.], [ 0., 12.]]], dtype=float32),
'b': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]],
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)},
'b': {'a': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]],
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32),
'b': DeviceArray([[[0. , 0. ], [0. , 0. ]],
[[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
Thus each leaf in the tree structure of ``jax.hessian(fun)(x)`` corresponds to
a leaf of ``fun(x)`` and a pair of leaves of ``x``. For each leaf in
``jax.hessian(fun)(x)``, if the corresponding array leaf of ``fun(x)`` has
shape ``(out_1, out_2, ...)`` and the corresponding array leaves of ``x`` have
shape ``(in_1_1, in_1_2, ...)`` and ``(in_2_1, in_2_2, ...)`` respectively,
then the Hessian leaf has shape ``(out_1, out_2, ..., in_1_1, in_1_2, ...,
in_2_1, in_2_2, ...)``. In other words, the Python tree structure represents
the block structure of the Hessian, with blocks determined by the input and
output pytrees.
In particular, an array is produced (with no pytrees involved) when the
function input ``x`` and output ``fun(x)`` are each a single array, as in the
``g`` example above. If ``fun(x)`` has shape ``(out1, out2, ...)`` and ``x``
has shape ``(in1, in2, ...)`` then ``jax.hessian(fun)(x)`` has shape
``(out1, out2, ..., in1, in2, ..., in1, in2, ...)``. To flatten pytrees into
1D vectors, consider using :py:func:`jax.flatten_util.flatten_pytree`.
"""
return jacfwd(jacrev(fun, argnums, has_aux=has_aux, holomorphic=holomorphic),
argnums, has_aux=has_aux, holomorphic=holomorphic)
def _std_basis(pytree):
leaves, _ = tree_flatten(pytree)
ndim = sum(map(np.size, leaves))
dtype = dtypes.result_type(*leaves)
flat_basis = jax.numpy.eye(ndim, dtype=dtype)
return _unravel_array_into_pytree(pytree, 1, None, flat_basis)
def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr):
return _unravel_array_into_pytree(
input_pytree, -1, output_pytree_leaf, arr)
def _jacrev_unravel(output_pytree, input_pytree_leaf, arr):
return _unravel_array_into_pytree(
output_pytree, 0, input_pytree_leaf, arr)
def _possible_downcast(x, example):
if (dtypes.issubdtype(x.dtype, np.complexfloating) and
not dtypes.issubdtype(_dtype(example), np.complexfloating)):
x = x.real
dtype = None if example is None else _dtype(example)
weak_type = None if example is None else dtypes.is_weakly_typed(example)
return lax_internal._convert_element_type(x, dtype, weak_type)
def _unravel_array_into_pytree(pytree, axis, example, arr):
"""Unravel an array into a PyTree with a given structure.
Args:
pytree: The pytree that provides the structure.
axis: The parameter axis is either -1, 0, or 1. It controls the
resulting shapes.
example: If specified, cast the components to the matching dtype/weak_type,
or else use the pytree leaf type if example is None.
arr: The array to be unraveled.
"""
leaves, treedef = tree_flatten(pytree)
axis = axis % arr.ndim
shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis+1:] for l in leaves]
parts = _split(arr, np.cumsum(map(np.size, leaves[:-1])), axis)
reshaped_parts = [
_possible_downcast(np.reshape(x, shape), leaf if example is None else example)
for x, shape, leaf in zip(parts, shapes, leaves)]
return tree_unflatten(treedef, reshaped_parts)
def _split(x, indices, axis):
if isinstance(x, np.ndarray):
return np.split(x, indices, axis)
else:
return x.split(indices, axis)
def vmap(fun: F,
in_axes: Union[int, Sequence[Any]] = 0,
out_axes: Any = 0,
axis_name: Optional[Hashable] = None,
axis_size: Optional[int] = None,
spmd_axis_name: Optional[Hashable] = None) -> F:
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
Args:
fun: Function to be mapped over additional axes.
in_axes: An integer, None, or (nested) standard Python container
(tuple/list/dict) thereof specifying which input array axes to map over.
If each positional argument to ``fun`` is an array, then ``in_axes`` can
be an integer, a None, or a tuple of integers and Nones with length equal
to the number of positional arguments to ``fun``. An integer or ``None``
indicates which array axis to map over for all arguments (with ``None``
indicating not to map any axis), and a tuple indicates which axis to map
for each corresponding positional argument. Axis integers must be in the
range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of
dimensions (axes) of the corresponding input array.
If the positional arguments to ``fun`` are container (pytree) types, the
corresponding element of ``in_axes`` can itself be a matching container,
so that distinct array axes can be mapped for different container
elements. ``in_axes`` must be a container tree prefix of the positional
argument tuple passed to ``fun``. See this link for more detail:
https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees
Either ``axis_size`` must be provided explicitly, or at least one
positional argument must have ``in_axes`` not None. The sizes of the
mapped input axes for all mapped positional arguments must all be equal.
Arguments passed as keywords are always mapped over their leading axis
(i.e. axis index 0).
See below for examples.
out_axes: An integer, None, or (nested) standard Python container
(tuple/list/dict) thereof indicating where the mapped axis should appear
in the output. All outputs with a mapped axis must have a non-None
``out_axes`` specification. Axis integers must be in the range ``[-ndim,
ndim)`` for each output array, where ``ndim`` is the number of dimensions
(axes) of the array returned by the :func:`vmap`-ed function, which is one
more than the number of dimensions (axes) of the corresponding array
returned by ``fun``.
axis_name: Optional, a hashable Python object used to identify the mapped
axis so that parallel collectives can be applied.
axis_size: Optional, an integer indicating the size of the axis to be
mapped. If not provided, the mapped axis size is inferred from arguments.
Returns:
Batched/vectorized version of ``fun`` with arguments that correspond to
those of ``fun``, but with extra array axes at positions indicated by
``in_axes``, and a return value that corresponds to that of ``fun``, but
with extra array axes at positions indicated by ``out_axes``.
For example, we can implement a matrix-matrix product using a vector dot
product:
>>> import jax.numpy as jnp
>>>
>>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> []
>>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis)
>>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
Here we use ``[a,b]`` to indicate an array with shape (a,b). Here are some
variants:
>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) -> [b] (b is the mapped axis)
>>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) -> [b] (b is the mapped axis)
>>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) -> [c,b] (c is the mapped axis)
Here's an example of using container types in ``in_axes`` to specify which
axes of the container elements to map over:
>>> A, B, C, D = 2, 3, 4, 5
>>> x = jnp.ones((A, B))
>>> y = jnp.ones((B, C))
>>> z = jnp.ones((C, D))
>>> def foo(tree_arg):
... x, (y, z) = tree_arg
... return jnp.dot(x, jnp.dot(y, z))
>>> tree = (x, (y, z))
>>> print(foo(tree))
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
>>> from jax import vmap
>>> K = 6 # batch size
>>> x = jnp.ones((K, A, B)) # batch axis in different locations
>>> y = jnp.ones((B, K, C))
>>> z = jnp.ones((C, D, K))
>>> tree = (x, (y, z))
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
>>> print(vfoo(tree).shape)
(6, 2, 5)
Here's another example using container types in ``in_axes``, this time a
dictionary, to specify the elements of the container to map over:
>>> dct = {'a': 0., 'b': jnp.arange(5.)}
>>> x = 1.
>>> def foo(dct, x):
... return dct['a'] + dct['b'] + x
>>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x)
>>> print(out)
[1. 2. 3. 4. 5.]
The results of a vectorized function can be mapped or unmapped. For example,
the function below returns a pair with the first element mapped and the second
unmapped. Only for unmapped results we can specify ``out_axes`` to be ``None``
(to keep it unmapped).
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.))
(DeviceArray([4., 5.], dtype=float32), 8.0)
If the ``out_axes`` is specified for an unmapped result, the result is
broadcast across the mapped axis:
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
(DeviceArray([4., 5.], dtype=float32), DeviceArray([8., 8.], dtype=float32, weak_type=True))
If the ``out_axes`` is specified for a mapped result, the result is transposed
accordingly.
Finally, here's an example using ``axis_name`` together with collectives:
>>> xs = jnp.arange(3. * 4.).reshape(3, 4)
>>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs))
[[12. 15. 18. 21.]
[12. 15. 18. 21.]
[12. 15. 18. 21.]]
See the :py:func:`jax.pmap` docstring for more examples involving collectives.
"""
_check_callable(fun)
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
"but with additional array axes over which {fun} is mapped.")
if fun.__doc__:
docstr += "\n\nOriginal documentation:\n\n"
docstr += fun.__doc__
axis_name = core.no_axis_name if axis_name is None else axis_name
if isinstance(in_axes, list):
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axes = tuple(in_axes)
if not all(type(l) is int or type(l) in batching.spec_types
for l in tree_leaves(in_axes)):
raise TypeError("vmap in_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {in_axes}.")
if not all(type(l) is int or type(l) in batching.spec_types
for l in tree_leaves(out_axes)):
raise TypeError("vmap out_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {out_axes}.")
@wraps(fun, docstr=docstr)
@api_boundary
def vmap_f(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
f = lu.wrap_init(fun)
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
axis_size_ = (axis_size if axis_size is not None else
_mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap",
kws=True))
out_flat = batching.batch(
flat_fun, axis_name, axis_size_, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
spmd_axis_name=spmd_axis_name
).call_wrapped(*args_flat)
return tree_unflatten(out_tree(), out_flat)
return vmap_f
def _mapped_axis_size(tree, vals, dims, name, *, kws=False):
if not vals:
args, kwargs = tree_unflatten(tree, vals)
raise ValueError(
f"{name} wrapped function must be passed at least one argument "
f"containing an array, got empty *args={args} and **kwargs={kwargs}"
)
def _get_axis_size(name: str, shape: Tuple[core.AxisSize, ...], axis: int
) -> core.AxisSize:
try:
return shape[axis]
except (IndexError, TypeError) as e:
min_rank = axis + 1 if axis >= 0 else -axis
raise ValueError(
f"{name} was requested to map its argument along axis {axis}, "
f"which implies that its rank should be at least {min_rank}, "
f"but is only {len(shape)} (its shape is {shape})") from e
axis_sizes = core.dedup_referents(
_get_axis_size(name, np.shape(x), d) for x, d in zip(vals, dims)
if d is not None)
if len(axis_sizes) == 1:
return axis_sizes[0]
if not axis_sizes:
msg = f"{name} must have at least one non-None value in in_axes"
raise ValueError(msg)
msg = f"{name} got inconsistent sizes for array axes to be mapped:\n" + "{}"
# we switch the error message based on whether args is a tuple of arrays,
# in which case we can produce an error message based on argument indices,
# or if it has nested containers.
if kws:
position_only_tree, leaf = treedef_children(tree)
if not treedef_is_leaf(leaf):
sizes = [x.shape[d] if d is not None else None for x, d in zip(vals, dims)]
sizes = tree_unflatten(tree, sizes)
raise ValueError(msg.format(f"the tree of axis sizes is:\n{sizes}")) from None
# if keyword arguments are included in the tree, we adapt the error
# message only to be about the positional arguments
tree = position_only_tree
# TODO(mattjj,phawkins): add a way to inspect pytree kind more directly
if tree == tree_flatten((0,) * tree.num_leaves)[1]:
lines1 = [f"arg {i} has shape {np.shape(x)} and axis {d} is to be mapped"
for i, (x, d) in enumerate(zip(vals, dims))]
sizes = collections.defaultdict(list)
for i, (x, d) in enumerate(zip(vals, dims)):
if d is not None:
sizes[x.shape[d]].append(i)
lines2 = ["{} {} {} {} to be mapped of size {}".format(
"args" if len(idxs) > 1 else "arg",
", ".join(map(str, idxs)),
"have" if len(idxs) > 1 else "has",
"axes" if len(idxs) > 1 else "an axis",
size)
for size, idxs in sizes.items()]
raise ValueError(msg.format("\n".join(lines1 + ["so"] + lines2))) from None
else:
sizes = [x.shape[d] if d is not None else None for x, d in zip(vals, dims)]
sizes = tree_unflatten(tree, sizes)
raise ValueError(msg.format(f"the tree of axis sizes is:\n{sizes}")) from None
def pmap(
fun: Callable,
axis_name: Optional[AxisName] = None,
*,
in_axes=0,
out_axes=0,
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
devices: Optional[Sequence[xc.Device]] = None, # noqa: F811
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> Any:
"""Parallel map with support for collective operations.
The purpose of :py:func:`pmap` is to express single-program multiple-data
(SPMD) programs. Applying :py:func:`pmap` to a function will compile the
function with XLA (similarly to :py:func:`jit`), then execute it in parallel
on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it
is comparable to :py:func:`vmap` because both transformations map a function
over array axes, but where :py:func:`vmap` vectorizes functions by pushing the
mapped axis down into primitive operations, :py:func:`pmap` instead replicates
the function and executes each replica on its own XLA device in parallel.
The mapped axis size must be less than or equal to the number of local XLA
devices available, as returned by :py:func:`jax.local_device_count()` (unless
``devices`` is specified, see below). For nested :py:func:`pmap` calls, the
product of the mapped axis sizes must be less than or equal to the number of
XLA devices.
.. note::
:py:func:`pmap` compiles ``fun``, so while it can be combined with
:py:func:`jit`, it's usually unnecessary.
**Multi-process platforms:** On multi-process platforms such as TPU pods,
:py:func:`pmap` is designed to be used in SPMD Python programs, where every
process is running the same Python code such that all processes run the same
pmapped function in the same order. Each process should still call the pmapped
function with mapped axis size equal to the number of *local* devices (unless
``devices`` is specified, see below), and an array of the same leading axis
size will be returned as usual. However, any collective operations in ``fun``
will be computed over *all* participating devices, including those on other
processes, via device-to-device communication. Conceptually, this can be
thought of as running a pmap over a single array sharded across processes,
where each process "sees" only its local shard of the input and output. The
SPMD model requires that the same multi-process pmaps must be run in the same
order on all devices, but they can be interspersed with arbitrary operations
running in a single process.
Args:
fun: Function to be mapped over argument axes. Its arguments and return
value should be arrays, scalars, or (nested) standard Python containers
(tuple/list/dict) thereof. Positional arguments indicated by
``static_broadcasted_argnums`` can be anything at all, provided they are
hashable and have an equality operation defined.
axis_name: Optional, a hashable Python object used to identify the mapped
axis so that parallel collectives can be applied.
in_axes: A non-negative integer, None, or nested Python container thereof
that specifies which axes of positional arguments to map over. Arguments
passed as keywords are always mapped over their leading axis (i.e. axis
index 0). See :py:func:`vmap` for details.
out_axes: A non-negative integer, None, or nested Python container thereof
indicating where the mapped axis should appear in the output. All outputs
with a mapped axis must have a non-None ``out_axes`` specification
(see :py:func:`vmap`).
static_broadcasted_argnums: An int or collection of ints specifying which
positional arguments to treat as static (compile-time constant).
Operations that only depend on static arguments will be constant-folded.
Calling the pmapped function with different values for these constants
will trigger recompilation. If the pmapped function is called with fewer
positional arguments than indicated by ``static_argnums`` then an error is
raised. Each of the static arguments will be broadcasted to all devices.
Arguments that are not arrays or containers thereof must be marked as
static. Defaults to ().
Static arguments must be hashable, meaning both ``__hash__`` and
``__eq__`` are implemented, and should be immutable.
devices: This is an experimental feature and the API is likely to change.
Optional, a sequence of Devices to map over. (Available devices can be
retrieved via jax.devices()). Must be given identically for each process
in multi-process settings (and will therefore include devices across
processes). If specified, the size of the mapped axis must be equal to
the number of devices in the sequence local to the given process. Nested
:py:func:`pmap` s with ``devices`` specified in either the inner or outer
:py:func:`pmap` are not yet supported.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
axis_size: Optional; the size of the mapped axis.
donate_argnums: Specify which positional argument buffers are "donated" to
the computation. It is safe to donate argument buffers if you no longer need
them once the computation has finished. In some cases XLA can make use of
donated buffers to reduce the amount of memory needed to perform a
computation, for example recycling one of your input buffers to store a
result. You should not reuse buffers that you donate to a computation, JAX
will raise an error if you try to.
Note that donate_argnums only work for positional arguments, and keyword
arguments will not be donated.
For more details on buffer donation see the
[FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and
the partitioned values span multiple processes. The global cross-process
per-replica shape of each argument, i.e. does not include the leading
pmapped dimension. Can be None for replicated arguments. This API is
likely to change in the future.
Returns:
A parallelized version of ``fun`` with arguments that correspond to those of
``fun`` but with extra array axes at positions indicated by ``in_axes`` and
with output that has an additional leading array axis (with the same size).
For example, assuming 8 XLA devices are available, :py:func:`pmap` can be used
as a map along a leading array axis:
>>> import jax.numpy as jnp
>>>
>>> out = pmap(lambda x: x ** 2)(jnp.arange(8)) # doctest: +SKIP
>>> print(out) # doctest: +SKIP
[0, 1, 4, 9, 16, 25, 36, 49]
When the leading dimension is smaller than the number of available devices JAX
will simply run on a subset of devices:
>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2))
>>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
>>> out = pmap(jnp.dot)(x, y) # doctest: +SKIP
>>> print(out) # doctest: +SKIP
[[[ 4. 9.]
[ 12. 29.]]
[[ 244. 345.]
[ 348. 493.]]
[[ 1412. 1737.]
[ 1740. 2141.]]]
If your leading dimension is larger than the number of available devices you
will get an error:
>>> pmap(lambda x: x ** 2)(jnp.arange(9)) # doctest: +SKIP
ValueError: ... requires 9 replicas, but only 8 XLA devices are available
As with :py:func:`vmap`, using ``None`` in ``in_axes`` indicates that an
argument doesn't have an extra axis and should be broadcasted, rather than
mapped, across the replicas:
>>> x, y = jnp.arange(2.), 4.
>>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y) # doctest: +SKIP
>>> print(out) # doctest: +SKIP
([4., 5.], [8., 8.])
Note that :py:func:`pmap` always returns values mapped over their leading axis,
equivalent to using ``out_axes=0`` in :py:func:`vmap`.
In addition to expressing pure maps, :py:func:`pmap` can also be used to express
parallel single-program multiple-data (SPMD) programs that communicate via
collective operations. For example:
>>> f = lambda x: x / jax.lax.psum(x, axis_name='i')
>>> out = pmap(f, axis_name='i')(jnp.arange(4.)) # doctest: +SKIP
>>> print(out) # doctest: +SKIP
[ 0. 0.16666667 0.33333334 0.5 ]
>>> print(out.sum()) # doctest: +SKIP
1.0
In this example, ``axis_name`` is a string, but it can be any Python object
with ``__hash__`` and ``__eq__`` defined.
The argument ``axis_name`` to :py:func:`pmap` names the mapped axis so that
collective operations, like :func:`jax.lax.psum`, can refer to it. Axis names
are important particularly in the case of nested :py:func:`pmap` functions,
where collective operations can operate over distinct axes:
>>> from functools import partial
>>> import jax
>>>
>>> @partial(pmap, axis_name='rows')
... @partial(pmap, axis_name='cols')
... def normalize(x):
... row_normed = x / jax.lax.psum(x, 'rows')
... col_normed = x / jax.lax.psum(x, 'cols')
... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
... return row_normed, col_normed, doubly_normed
>>>
>>> x = jnp.arange(8.).reshape((4, 2))
>>> row_normed, col_normed, doubly_normed = normalize(x) # doctest: +SKIP
>>> print(row_normed.sum(0)) # doctest: +SKIP
[ 1. 1.]
>>> print(col_normed.sum(1)) # doctest: +SKIP
[ 1. 1. 1. 1.]
>>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP
1.0
On multi-process platforms, collective operations operate over all devices,
including those on other processes. For example, assuming the following code
runs on two processes with 4 XLA devices each:
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
>>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8)
>>> out = pmap(f, axis_name='i')(data) # doctest: +SKIP
>>> print(out) # doctest: +SKIP
[28 29 30 31] # on process 0
[32 33 34 35] # on process 1
Each process passes in a different length-4 array, corresponding to its 4
local devices, and the psum operates over all 8 values. Conceptually, the two
length-4 arrays can be thought of as a sharded length-8 array (in this example
equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped
axis given name 'i'. The pmap call on each process then returns the
corresponding length-4 output shard.
The ``devices`` argument can be used to specify exactly which devices are used
to run the parallel computation. For example, again assuming a single process
with 8 devices, the following code defines two parallel computations, one
which runs on the first six devices and one on the remaining two:
>>> from functools import partial
>>> @partial(pmap, axis_name='i', devices=jax.devices()[:6])
... def f1(x):
... return x / jax.lax.psum(x, axis_name='i')
>>>
>>> @partial(pmap, axis_name='i', devices=jax.devices()[-2:])
... def f2(x):
... return jax.lax.psum(x ** 2, axis_name='i')
>>>
>>> print(f1(jnp.arange(6.))) # doctest: +SKIP
[0. 0.06666667 0.13333333 0.2 0.26666667 0.33333333]
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
[ 13. 13.]
"""
if FLAGS.experimental_cpp_pmap:
func = _cpp_pmap
else:
func = _python_pmap
return func(
fun,
axis_name,
in_axes=in_axes,
out_axes=out_axes,
static_broadcasted_argnums=static_broadcasted_argnums,
devices=devices,
backend=backend,
axis_size=axis_size,
donate_argnums=donate_argnums,
global_arg_shapes=global_arg_shapes)
class PmapCallInfo(NamedTuple):
flat_fun: lu.WrappedFun
in_tree: PyTreeDef
out_tree: PyTreeDef
flat_args: Sequence[Any]
donated_invars: Sequence[bool]
in_axes_flat: Sequence[Optional[int]]
local_axis_size: int
global_arg_shapes_flat: Sequence[Optional[Tuple[int, ...]]]
out_axes_thunk: HashableFunction
devices: Optional[Sequence[xc.Device]]
def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, global_arg_shapes, in_devices, args, kwargs):
f = lu.wrap_init(fun)
if static_broadcasted_tuple:
if max(static_broadcasted_tuple) >= len(args):
raise ValueError(
f"pmapped function has static_broadcasted_argnums={static_broadcasted_tuple}"
f" but was called with only {len(args)} positional "
f"argument{'s' if len(args) > 1 else ''}. "
"All static broadcasted arguments must be passed positionally.")
dyn_argnums = [i for i in range(len(args))
if i not in static_broadcasted_tuple]
f, dyn_args = argnums_partial(f, dyn_argnums, args)
if isinstance(in_axes, tuple):
dyn_in_axes = tuple(in_axes[i] for i in dyn_argnums)
else:
dyn_in_axes = in_axes
dyn_global_arg_shapes = global_arg_shapes
if isinstance(global_arg_shapes, tuple):
dyn_global_arg_shapes = tuple(global_arg_shapes[i] for i in dyn_argnums)
else:
dyn_global_arg_shapes = global_arg_shapes
else:
dyn_args, dyn_in_axes = args, in_axes
dyn_global_arg_shapes = global_arg_shapes
args, in_tree = tree_flatten((dyn_args, kwargs))
if donate_tuple:
donated_invars = donation_vector(donate_tuple, dyn_args, kwargs)
else:
donated_invars = (False,) * len(args)
in_axes_flat = tuple(flatten_axes("pmap in_axes", in_tree, (dyn_in_axes, 0)))
global_arg_shapes_flat = tuple(flatten_axes(
"pmap global_arg_shapes", in_tree, (dyn_global_arg_shapes, None),
kws=True))
local_axis_size = _mapped_axis_size(
in_tree, args, in_axes_flat, "pmap", kws=True)
flat_fun, out_tree = flatten_fun(f, in_tree)
if any(out_axis is None for out_axis in tree_flatten(out_axes)):
raise NotImplementedError("None out_axes in pmap are not supported yet")
# NOTE: We don't put out_tree() in the closure, because it's (1) non-hashable,
# (2) depends deterministically on flat_fun (at least that's the assumption
# that we make).
if out_axes == 0:
# TODO(apaszke,mattjj): flatten_axes assumes that the output pytree is
# functorial (i.e. it can hold leaves of any type), but some user code
# breaks this assumption. This is a stop-gap solution to keep the old
# out_axes == 0 path working as we look for a better solution.
out_axes_thunk = HashableFunction(
lambda: (0,) * out_tree().num_leaves,
closure=out_axes)
else:
# out_axes_thunk closes over the out_axes, they are flattened here to make
# them hashable.
out_axes_leaves, out_axes_treedef = tree_flatten(out_axes)
out_axes_thunk = HashableFunction(
lambda: tuple(flatten_axes("pmap out_axes", out_tree(),
tree_unflatten(out_axes_treedef,
list(out_axes_leaves)))),
closure=(tuple(out_axes_leaves), out_axes_treedef))
return PmapCallInfo(flat_fun=flat_fun,
in_tree=in_tree,
out_tree=out_tree,
flat_args=args,
donated_invars=donated_invars,
in_axes_flat=in_axes_flat,
local_axis_size=local_axis_size,
global_arg_shapes_flat=global_arg_shapes_flat,
out_axes_thunk=out_axes_thunk,
devices=None if in_devices is None else tuple(in_devices))
def _get_f_mapped(
*,
fun: Callable,
axis_name: Optional[AxisName],
in_axes=0,
out_axes=0,
static_broadcasted_tuple: Tuple[int, ...],
devices: Optional[Sequence[xc.Device]], # noqa: F811
backend: Optional[str],
axis_size: Optional[int],
donate_tuple: Tuple[int, ...],
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]],
):
def pmap_f(*args, **kwargs):
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
global_arg_shapes, devices, args, kwargs)
for arg in p.flat_args:
_check_arg(arg)
out = pxla.xla_pmap(
p.flat_fun, *p.flat_args, backend=backend, axis_name=axis_name,
axis_size=p.local_axis_size, global_axis_size=axis_size,
devices=p.devices,
in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk,
name=p.flat_fun.__name__, donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat)
return p.out_tree, out
return pmap_f
def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums,
donate_argnums, in_axes, out_axes):
# axis_size is an optional integer representing the global axis size. The
# aggregate size (across all processes) size of the mapped axis must match the
# given value.
_check_callable(fun)
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name
static_broadcasted_tuple = _ensure_index_tuple(static_broadcasted_argnums)
donate_tuple = rebase_donate_argnums(
_ensure_index_tuple(donate_argnums), static_broadcasted_tuple)
if not all(type(l) is int for l in tree_leaves(in_axes)):
raise TypeError("pmap in_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {in_axes}.")
if not all(type(l) is int for l in tree_leaves(out_axes)):
raise TypeError("pmap out_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {out_axes}.")
return axis_name, static_broadcasted_tuple, donate_tuple
def _python_pmap(
fun: Callable,
axis_name: Optional[AxisName] = None,
*,
in_axes=0,
out_axes=0,
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
devices: Optional[Sequence[xc.Device]] = None, # noqa: F811
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> stages.Wrapped:
"""The Python only implementation."""
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
out_axes)
@wraps(fun)
@api_boundary
def pmap_f(*args, **kwargs):
f_pmapped_ = _get_f_mapped(
fun=fun,
axis_name=axis_name,
in_axes=in_axes,
out_axes=out_axes,
static_broadcasted_tuple=static_broadcasted_tuple,
devices=devices,
backend=backend,
axis_size=axis_size,
global_arg_shapes=global_arg_shapes,
donate_tuple=donate_tuple)
out_tree, out_flat = f_pmapped_(*args, **kwargs)
return tree_unflatten(out_tree(), out_flat)
pmap_f.lower = _pmap_lower(
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
backend, axis_size, global_arg_shapes, donate_tuple)
return pmap_f
class _PmapFastpathData(NamedTuple):
version: int # For forward and backward compatibility
xla_executable: xla.XlaExecutable
in_handler: Any
out_handler: Any
out_pytree_def: Any
# Data needed to handle the inputs.
input_sharding_specs: Sequence[pxla.ShardingSpec]
input_devices: Sequence[xc.Device]
input_indices: Sequence[pxla.Index]
input_array_shardings: Sequence[Any]
# Data needed to build the ShardedDeviceArray from C++.
out_sharding_specs: Sequence[pxla.ShardingSpec]
out_indices: Sequence[pxla.Index]
out_avals: Sequence[Any]
out_array_shardings: Sequence[Any]
out_committed: Sequence[Any]
def _cpp_pmap(
fun: Callable,
axis_name: Optional[AxisName] = None,
*,
in_axes=0,
out_axes=0,
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
devices: Optional[Sequence[xc.Device]] = None, # noqa: F811
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> Any:
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
out_axes)
del static_broadcasted_argnums, donate_argnums
@api_boundary
def cache_miss(*args, **kwargs):
f_pmapped_ = _get_f_mapped(
fun=fun,
axis_name=axis_name,
in_axes=in_axes,
out_axes=out_axes,
static_broadcasted_tuple=static_broadcasted_tuple,
devices=devices,
backend=backend,
axis_size=axis_size,
global_arg_shapes=global_arg_shapes,
donate_tuple=donate_tuple)
out_tree, out_flat = f_pmapped_(*args, **kwargs)
out_pytree_def = out_tree()
out = tree_unflatten(out_pytree_def, out_flat)
### Decide whether we can support the C++ fast path
execute: Optional[functools.partial] = None
execute = pxla.parallel_callable.most_recent_entry()
use_fastpath = (
execute is not None and
# We don't support JAX extension backends.
isinstance(execute[0], pxla.ExecuteReplicated) and
# TODO(sharadmv): Enable effects in replicated computation
not execute[0].has_unordered_effects and
not execute[0].has_host_callbacks and
# No tracers in the outputs. Checking for ShardedDeviceArray should be
# sufficient, but we use the more general `DeviceArray`.
all(
isinstance(x, device_array.DeviceArray) or
xc._version >= 94 and isinstance(x, xc.Array) for x in out_flat))
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
execute_replicated = execute[0]
out_handler = execute_replicated.out_handler
in_handler = execute_replicated.in_handler
out_indices = [tuple(s.devices_indices_map(a.shape).values())
for s, a in safe_zip(out_handler.out_shardings, out_handler.out_avals)]
if jax.config.jax_array:
out_array_shardings = [out.sharding for out in out_flat]
out_committed = [out._committed for out in out_flat]
else:
out_array_shardings = []
out_committed = []
fastpath_data = _PmapFastpathData(
version=1,
xla_executable=execute_replicated.xla_executable,
in_handler=in_handler,
out_handler=out_handler,
out_pytree_def=out_pytree_def,
input_sharding_specs=[i.sharding_spec for i in in_handler.in_shardings],
input_devices=in_handler.local_devices,
input_indices=in_handler.input_indices,
input_array_shardings=in_handler.in_shardings,
out_sharding_specs=[s.sharding_spec for s in out_handler.out_shardings],
out_indices=out_indices,
out_avals=out_handler.out_avals,
out_array_shardings=out_array_shardings,
out_committed=out_committed,
)
else:
fastpath_data = None
return out, fastpath_data
cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple,
partial(pxla._shard_arg, mode=pxla.InputsHandlerMode.pmap))
pmap_f = wraps(fun)(cpp_mapped_f)
pmap_f.lower = _pmap_lower(
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
backend, axis_size, global_arg_shapes, donate_tuple)
return pmap_f
def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
devices, backend, axis_size, global_arg_shapes, donate_tuple): # noqa: F811
"""Make a ``lower`` method for pmapped functions."""
# If the function we returned from ``pmap`` were a class instance,
# this might naturally be a method, with ``fun`` as a ``self`` and
# all the other arguments stored as attributes.
@api_boundary
def lower(*args, **kwargs) -> stages.Lowered:
"""Lower a parallel-mapped form of this function for the given arguments.
A parallel-mapped and lowered function is staged out of Python and
translated to a compiler's input language, possibly in a
backend-dependent manner. It is ready for compilation but is not yet
compiled. It represents a function intended for SPMD execution on
multiple devices.
Returns:
A ``Lowered`` instance representing the post-map lowering.
"""
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
global_arg_shapes, devices, args, kwargs)
abstract_args = list(map(shaped_abstractify, p.flat_args))
computation = pxla.lower_parallel_callable(
p.flat_fun, backend, axis_name,
axis_size=p.local_axis_size, global_axis_size=axis_size,
devices=p.devices,
name=p.flat_fun.__name__,
in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk,
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
avals=abstract_args)
return stages.Lowered.from_flat_info(
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())
return lower
def jvp(
fun: Callable, primals, tangents, has_aux: bool = False
) -> Tuple[Any, ...]:
"""Computes a (forward-mode) Jacobian-vector product of ``fun``.
Args:
fun: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars.
primals: The primal values at which the Jacobian of ``fun`` should be
evaluated. Should be either a tuple or a list of arguments,
and its length should be equal to the number of positional parameters of
``fun``.
tangents: The tangent vector for which the Jacobian-vector product should be
evaluated. Should be either a tuple or a list of tangents, with the same
tree structure and array shapes as ``primals``.
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
Returns:
If ``has_aux`` is ``False``, returns a ``(primals_out, tangents_out)`` pair,
where ``primals_out`` is ``fun(*primals)``,
and ``tangents_out`` is the Jacobian-vector product of
``function`` evaluated at ``primals`` with ``tangents``. The
``tangents_out`` value has the same Python tree structure and shapes as
``primals_out``. If ``has_aux`` is ``True``, returns a
``(primals_out, tangents_out, aux)`` tuple where ``aux``
is the auxiliary data returned by ``fun``.
For example:
>>> import jax
>>>
>>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
>>> print(primals)
0.09983342
>>> print(tangents)
0.19900084
"""
_check_callable(fun)
return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
"""Variant of jvp() that takes an lu.WrappedFun."""
if (not isinstance(primals, (tuple, list)) or
not isinstance(tangents, (tuple, list))):
raise TypeError("primal and tangent arguments to jax.jvp must be tuples or lists; "
f"found {type(primals).__name__} and {type(tangents).__name__}.")
ps_flat, tree_def = tree_flatten(primals)
ts_flat, tree_def_2 = tree_flatten(tangents)
if tree_def != tree_def_2:
raise TypeError("primal and tangent arguments to jax.jvp must have the same tree "
f"structure; primals have tree structure {tree_def} whereas tangents have "
f"tree structure {tree_def_2}.")
for p, t in safe_zip(ps_flat, ts_flat):
if core.primal_dtype_to_tangent_dtype(_dtype(p)) != _dtype(t):
raise TypeError("primal and tangent arguments to jax.jvp do not match; "
"dtypes must be equal, or in case of int/bool primal dtype "
"the tangent dtype must be float0."
f"Got primal dtype {_dtype(p)} and so expected tangent dtype "
f"{core.primal_dtype_to_tangent_dtype(_dtype(p))}, but got "
f"tangent dtype {_dtype(t)} instead.")
if np.shape(p) != np.shape(t):
raise ValueError("jvp called with different primal and tangent shapes;"
f"Got primal shape {np.shape(p)} and tangent shape as {np.shape(t)}")
if not has_aux:
flat_fun, out_tree = flatten_fun_nokwargs(fun, tree_def)
out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
out_tree = out_tree()
return (tree_unflatten(out_tree, out_primals),
tree_unflatten(out_tree, out_tangents))
else:
flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, tree_def)
jvp_fun, aux = ad.jvp(flat_fun, has_aux=True)
out_primals, out_tangents = jvp_fun.call_wrapped(ps_flat, ts_flat)
out_tree, aux_tree = out_aux_trees()
return (tree_unflatten(out_tree, out_primals),
tree_unflatten(out_tree, out_tangents),
tree_unflatten(aux_tree, aux()))
def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]:
"""Produces a linear approximation to ``fun`` using :py:func:`jvp` and partial eval.
Args:
fun: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard python container of arrays or scalars.
primals: The primal values at which the Jacobian of ``fun`` should be
evaluated. Should be a tuple of arrays, scalar, or standard Python
container thereof. The length of the tuple is equal to the number of
positional parameters of ``fun``.
Returns:
A pair where the first element is the value of ``f(*primals)`` and the
second element is a function that evaluates the (forward-mode)
Jacobian-vector product of ``fun`` evaluated at ``primals`` without re-doing
the linearization work.
In terms of values computed, :py:func:`linearize` behaves much like a curried
:py:func:`jvp`, where these two code blocks compute the same values::
y, out_tangent = jax.jvp(f, (x,), (in_tangent,))
y, f_jvp = jax.linearize(f, x)
out_tangent = f_jvp(in_tangent)
However, the difference is that :py:func:`linearize` uses partial evaluation
so that the function ``f`` is not re-linearized on calls to ``f_jvp``. In
general that means the memory usage scales with the size of the computation,
much like in reverse-mode. (Indeed, :py:func:`linearize` has a similar
signature to :py:func:`vjp`!)
This function is mainly useful if you want to apply ``f_jvp`` multiple times,
i.e. to evaluate a pushforward for many different input tangent vectors at the
same linearization point. Moreover if all the input tangent vectors are known
at once, it can be more efficient to vectorize using :py:func:`vmap`, as in::
pushfwd = partial(jvp, f, (x,))
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
By using :py:func:`vmap` and :py:func:`jvp` together like this we avoid the stored-linearization
memory cost that scales with the depth of the computation, which is incurred
by both :py:func:`linearize` and :py:func:`vjp`.
Here's a more complete example of using :py:func:`linearize`:
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(DeviceArray(3.26819, dtype=float32, weak_type=True), DeviceArray(-5.00753, dtype=float32, weak_type=True))
>>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y)
3.2681944
>>> print(f_jvp(3.))
-5.007528
>>> print(f_jvp(4.))
-6.676704
"""
_check_callable(fun)
f = lu.wrap_init(fun)
primals_flat, in_tree = tree_flatten((primals, {}))
jaxtree_fun, out_tree = flatten_fun(f, in_tree)
out_primals, out_pvals, jaxpr, consts = ad.linearize(jaxtree_fun, *primals_flat)
out_tree = out_tree()
out_primal_py = tree_unflatten(out_tree, out_primals)
primal_avals = list(map(core.get_aval, primals_flat))
# Ensure that lifted_jvp is a PyTree
lifted_jvp = Partial(partial(_lift_linearized, jaxpr, primal_avals,
(in_tree, out_tree), out_pvals), consts)
return out_primal_py, lifted_jvp
def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
def fun(*tangents):
tangent_avals = list(map(core.get_aval, tangents))
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
if not core.typecompat(primal_aval.at_least_vspace(), tangent_aval):
raise ValueError("linearized function called on tangent values inconsistent with "
"the original primal values: "
f"got {tangent_aval} for primal aval {primal_aval}")
tangents_out = eval_jaxpr(jaxpr, consts, *tangents)
tangents_out_ = iter(tangents_out)
full_out = [pval.get_known() if pval.is_known() else next(tangents_out_)
for pval in out_pvals]
assert next(tangents_out_, None) is None
return full_out
return apply_flat_fun(fun, io_tree, *py_args)
def _vjp_pullback_wrapper(cotangent_dtypes, cotangent_shapes,
io_tree, fun, py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
if in_tree != in_tree_expected:
raise TypeError(f"Tree structure of cotangent input {in_tree}, does not match structure of "
f"primal output {in_tree_expected}.")
for arg, ct_dtype, ct_shape in safe_zip(args, cotangent_dtypes, cotangent_shapes):
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(_dtype(arg))
if expected_tangent_dtype != ct_dtype:
raise TypeError(
f"Type of cotangent input to vjp pullback function ({ct_dtype}) is not "
f"the expected tangent type ({expected_tangent_dtype}) of corresponding primal output "
f"with dtype {_dtype(arg)}.")
if np.shape(arg) != ct_shape:
raise ValueError(
f"Shape of cotangent input to vjp pullback function {np.shape(arg)} "
"must be the same as the shape of corresponding primal input "
f"{ct_shape}.")
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@overload
def vjp(fun: Callable[..., T],
*primals: Any,
has_aux: Literal[False] = False,
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable]:
...
@overload
def vjp(fun: Callable[..., Tuple[T, U]], *primals: Any,
has_aux: Literal[True],
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable, U]:
...
def vjp( # type: ignore
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
"""Compute a (reverse-mode) vector-Jacobian product of ``fun``.
:py:func:`grad` is implemented as a special case of :py:func:`vjp`.
Args:
fun: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars.
primals: A sequence of primal values at which the Jacobian of ``fun``
should be evaluated. The length of ``primals`` should be equal to the
number of positional parameters to ``fun``. Each primal value should be a
tuple of arrays, scalar, or standard Python containers thereof.
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
VJP will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
create a VJP function that sums over the batch while ``vjp(f, *args)``
will create a per-example VJP.
Returns:
If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where
``primals_out`` is ``fun(*primals)``.
``vjpfun`` is a function from a cotangent vector with the same shape as
``primals_out`` to a tuple of cotangent vectors with the same shape as
``primals``, representing the vector-Jacobian product of ``fun`` evaluated at
``primals``. If ``has_aux`` is ``True``, returns a
``(primals_out, vjpfun, aux)`` tuple where ``aux`` is the auxiliary data
returned by ``fun``.
>>> import jax
>>>
>>> def f(x, y):
... return jax.numpy.sin(x), jax.numpy.cos(y)
...
>>> primals, f_vjp = jax.vjp(f, 0.5, 1.0)
>>> xbar, ybar = f_vjp((-0.7, 0.3))
>>> print(xbar)
-0.61430776
>>> print(ybar)
-0.2524413
"""
_check_callable(fun)
reduce_axes = _ensure_str_tuple(reduce_axes)
return _vjp(
lu.wrap_init(fun), *primals, has_aux=has_aux, reduce_axes=reduce_axes)
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False, reduce_axes=()):
"""Variant of vjp() that takes an lu.WrappedFun."""
primals_flat, in_tree = tree_flatten(primals)
for arg in primals_flat: _check_arg(arg)
if not has_aux:
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
out_primal, out_vjp = ad.vjp(
flat_fun, primals_flat, reduce_axes=reduce_axes)
out_tree = out_tree()
else:
flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
out_primal, out_vjp, aux = ad.vjp(
flat_fun, primals_flat, has_aux=True, reduce_axes=reduce_axes)
out_tree, aux_tree = out_aux_trees()
out_primal_py = tree_unflatten(out_tree, out_primal)
ct_dtypes = [core.primal_dtype_to_tangent_dtype(_dtype(x)) for x in out_primal]
ct_shapes = [np.shape(x) for x in out_primal]
# Ensure that vjp_py is a PyTree so that we can pass it from the forward to the
# backward pass in a custom VJP.
vjp_py = Partial(partial(_vjp_pullback_wrapper,
ct_dtypes, ct_shapes,
(out_tree, in_tree)),
out_vjp)
if not has_aux:
return out_primal_py, vjp_py
else:
return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux)
def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
"""Transpose a function that is promised to be linear.
For linear functions, this transformation is equivalent to ``vjp``, but
avoids the overhead of computing the forward pass.
The outputs of the transposed function will always have the exact same dtypes
as ``primals``, even if some values are truncated (e.g., from complex to
float, or from float64 to float32). To avoid truncation, use dtypes in
``primals`` that match the full range of desired outputs from the transposed
function. Integer dtypes are not supported.
Args:
fun: the linear function to be transposed.
*primals: a positional argument tuple of arrays, scalars, or (nested)
standard Python containers (tuples, lists, dicts, namedtuples, i.e.,
pytrees) of those types used for evaluating the shape/dtype of
``fun(*primals)``. These arguments may be real scalars/ndarrays, but that
is not required: only the ``shape`` and ``dtype`` attributes are accessed.
See below for an example. (Note that the duck-typed objects cannot be
namedtuples because those are treated as standard Python containers.)
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding cotangent. Otherwise, the
transposed function will be per-example over named axes. For example, if
``'batch'`` is a named batch axis, ``linear_transpose(f, *args,
reduce_axes=('batch',))`` will create a transpose function that sums over
the batch while ``linear_transpose(f, args)`` will create a per-example
transpose.
Returns:
A callable that calculates the transpose of ``fun``. Valid input into this
function must have the same shape/dtypes/structure as the result of
``fun(*primals)``. Output will be a tuple, with the same
shape/dtypes/structure as ``primals``.
>>> import jax
>>> import types
>>>
>>> f = lambda x, y: 0.5 * x - 0.5 * y
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
>>> f_transpose(1.0)
(DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32))
"""
reduce_axes = _ensure_str_tuple(reduce_axes)
primals_flat, in_tree = tree_flatten(primals)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
in_avals = map(shaped_abstractify, primals_flat)
in_dtypes = map(dtypes.dtype, in_avals)
in_pvals = map(pe.PartialVal.unknown, in_avals)
jaxpr, out_pvals, const = pe.trace_to_jaxpr_nounits(flat_fun, in_pvals,
instantiate=True)
out_avals, _ = unzip2(out_pvals)
out_dtypes = map(dtypes.dtype, out_avals)
if not (all(dtypes.issubdtype(d, np.inexact) for d in in_dtypes + out_dtypes)
or all(dtypes.issubdtype(d, np.integer)
for d in in_dtypes + out_dtypes)):
raise TypeError("linear_transpose only supports [float or complex] -> "
"[float or complex], and integer -> integer functions, "
f"but got {in_dtypes} -> {out_dtypes}.")
@api_boundary
def transposed_fun(const, out_cotangent):
out_cts, out_tree2 = tree_flatten(out_cotangent)
if out_tree() != out_tree2:
raise TypeError("cotangent tree does not match function output, "
f"expected {out_tree()} but got {out_tree2}")
if not all(map(core.typecheck, out_avals, out_cts)):
raise TypeError("cotangent type does not match function output, "
f"expected {out_avals} but got {out_cts}")
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
in_cts = ad.backward_pass(jaxpr, reduce_axes, True, const, dummies, out_cts)
in_cts = map(ad.instantiate_zeros, in_cts)
return tree_unflatten(in_tree, in_cts)
# Ensure that transposed_fun is a PyTree
return Partial(transposed_fun, const)
def make_jaxpr(fun: Callable,
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
return_shape: bool = False,
abstracted_axes: Optional[Any] = None,
) -> Callable[..., core.ClosedJaxpr]:
"""Creates a function that produces its jaxpr given example args.
Args:
fun: The function whose ``jaxpr`` is to be computed. Its positional
arguments and return value should be arrays, scalars, or standard Python
containers (tuple/list/dict) thereof.
static_argnums: See the :py:func:`jax.jit` docstring.
axis_env: Optional, a sequence of pairs where the first element is an axis
name and the second element is a positive integer representing the size of
the mapped axis with that name. This parameter is useful when lowering
functions that involve parallel communication collectives, and it
specifies the axis name/size environment that would be set up by
applications of :py:func:`jax.pmap`.
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the XLA
computation and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape``,
``dtype``, and ``named_shape`` attributes representing the corresponding
types of the output leaves.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
argument ``return_shape`` is ``True``, then the returned function instead
returns a pair where the first element is the ``ClosedJaxpr``
representation of ``fun`` and the second element is a pytree representing
the structure, shape, dtypes, and named shapes of the output of ``fun``.
A ``jaxpr`` is JAX's intermediate representation for program traces. The
``jaxpr`` language is based on the simply-typed first-order lambda calculus
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
``jaxpr``, which we can inspect to understand what JAX is doing internally.
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
We do not describe the semantics of the ``jaxpr`` language in detail here, but
instead give a few examples.
>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> print(f(3.0))
-0.83602
>>> jax.make_jaxpr(f)(3.0)
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
>>> jax.make_jaxpr(jax.grad(f))(3.0)
{ lambda ; a:f32[]. let
b:f32[] = cos a
c:f32[] = sin a
_:f32[] = sin b
d:f32[] = cos b
e:f32[] = mul 1.0 d
f:f32[] = neg e
g:f32[] = mul f c
in (g,) }
"""
_check_callable(fun)
static_argnums = _ensure_index_tuple(static_argnums)
def abstractify(args, kwargs):
flat_args, in_tree = tree_flatten((args, kwargs))
if abstracted_axes is None:
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
else:
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
in_avals, keep_inputs = unzip2(in_type)
return in_avals, in_tree, keep_inputs
@wraps(fun)
@api_boundary
def make_jaxpr_f(*args, **kwargs):
f = lu.wrap_init(fun)
if static_argnums:
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
f, args = argnums_partial(f, dyn_argnums, args)
in_avals, in_tree, keep_inputs = abstractify(args, kwargs)
in_type = tuple(zip(in_avals, keep_inputs))
f, out_tree = flatten_fun(f, in_tree)
f = lu.annotate(f, in_type)
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
if return_shape:
out_avals, _ = unzip2(out_type)
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat)
return closed_jaxpr
make_jaxpr_f.__name__ = f"make_jaxpr({make_jaxpr.__name__})"
return make_jaxpr_f
def device_put(x, device: Optional[xc.Device] = None):
"""Transfers ``x`` to ``device``.
Args:
x: An array, scalar, or (nested) standard Python container thereof.
device: The (optional) :py:class:`Device` to which ``x`` should be
transferred. If given, then the result is committed to the device.
If the ``device`` parameter is ``None``, then this operation behaves like the
identity function if the operand is on any device already, otherwise it
transfers the data to the default device, uncommitted.
For more details on data placement see the
:ref:`FAQ on data placement <faq-data-placement>`.
This function is always asynchronous, i.e. returns immediately.
Returns:
A copy of ``x`` that resides on ``device``.
"""
with config_explicit_device_put_scope():
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # noqa: F811
"""Transfer array shards to specified devices and form ShardedDeviceArray(s).
Args:
shards: A sequence of arrays, scalars, or (nested) standard Python
containers thereof representing the shards to be stacked together to form
the output. The length of ``shards`` must equal the length of ``devices``.
devices: A sequence of :py:class:`Device` instances representing the devices
to which corresponding shards in ``shards`` will be transferred.
This function is always asynchronous, i.e. returns immediately.
Returns:
A ShardedDeviceArray or (nested) Python container thereof representing the
elements of ``shards`` stacked together, with each shard backed by physical
device memory specified by the corresponding entry in ``devices``.
Examples:
Passing a list of arrays for ``shards`` results in a sharded array
containing a stacked version of the inputs:
>>> import jax
>>> devices = jax.local_devices()
>>> x = [jax.numpy.ones(5) for device in devices]
>>> y = jax.device_put_sharded(x, devices)
>>> np.allclose(y, jax.numpy.stack(x))
True
Passing a list of nested container objects with arrays at the leaves for
``shards`` corresponds to stacking the shards at each leaf. This requires
all entries in the list to have the same tree structure:
>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))]
>>> y = jax.device_put_sharded(x, devices)
>>> type(y)
<class 'tuple'>
>>> y0 = jax.device_put_sharded([a for a, b in x], devices)
>>> y1 = jax.device_put_sharded([b for a, b in x], devices)
>>> np.allclose(y[0], y0)
True
>>> np.allclose(y[1], y1)
True
See Also:
- device_put
- device_put_replicated
"""
# TODO(jakevdp): provide a default for devices that considers both local
# devices and pods
if not isinstance(shards, Sequence):
raise ValueError("device_put_sharded `shards` input must be a sequence; "
f"got {type(shards)}")
if len(shards) != len(devices):
raise ValueError(f"len(shards) = {len(shards)} must equal "
f"len(devices) = {len(devices)}.")
def _device_put_sharded(*xs):
avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs]
if not all(a1 == a2 for a1, a2 in zip(avals[:-1], avals[1:])):
a1, a2 = next((a1, a2) for a1, a2 in zip(avals[:-1], avals[1:])
if a1 != a2)
raise ValueError("the shards passed to device_put_sharded must have "
f"consistent shape and dtype, but got {a1} and {a2}.")
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
buffers = [buf for x, d in zip(xs, devices)
for buf in dispatch.device_put(x, d)]
if config.jax_array:
from jax.experimental import array, sharding
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
return array.Array(
stacked_aval,
sharding.PmapSharding(np.array(devices), sharding_spec),
buffers, committed=True, _skip_checks=True)
else:
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
with config_explicit_device_put_scope():
return tree_map(_device_put_sharded, *shards)
def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
"""Transfer array(s) to each specified device and form ShardedDeviceArray(s).
Args:
x: an array, scalar, or (nested) standard Python container thereof
representing the array to be replicated to form the output.
devices: A sequence of :py:class:`Device` instances representing the devices
to which ``x`` will be transferred.
This function is always asynchronous, i.e. returns immediately.
Returns:
A ShardedDeviceArray or (nested) Python container thereof representing the
value of ``x`` broadcasted along a new leading axis of size
``len(devices)``, with each slice along that new leading axis backed by
memory on the device specified by the corresponding entry in ``devices``.
Examples:
Passing an array:
>>> import jax
>>> devices = jax.local_devices()
>>> x = jax.numpy.array([1., 2., 3.])
>>> y = jax.device_put_replicated(x, devices)
>>> np.allclose(y, jax.numpy.stack([x for _ in devices]))
True
See Also:
- device_put
- device_put_sharded
"""
if not isinstance(devices, Sequence) or not devices:
raise ValueError("`devices` argument to `device_put_replicated must be "
"a non-empty sequence.")
def _device_put_replicated(x):
aval = core.unmapped_aval(len(devices), core.no_axis_name, 0,
core.raise_to_shaped(core.get_aval(x)))
assert (isinstance(aval, ShapedArray) and
len(xla.aval_to_xla_shapes(aval)) == 1)
buf, = dispatch.device_put(x, devices[0])
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
if config.jax_array:
from jax.experimental import array, sharding
sharding_spec = pxla._create_pmap_sharding_spec(aval)
return array.Array(
aval, sharding.PmapSharding(np.array(devices), sharding_spec),
[buf, *rest_bufs], committed=True, _skip_checks=True)
else:
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
with config_explicit_device_put_scope():
return tree_map(_device_put_replicated, x)
# TODO(mattjj): consider revising
def _device_get(x):
if isinstance(x, core.Tracer):
return x
try:
toarray = x.__array__
except AttributeError:
return x
else:
return toarray()
def device_get(x: Any):
"""Transfer ``x`` to host.
If ``x`` is a pytree, then the individual buffers are copied in parallel.
Args:
x: An array, scalar, DeviceArray or (nested) standard Python container thereof
representing the array to be transferred to host.
Returns:
An array or (nested) Python container thereof representing the
value of ``x``.
Examples:
Passing a DeviceArray:
>>> import jax
>>> x = jax.numpy.array([1., 2., 3.])
>>> jax.device_get(x)
array([1., 2., 3.], dtype=float32)
Passing a scalar (has no effect):
>>> jax.device_get(1)
1
See Also:
- device_put
- device_put_sharded
- device_put_replicated
"""
with config_explicit_device_get_scope():
for y in tree_leaves(x):
try:
y.copy_to_host_async()
except AttributeError:
pass
return tree_map(_device_get, x)
def _check_arg(arg):
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
# TODO(mattjj,necula): this duplicates code in core.valid_jaxtype, but one
# internal user relies on it for duck-typing. must fix downstream user!
def _valid_jaxtype(arg):
try:
xla.abstractify(arg) # faster than core.get_aval
except TypeError:
return core.valid_jaxtype(arg)
else:
return True
class ShapeDtypeStruct:
__slots__ = ["shape", "dtype", "named_shape"]
def __init__(self, shape, dtype, named_shape=None):
self.shape = shape
self.dtype = dtype if core.is_opaque_dtype(dtype) else np.dtype(dtype)
self.named_shape = {} if named_shape is None else dict(named_shape)
size = property(lambda self: prod(self.shape))
ndim = property(lambda self: len(self.shape))
def __len__(self):
try:
return self.shape[0]
except IndexError as e:
raise TypeError("len() of unsized object") from e # same as numpy error
def __repr__(self):
ns = f", named_shape={self.named_shape}" if self.named_shape else ""
return f"{type(self).__name__}(shape={self.shape}, dtype={self.dtype.name}{ns})"
__str__ = __repr__
def __eq__(self, other):
if not isinstance(other, ShapeDtypeStruct):
return False
else:
return (other.shape, other.dtype, other.named_shape) == (
self.shape, self.dtype, self.named_shape)
def __hash__(self):
# TODO(frostig): avoid the conversion from dict by addressing
# https://github.com/google/jax/issues/8182
named = frozenset(self.named_shape.items())
return hash((self.shape, self.dtype, named))
core.pytype_aval_mappings[ShapeDtypeStruct] = (
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype),
weak_type=False, named_shape=x.named_shape))
def eval_shape(fun: Callable, *args, **kwargs):
"""Compute the shape/dtype of ``fun`` without any FLOPs.
This utility function is useful for performing shape inference. Its
input/output behavior is defined by::
def eval_shape(fun, *args, **kwargs):
out = fun(*args, **kwargs)
return jax.tree_util.tree_map(shape_dtype_struct, out)
def shape_dtype_struct(x):
return ShapeDtypeStruct(x.shape, x.dtype)
class ShapeDtypeStruct:
__slots__ = ["shape", "dtype"]
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
In particular, the output is a pytree of objects that have ``shape`` and
``dtype`` attributes, but nothing else about them is guaranteed by the API.
But instead of applying ``fun`` directly, which might be expensive, it uses
JAX's abstract interpretation machinery to evaluate the shapes without doing
any FLOPs.
Using :py:func:`eval_shape` can also catch shape errors, and will raise same
shape errors as evaluating ``fun(*args, **kwargs)``.
Args:
fun: The function whose output shape should be evaluated.
*args: a positional argument tuple of arrays, scalars, or (nested) standard
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
those types. Since only the ``shape`` and ``dtype`` attributes are
accessed, only values that duck-type arrays are required, rather than real
ndarrays. The duck-typed objects cannot be namedtuples because those are
treated as standard Python containers. See the example below.
**kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
Python containers (pytrees) of those types. As in ``args``, array values
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> f = lambda A, x: jnp.tanh(jnp.dot(A, x))
>>> class MyArgArray(object):
... def __init__(self, shape, dtype):
... self.shape = shape
... self.dtype = jnp.dtype(dtype)
...
>>> A = MyArgArray((2000, 3000), jnp.float32)
>>> x = MyArgArray((3000, 1000), jnp.float32)
>>> out = jax.eval_shape(f, A, x) # no FLOPs performed
>>> print(out.shape)
(2000, 1000)
>>> print(out.dtype)
float32
"""
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug_info = pe.debug_info(fun, in_tree, True, "eval_shape")
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
*map(shaped_abstractify, args_flat),
debug_info=debug_info)
out = [ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out]
return tree_unflatten(out_tree(), out)
@functools.wraps(new_checkpoint) # config.jax_new_checkpoint is True by default
def checkpoint(fun: Callable, *,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, Tuple[int, ...]] = (),
policy: Optional[Callable[..., bool]] = None,
) -> Callable:
if concrete:
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
"in its place, you can use its `static_argnums` option, and if "
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
"\n"
"For example, if using `concrete=True` for an `is_training` flag:\n"
"\n"
" from functools import partial\n"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, is_training):\n"
" if is_training:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"replace it with a use of `static_argnums`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, is_training):\n"
" ...\n"
"\n"
"If jax.numpy operations need to be performed on static arguments, "
"we can use the `jax.ensure_compile_time_eval()` context manager. "
"For example, we can replace this use of `concrete=True`\n:"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, y):\n"
" if y > 0:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"with this combination of `static_argnums` and "
"`jax.ensure_compile_time_eval()`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, y):\n"
" with jax.ensure_compile_time_eval():\n"
" y_pos = y > 0\n"
" if y_pos:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
raise NotImplementedError(msg)
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)
remat = checkpoint # type: ignore
def named_call(
fun: Callable[..., Any],
*,
name: Optional[str] = None,
) -> Callable[..., Any]:
"""Adds a user specified name to a function when staging out JAX computations.
When staging out computations for just-in-time compilation to XLA (or other
backends such as TensorFlow) JAX runs your Python program but by default does
not preserve any of the function names or other metadata associated with it.
This can make debugging the staged out (and/or compiled) representation of
your program complicated because there is limited context information for each
operation being executed.
`named_call` tells JAX to stage the given function out as a subcomputation
with a specific name. When the staged out program is compiled with XLA these
named subcomputations are preserved and show up in debugging utilities like
the TensorFlow Profiler in TensorBoard. Names are also preserved when staging
out JAX programs to TensorFlow using :func:`experimental.jax2tf.convert`.
Args:
fun: Function to be wrapped. This can be any Callable.
name: Optional. The prefix to use to name all sub computations created
within the name scope. Use the fun.__name__ if not specified.
Returns:
A version of `fun` that is wrapped in a name_scope.
"""
if name is None:
name = fun.__name__
_, in_tree = tree_flatten(())
if config.jax_experimental_name_stack:
return source_info_util.extend_name_stack(name)(fun)
@functools.wraps(fun)
def named_call_f(*args, **kwargs):
lu_f = lu.wrap_init(lambda: fun(*args, **kwargs))
flat_f, out_tree = flatten_fun_nokwargs(lu_f, in_tree)
out_flat = core.named_call_p.bind(flat_f, name=name)
return tree_unflatten(out_tree(), out_flat)
return named_call_f
@contextmanager
def named_scope(
name: str,
) -> Generator[None, None, None]:
"""A context manager that adds a user specified name to the JAX name stack.
When staging out computations for just-in-time compilation to XLA (or other
backends such as TensorFlow) JAX does not, by default, preserve the names
(or other source metadata) of Python functions it encounters.
This can make debugging the staged out (and/or compiled) representation of
your program complicated because there is limited context information for each
operation being executed.
``named_scope`` tells JAX to stage the given function with additional
annotations on the underlying operations. JAX internally keeps track of these
annotations in a name stack. When the staged out program is compiled with XLA
these annotations are preserved and show up in debugging utilities like the
TensorFlow Profiler in TensorBoard. Names are also preserved when staging out
JAX programs to TensorFlow using :func:`experimental.jax2tf.convert`.
Args:
name: The prefix to use to name all operations created within the name
scope.
Yields:
Yields ``None``, but enters a context in which `name` will be appended to
the active name stack.
Examples:
``named_scope`` can be used as a context manager inside compiled functions:
>>> import jax
>>>
>>> @jax.jit
... def layer(w, x):
... with jax.named_scope("dot_product"):
... logits = w.dot(x)
... with jax.named_scope("activation"):
... return jax.nn.relu(logits)
It can also be used as a decorator:
>>> @jax.jit
... @jax.named_scope("layer")
... def layer(w, x):
... logits = w.dot(x)
... return jax.nn.relu(logits)
"""
if not isinstance(name, str):
raise ValueError("named_scope name argument must be a string.")
with source_info_util.extend_name_stack(name):
yield
def effects_barrier():
"""Waits until existing functions have completed any side-effects."""
dispatch.runtime_tokens.block_until_ready()
def block_until_ready(x):
"""
Tries to call a ``block_until_ready`` method on pytree leaves.
Args:
x: a pytree, usually with at least some JAX array instances at its leaves.
Returns:
A pytree with the same structure and values of the input, where the values
of all JAX array leaves are ready.
"""
def try_to_block(x):
try:
return x.block_until_ready()
except AttributeError:
return x
return jax.tree_util.tree_map(try_to_block, x)
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, **kwargs: Any):
"""Applies a functionally pure Python callable. Works under `jit`/`pmap`/etc.
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
should also return NumPy arrays. Execution takes place on CPU, like any
Python+NumPy function.
The callback is treated as functionally pure, meaning it has no side-effects
and its output value depends only on its argument values. As a consequence, it
is safe to be called multiple times (e.g. when transformed by ``vmap`` or
``pmap``), or not to be called at all when e.g. the output of a
`jit`-decorated function has no data dependence on its value. Pure callbacks
may also be reordered if data-dependence allows.
When ``pmap``-ed, the pure callback will be called several times (one on each
axis of the map). When `vmap`-ed the behavior will depend on the value of the
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
is assumed to obey
``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
Therefore, the callback will be called directly on batched inputs (where the
batch axes are the leading dimensions). Additionally, the callbacks should
return outputs that have corresponding leading batch axes. If not vectorized
``callback`` will be mapped sequentially across the batched axis.
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
to set ``vectorized=True`` because the ``np.matmul`` function handles
arbitrary leading batch dimensions.
Args:
callback: A Python callable. The callable will be passed PyTrees of NumPy
arrays as arguments, and should return a PyTree of NumPy arrays that
matches ``result_shape_dtypes``.
result_shape_dtypes: A PyTree with leaves that are objects with ``shape``
and ``dtype`` attributes which represent to the shapes and dtypes of the
value of ``callback`` applied to ``args`` and ``kwargs``.
*args: The positional arguments to the callback. Must be PyTrees of JAX
types.
vectorized: A boolean that indicates whether or not ``callback`` is
vectorized, meaning it can handle arrays with additional leading
dimensions. If ``vectorized`` is `True`, when the callback is mapped
via `jax.vmap`, it will be called directly on inputs with leading batch
dimensions instead of executing ``callback`` on each mapped input
individually. The callback should also return outputs batched across the
leading axis.
**kwargs: The keyword arguments to the callback. Must be PyTrees of JAX
types.
Returns:
The value of ``callback(*args, **kwargs)``.
"""
return jcb.pure_callback(callback, result_shape_dtypes, *args, **kwargs)
def clear_backends():
"""
Clear all backend clients so that new backend clients can be created later.
"""
xb._clear_backends()
jax.lib.xla_bridge._backends = {}
dispatch.xla_callable.cache_clear() # type: ignore
dispatch.xla_primitive_callable.cache_clear()
_cpp_jit_cache.clear()
jax_jit.CompiledFunctionCache.clear_all()