mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
prepare to switch to new remat
This commit involves a few things, which are all united in being about landing the new remat (aka new checkpoint) implementation: * add benchmarks for new remat eager performance, and some caching to make those benchmarks fast * warn when the old-remat-exclusive `concrete` feature is used, with an actionable message pointing to the new recommended approach involving static_argnums * add the static_argnums parameter to both new and old remt * update docstrings (and de-duplicate them to) * add new tests, especially around caching and errors/warnings
This commit is contained in:
parent
40c12e376e
commit
e3a92d52ba
@ -22,6 +22,7 @@ from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import sparse
|
||||
from jax._src.api_util import shaped_abstractify # technically not an api fn
|
||||
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental import sharding
|
||||
@ -44,6 +45,9 @@ def required_devices(num_devices_required):
|
||||
return helper2
|
||||
return helper1
|
||||
|
||||
def swap(a, b):
|
||||
return b, a
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
def eager_unary_dispatch(state):
|
||||
@ -504,8 +508,35 @@ def bench_pjit_check_aval_sharding(state):
|
||||
pjit_lib.pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)
|
||||
|
||||
|
||||
def swap(a, b):
|
||||
return b, a
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def bench_remat_eager_retracing_overheads(state):
|
||||
def double_compose(f):
|
||||
return lambda x: f(f(x))
|
||||
|
||||
f = jnp.sin
|
||||
for _ in range(6):
|
||||
f = double_compose(f)
|
||||
f = double_compose(checkpoint(f))
|
||||
|
||||
while state:
|
||||
y, _ = jax.vjp(f, 3.)
|
||||
y.block_until_ready()
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def bench_remat_eager_retracing_overheads_static_argnums(state):
|
||||
def double_compose(f):
|
||||
return lambda x, y: f(f(x, y), y)
|
||||
|
||||
f = lambda x, _: jnp.sin(x)
|
||||
for _ in range(6):
|
||||
f = double_compose(f)
|
||||
f = double_compose(checkpoint(f, static_argnums=(1,)))
|
||||
|
||||
while state:
|
||||
y, _ = jax.vjp(f, 3., True)
|
||||
y.block_until_ready()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from functools import partial
|
||||
import operator as op
|
||||
from typing import Callable, Optional, List, Tuple, Sequence, Union
|
||||
from typing import Callable, Optional, List, Tuple, Sequence, Union, Any
|
||||
import types
|
||||
|
||||
import jax
|
||||
@ -96,8 +96,9 @@ checkpoint_policies = types.SimpleNamespace(
|
||||
|
||||
### Main API
|
||||
|
||||
def checkpoint(fun: Callable, prevent_cse: bool = True,
|
||||
policy: Optional[Callable[..., bool]] = None
|
||||
def checkpoint(fun: Callable, *, prevent_cse: bool = True,
|
||||
policy: Optional[Callable[..., bool]] = None,
|
||||
static_argnums: Union[int, Tuple[int, ...]] = (),
|
||||
) -> Callable:
|
||||
"""Make ``fun`` recompute internal linearization points when differentiated.
|
||||
|
||||
@ -131,25 +132,26 @@ def checkpoint(fun: Callable, prevent_cse: bool = True,
|
||||
from the default of storing all intermediate linearization points to
|
||||
recomputing them. Its arguments and return value should be arrays,
|
||||
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
|
||||
concrete: Optional, boolean indicating whether ``fun`` may involve
|
||||
value-dependent Python control flow (default False). Support for such
|
||||
control flow is optional, and disabled by default, because in some
|
||||
edge-case compositions with :func:`jax.jit` it can lead to some extra
|
||||
computation.
|
||||
prevent_cse: Optional, boolean indicating whether to prevent common
|
||||
subexpression elimination (CSE) optimizations in the HLO generated from
|
||||
differentiation. This CSE prevention has costs because it can foil other
|
||||
optimizations, and because it can incur high overheads on some backends,
|
||||
especially GPU. The default is True because otherwise, under a ``jit`` or
|
||||
``pmap``, CSE can defeat the purpose of this decorator. But in some
|
||||
settings, like when used inside a ``scan``, this CSE prevention mechanism
|
||||
is unnecessary, in which case ``prevent_cse`` can be set to False.
|
||||
policy: This is an experimental feature and the API is likely to change.
|
||||
Optional callable, one of the attributes of ``jax.checkpoint_policies``,
|
||||
which takes as input a type-level specification of a first-order primitive
|
||||
application and returns a boolean indicating whether the corresponding
|
||||
output value(s) can be saved as a residual (or, if not, instead must be
|
||||
recomputed in the (co)tangent computation).
|
||||
prevent_cse: Optional, boolean keyword-only argument indicating whether to
|
||||
prevent common subexpression elimination (CSE) optimizations in the HLO
|
||||
generated from differentiation. This CSE prevention has costs because it
|
||||
can foil other optimizations, and because it can incur high overheads on
|
||||
some backends, especially GPU. The default is True because otherwise,
|
||||
under a ``jit`` or ``pmap``, CSE can defeat the purpose of this decorator.
|
||||
But in some settings, like when used inside a ``scan``, this CSE
|
||||
prevention mechanism is unnecessary, in which case ``prevent_cse`` can be
|
||||
set to False.
|
||||
static_argnums: Optional, int or sequence of ints, a keyword-only argument
|
||||
indicating which argument values on which to specialize for tracing and
|
||||
caching purposes. Specifying arguments as static can avoid
|
||||
ConcretizationTypeErrors when tracing, but at the cost of more retracing
|
||||
overheads. See the example below.
|
||||
policy: Optional, callable keyword-only argument. It should be one of the
|
||||
attributes of ``jax.checkpoint_policies``. The callable takes as input a
|
||||
type-level specification of a first-order primitive application and
|
||||
returns a boolean indicating whether the corresponding output value(s) can
|
||||
be saved as residuals (or instead must be recomputed in the (co)tangent
|
||||
computation if needed).
|
||||
|
||||
Returns:
|
||||
A function (callable) with the same input/output behavior as ``fun`` but
|
||||
@ -203,13 +205,46 @@ def checkpoint(fun: Callable, prevent_cse: bool = True,
|
||||
... f2 = recursive_checkpoint(funs[len(funs)//2:])
|
||||
... return lambda x: f1(jax.checkpoint(f2)(x))
|
||||
...
|
||||
|
||||
If ``fun`` involves Python control flow which depends on argument values,
|
||||
it may be necessary to use the ``static_argums`` parameter. For example,
|
||||
consider a boolean flag argument:
|
||||
|
||||
from functools import partial
|
||||
|
||||
@partial(jax.checkpoint, static_argnums=(1,))
|
||||
def foo(x, is_training):
|
||||
if is_training:
|
||||
...
|
||||
else:
|
||||
...
|
||||
|
||||
Here, the use of ``static_argnums`` is necessary because the ``if`` statement
|
||||
depends on the value of ``is_training``. The only cost to using
|
||||
``static_argnums`` is more retracing overheads: in the example, ``foo`` must
|
||||
be retraced for every new value of ``is_training``. In some cases it may
|
||||
additionally be necessary to use ``jax.ensure_compile_time_eval``:
|
||||
|
||||
@partial(jax.checkpoint, static_argnums=(1,))
|
||||
def foo(x, y):
|
||||
with jax.ensure_compile_time_eval():
|
||||
y_pos = y > 0
|
||||
if y_pos:
|
||||
...
|
||||
else:
|
||||
...
|
||||
|
||||
As an alternative to using ``static_argnums`` (and
|
||||
``jax.ensure_compile_time_eval``), it may be easier to compute some values
|
||||
outside the ``jax.checkpoint``-decorated function and then close over them.
|
||||
"""
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def fun_remat(*args, **kwargs):
|
||||
fun_, args = _remat_static_argnums(fun, static_argnums, args)
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
in_avals = [shaped_abstractify(x) for x in args_flat]
|
||||
jaxpr, consts, out_tree = _trace_to_jaxpr(fun, in_tree, tuple(in_avals))
|
||||
jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals))
|
||||
out_flat = remat_p.bind(
|
||||
*consts, *args_flat, jaxpr=jaxpr, prevent_cse=prevent_cse,
|
||||
differentiated=False, policy=policy)
|
||||
@ -218,11 +253,90 @@ def checkpoint(fun: Callable, prevent_cse: bool = True,
|
||||
|
||||
remat = checkpoint # alias
|
||||
|
||||
# This function is similar to api_util.argnums_partial, except the error
|
||||
# messages are specific to jax.remat (and thus more actionable), the
|
||||
# hashing/caching behavior is slightly different, and this function accepts a
|
||||
# boolean for static_argnums. Perhaps the two could be de-duplicated.
|
||||
def _remat_static_argnums(fun, static_argnums, args):
|
||||
if type(static_argnums) is int:
|
||||
static_argnums = (static_argnums,)
|
||||
elif not (type(static_argnums) is tuple and
|
||||
all(type(d) is int for d in static_argnums)):
|
||||
raise TypeError("the `static_argnums` argument to `jax.checkpoint` / "
|
||||
"`jax.remat` must be an int, tuple of ints or, bool, but "
|
||||
f"got value {static_argnums}")
|
||||
|
||||
if not all(-len(args) <= d < len(args) for d in static_argnums):
|
||||
raise ValueError("the `static_argnums` argument to `jax.checkpoint` / "
|
||||
"`jax.remat` can only take integer values greater than or "
|
||||
"equal to `-len(args)` and less than `len(args)`, but got "
|
||||
f"{static_argnums}")
|
||||
|
||||
if not static_argnums:
|
||||
return fun, args
|
||||
nargs = len(args)
|
||||
static_argnums_ = frozenset(d % len(args) for d in static_argnums)
|
||||
dyn_args, static_args = [], []
|
||||
for i, x in enumerate(args):
|
||||
if i in static_argnums_: static_args.append(WrapHashably(x))
|
||||
else: dyn_args.append(x)
|
||||
new_fun = _dyn_args_fun(fun, static_argnums_, tuple(static_args), nargs)
|
||||
return new_fun, dyn_args
|
||||
|
||||
class WrapHashably:
|
||||
val: Any
|
||||
hash: Optional[int] = None
|
||||
hashable: bool
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
try:
|
||||
self.hash = hash(val)
|
||||
self.hashable = True
|
||||
except:
|
||||
self.hash = id(val)
|
||||
self.hashable = False
|
||||
def __hash__(self):
|
||||
return self.hash
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, WrapHashably):
|
||||
try: return self.val == other.val
|
||||
except: return self.val is other.val
|
||||
return False
|
||||
|
||||
# This caching is useful to avoid retracing even when static_argnums is used.
|
||||
# See api_benchmark.py:bench_remat_eager_retracing_overheads_static_argnums.
|
||||
# On that benchmark, including this caching makes a ~10x difference (which can
|
||||
# be made arbitrary large by involving larger functions to be traced).
|
||||
@weakref_lru_cache
|
||||
def _dyn_args_fun(fun: Callable, static_argnums: Tuple[int, ...],
|
||||
static_args: Tuple[WrapHashably, ...], nargs: int):
|
||||
def new_fun(*dyn_args, **kwargs):
|
||||
static_args_, dyn_args_ = iter(static_args), iter(dyn_args)
|
||||
full_args = [next(static_args_).val if i in static_argnums
|
||||
else next(dyn_args_) for i in range(nargs)]
|
||||
return fun(*full_args, **kwargs)
|
||||
return new_fun
|
||||
|
||||
# This helper is similar to those in control_flow/common.py, but with
|
||||
# remat-specific errors.
|
||||
@weakref_lru_cache
|
||||
def _trace_to_jaxpr(fun, in_tree, in_avals):
|
||||
debug = pe.debug_info(fun, in_tree, False, "checkpoint")
|
||||
debug = pe.debug_info(fun, in_tree, True, "checkpoint")
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
try:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
except core.ConcretizationTypeError as e:
|
||||
msg, = e.args
|
||||
new_msg = msg + "\n\n" + (
|
||||
"Consider using the `static_argnums` parameter for `jax.remat` or "
|
||||
"`jax.checkpoint`. See the `jax.checkpoint` docstring and its example "
|
||||
"involving `static_argnums`:\n"
|
||||
"https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html"
|
||||
"\n")
|
||||
new_e = core.ConcretizationTypeError.__new__(core.ConcretizationTypeError)
|
||||
new_e.args = (new_msg,)
|
||||
raise new_e from None
|
||||
return pe.convert_constvars_jaxpr(jaxpr), consts, out_tree()
|
||||
|
||||
|
||||
|
189
jax/_src/api.py
189
jax/_src/api.py
@ -53,7 +53,8 @@ from jax._src.api_util import (
|
||||
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
|
||||
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
shaped_abstractify, _ensure_str_tuple, argnames_partial_except, validate_argnames, validate_argnums)
|
||||
shaped_abstractify, _ensure_str_tuple, argnames_partial_except,
|
||||
validate_argnames, validate_argnums)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
@ -63,7 +64,7 @@ from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import broadcast_prefix
|
||||
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
extend_name_stack, new_name_stack, wrap_name, cache,
|
||||
wraps, HashableFunction)
|
||||
wraps, HashableFunction, weakref_lru_cache)
|
||||
|
||||
# Unused imports to be exported
|
||||
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
@ -71,6 +72,7 @@ from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
process_count, host_id, host_ids,
|
||||
host_count, default_backend)
|
||||
from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint
|
||||
from jax._src.ad_checkpoint import _remat_static_argnums
|
||||
from jax.core import ShapedArray, raise_to_shaped
|
||||
from jax.custom_batching import custom_vmap
|
||||
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
|
||||
@ -660,12 +662,12 @@ def disable_jit():
|
||||
"""Context manager that disables :py:func:`jit` behavior under its dynamic context.
|
||||
|
||||
For debugging it is useful to have a mechanism that disables :py:func:`jit`
|
||||
everywhere in a dynamic context. Note that this not only disables explicit uses
|
||||
of `jit` by the user, but will also remove any implicit JIT compilation used by the
|
||||
JAX library: this includes implicit JIT computation of `body` and `cond`
|
||||
functions passed to higher-level primitives like :func:`scan` and :func:`while_loop`,
|
||||
JIT used in implementations of :mod:`jax.numpy` functions, and any other case where
|
||||
`jit` is used within an API's implementation.
|
||||
everywhere in a dynamic context. Note that this not only disables explicit
|
||||
uses of `jit` by the user, but will also remove any implicit JIT compilation
|
||||
used by the JAX library: this includes implicit JIT computation of `body` and
|
||||
`cond` functions passed to higher-level primitives like :func:`scan` and
|
||||
:func:`while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
|
||||
and any other case where `jit` is used within an API's implementation.
|
||||
|
||||
Values that have a data dependence on the arguments to a jitted function are
|
||||
traced and abstracted. For example, an abstract value may be a
|
||||
@ -3046,130 +3048,81 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
|
||||
def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
|
||||
def checkpoint(fun: Callable, *,
|
||||
concrete: bool = False,
|
||||
prevent_cse: bool = True,
|
||||
static_argnums: Union[int, Tuple[int, ...], bool] = (),
|
||||
policy: Optional[Callable[..., bool]] = None,
|
||||
) -> Callable:
|
||||
"""Make ``fun`` recompute internal linearization points when differentiated.
|
||||
if concrete:
|
||||
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
|
||||
"in its place, you can use its `static_argnums` option, and if "
|
||||
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
|
||||
"\n"
|
||||
"For example, if using `concrete=True` for an `is_training` flag:\n"
|
||||
"\n"
|
||||
" from functools import partial\n"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, concrete=True)\n"
|
||||
" def foo(x, is_training):\n"
|
||||
" if is_training:\n"
|
||||
" return f(x)\n"
|
||||
" else:\n"
|
||||
" return g(x)\n"
|
||||
"\n"
|
||||
"replace it with a use of `static_argnums`:\n"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, static_argnums=(1,))\n"
|
||||
" def foo(x, is_training):\n"
|
||||
" ...\n"
|
||||
"\n"
|
||||
"If jax.numpy operations need to be performed on static arguments, "
|
||||
"we can use the `jax.ensure_compile_time_eval()` context manager. "
|
||||
"For example, we can replace this use of `concrete=True`\n:"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, concrete=True)\n"
|
||||
" def foo(x, y):\n"
|
||||
" if y > 0:\n"
|
||||
" return f(x)\n"
|
||||
" else:\n"
|
||||
" return g(x)\n"
|
||||
"\n"
|
||||
"with this combination of `static_argnums` and "
|
||||
"`jax.ensure_compile_time_eval()`:\n"
|
||||
"\n"
|
||||
" @partial(jax.checkpoint, static_argnums=(1,))\n"
|
||||
" def foo(x, y):\n"
|
||||
" with jax.ensure_compile_time_eval():\n"
|
||||
" y_pos = y > 0\n"
|
||||
" if y_pos:\n"
|
||||
" return f(x)\n"
|
||||
" else:\n"
|
||||
" return g(x)\n"
|
||||
"\n")
|
||||
if config.jax_new_checkpoint:
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
warn(msg, DeprecationWarning)
|
||||
|
||||
The :func:`jax.checkpoint` decorator, aliased to ``jax.remat``, provides a
|
||||
way to trade off computation time and memory cost in the context of automatic
|
||||
differentiation, especially with reverse-mode autodiff like :func:`jax.grad`
|
||||
and :func:`jax.vjp` but also with :func:`jax.linearize`.
|
||||
|
||||
When differentiating a function in reverse-mode, by default all the
|
||||
linearization points (e.g. inputs to elementwise nonlinear primitive
|
||||
operations) are stored when evaluating the forward pass so that they can be
|
||||
reused on the backward pass. This evaluation strategy can lead to a high
|
||||
memory cost, or even to poor performance on hardware accelerators where memory
|
||||
access is much more expensive than FLOPs.
|
||||
|
||||
An alternative evaluation strategy is for some of the linearization points to
|
||||
be recomputed (i.e. rematerialized) rather than stored. This approach can
|
||||
reduce memory usage at the cost of increased computation.
|
||||
|
||||
This function decorator produces a new version of ``fun`` which follows
|
||||
the rematerialization strategy rather than the default store-everything
|
||||
strategy. That is, it returns a new version of ``fun`` which, when
|
||||
differentiated, doesn't store any of its intermediate linearization points.
|
||||
Instead, these linearization points are recomputed from the function's saved
|
||||
inputs.
|
||||
|
||||
See the examples below.
|
||||
|
||||
Args:
|
||||
fun: Function for which the autodiff evaluation strategy is to be changed
|
||||
from the default of storing all intermediate linearization points to
|
||||
recomputing them. Its arguments and return value should be arrays,
|
||||
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
|
||||
concrete: Optional, boolean indicating whether ``fun`` may involve
|
||||
value-dependent Python control flow (default False). Support for such
|
||||
control flow is optional, and disabled by default, because in some
|
||||
edge-case compositions with :func:`jax.jit` it can lead to some extra
|
||||
computation.
|
||||
prevent_cse: Optional, boolean indicating whether to prevent common
|
||||
subexpression elimination (CSE) optimizations in the HLO generated from
|
||||
differentiation. This CSE prevention has costs because it can foil other
|
||||
optimizations, and because it can incur high overheads on some backends,
|
||||
especially GPU. The default is True because otherwise, under a ``jit`` or
|
||||
``pmap``, CSE can defeat the purpose of this decorator. But in some
|
||||
settings, like when used inside a ``scan``, this CSE prevention mechanism
|
||||
is unnecessary, in which case ``prevent_cse`` can be set to False.
|
||||
policy: Optional callable, one of the attributes of
|
||||
``jax.checkpoint_policies``, which takes as input a type-level
|
||||
specification of a first-order primitive application and returns a boolean
|
||||
indicating whether the corresponding output value(s) can be saved as a
|
||||
residual (or instead must be recomputed in the (co)tangent computation if
|
||||
needed).
|
||||
|
||||
Returns:
|
||||
A function (callable) with the same input/output behavior as ``fun`` but
|
||||
which, when differentiated using e.g. :func:`jax.grad`, :func:`jax.vjp`, or
|
||||
:func:`jax.linearize`, recomputes rather than stores intermediate
|
||||
linearization points, thus potentially saving memory at the cost of extra
|
||||
computation.
|
||||
|
||||
Here is a simple example:
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
|
||||
>>> @jax.checkpoint
|
||||
... def g(x):
|
||||
... y = jnp.sin(x)
|
||||
... z = jnp.sin(y)
|
||||
... return z
|
||||
...
|
||||
>>> jax.value_and_grad(g)(2.0)
|
||||
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32, weak_type=True))
|
||||
|
||||
Here, the same value is produced whether or not the :func:`jax.checkpoint`
|
||||
decorator is present. When the decorator is not present, the values
|
||||
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))`` are computed on the forward
|
||||
pass and are stored for use in the backward pass, because they are needed
|
||||
on the backward pass and depend only on the primal inputs. When using
|
||||
:func:`jax.checkpoint`, the forward pass will compute only the primal outputs
|
||||
and only the primal inputs (``2.0``) will be stored for the backward pass.
|
||||
At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
|
||||
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
|
||||
|
||||
While ``jax.checkpoint`` controls what values are stored from the forward-pass
|
||||
to be used on the backward pass, the total amount of memory required to
|
||||
evaluate a function or its VJP depends on many additional internal details of
|
||||
that function. Those details include which numerical primitives are used,
|
||||
how they're composed, where jit and control flow primitives like scan
|
||||
are used, and other factors.
|
||||
|
||||
The :func:`jax.checkpoint` decorator can be applied recursively to express
|
||||
sophisticated autodiff rematerialization strategies. For example:
|
||||
|
||||
>>> def recursive_checkpoint(funs):
|
||||
... if len(funs) == 1:
|
||||
... return funs[0]
|
||||
... elif len(funs) == 2:
|
||||
... f1, f2 = funs
|
||||
... return lambda x: f1(f2(x))
|
||||
... else:
|
||||
... f1 = recursive_checkpoint(funs[:len(funs)//2])
|
||||
... f2 = recursive_checkpoint(funs[len(funs)//2:])
|
||||
... return lambda x: f1(jax.checkpoint(f2)(x))
|
||||
...
|
||||
"""
|
||||
if config.jax_new_checkpoint and not concrete:
|
||||
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy)
|
||||
if config.jax_new_checkpoint:
|
||||
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
|
||||
static_argnums=static_argnums)
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def remat_f(*args, **kwargs):
|
||||
f, args = _remat_static_argnums(fun, static_argnums, args)
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(f), in_tree)
|
||||
out_flat = pe.remat_call(flat_fun, *args_flat, name=flat_fun.__name__,
|
||||
concrete=concrete, prevent_cse=prevent_cse,
|
||||
differentiated=False,
|
||||
policy=policy)
|
||||
differentiated=False, policy=policy)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
return remat_f
|
||||
checkpoint.__doc__ = new_checkpoint.__doc__
|
||||
remat = checkpoint # type: ignore
|
||||
|
||||
|
||||
def named_call(
|
||||
fun: Callable[..., Any],
|
||||
*,
|
||||
|
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from jax import core
|
||||
|
||||
@ -145,7 +146,7 @@ class ConcretizationTypeError(JAXTypeError):
|
||||
and concrete vs. abstract values, you may want to read
|
||||
:ref:`faq-different-kinds-of-jax-values`.
|
||||
"""
|
||||
def __init__(self, tracer: "core.Tracer", context: str = ""):
|
||||
def __init__(self, tracer: core.Tracer, context: str = ""):
|
||||
super().__init__(
|
||||
"Abstract tracer value encountered where concrete value is expected: "
|
||||
f"{tracer}\n{context}{tracer._origin_msg()}\n")
|
||||
@ -237,7 +238,7 @@ class NonConcreteBooleanIndexError(JAXIndexError):
|
||||
>>> manual_clip(jnp.arange(-2, 2))
|
||||
DeviceArray([0, 0, 0, 1], dtype=int32)
|
||||
"""
|
||||
def __init__(self, tracer: "core.Tracer"):
|
||||
def __init__(self, tracer: core.Tracer):
|
||||
super().__init__(
|
||||
f"Array boolean indices must be concrete; got {tracer}\n")
|
||||
|
||||
@ -316,7 +317,7 @@ class TracerArrayConversionError(JAXTypeError):
|
||||
and concrete vs. abstract values, you may want to read
|
||||
:ref:`faq-different-kinds-of-jax-values`.
|
||||
"""
|
||||
def __init__(self, tracer: "core.Tracer"):
|
||||
def __init__(self, tracer: core.Tracer):
|
||||
super().__init__(
|
||||
"The numpy.ndarray conversion method __array__() was called on "
|
||||
f"the JAX Tracer object {tracer}{tracer._origin_msg()}")
|
||||
@ -409,7 +410,7 @@ class TracerIntegerConversionError(JAXTypeError):
|
||||
and concrete vs. abstract values, you may want to read
|
||||
:ref:`faq-different-kinds-of-jax-values`.
|
||||
"""
|
||||
def __init__(self, tracer: "core.Tracer"):
|
||||
def __init__(self, tracer: core.Tracer):
|
||||
super().__init__(
|
||||
f"The __index__() method was called on the JAX Tracer object {tracer}")
|
||||
|
||||
|
@ -908,7 +908,7 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
||||
local_devices = list(api.local_devices())
|
||||
if len(local_devices) < size:
|
||||
raise unittest.SkipTest(f"Test requires {size} local devices")
|
||||
mesh_devices = np.array(local_devices[:size]).reshape(shape)
|
||||
mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore
|
||||
with Mesh(mesh_devices, axis_names):
|
||||
yield
|
||||
|
||||
|
@ -3911,8 +3911,11 @@ class RematTest(jtu.JaxTestCase):
|
||||
expected = f_lin_expected(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@unittest.skipIf(config.jax_new_checkpoint, "test is for old remat only")
|
||||
def test_remat_grad_python_control_flow(self):
|
||||
@partial(api.remat, concrete=True)
|
||||
# See test `test_remat_grad_python_control_flow_static_argnums` for the
|
||||
# new recommended way to express this computation.
|
||||
|
||||
def g(x):
|
||||
if x > 0:
|
||||
return lax.sin(x), 3.
|
||||
@ -3923,6 +3926,75 @@ class RematTest(jtu.JaxTestCase):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, "static_argnums"):
|
||||
g = api.remat(g, concrete=True)
|
||||
|
||||
ans = f(2.)
|
||||
expected = np.sin(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(f)(2.)
|
||||
expected = np.cos(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@unittest.skipIf(config.jax_new_checkpoint, "new remat raises error here")
|
||||
def test_remat_concrete_deprecation_warning(self):
|
||||
def g(x):
|
||||
if x > 0:
|
||||
return lax.sin(x), 3.
|
||||
else:
|
||||
return lax.cos(x), 4.
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, "static_argnums"):
|
||||
_ = api.remat(g, concrete=True)
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat warns here")
|
||||
def test_remat_concrete_deprecation_error(self):
|
||||
def g(x):
|
||||
if x > 0:
|
||||
return lax.sin(x), 3.
|
||||
else:
|
||||
return lax.cos(x), 4.
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, "static_argnums"):
|
||||
_ = api.remat(g, concrete=True)
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat different error")
|
||||
def test_remat_concrete_error(self):
|
||||
@api.remat # no static_argnums or concrete
|
||||
def g(x):
|
||||
if x > 0:
|
||||
return lax.sin(x)
|
||||
else:
|
||||
return lax.cos(x)
|
||||
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
|
||||
g(3.)
|
||||
|
||||
@partial(api.remat, static_argnums=(0,)) # using static_argnums but...
|
||||
def g(x):
|
||||
if x > 0: # jnp operations still get staged!
|
||||
return lax.sin(x)
|
||||
else:
|
||||
return lax.cos(x)
|
||||
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
|
||||
g(jnp.array(3.))
|
||||
|
||||
def test_remat_grad_python_control_flow_static_argnums(self):
|
||||
@partial(api.remat, static_argnums=(0,))
|
||||
def g(x):
|
||||
with jax.ensure_compile_time_eval():
|
||||
x_pos = x > 0
|
||||
if x_pos:
|
||||
return lax.sin(x), 3.
|
||||
else:
|
||||
return lax.cos(x), 4.
|
||||
|
||||
def f(x):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
ans = f(2.)
|
||||
expected = np.sin(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@ -3931,6 +4003,89 @@ class RematTest(jtu.JaxTestCase):
|
||||
expected = np.cos(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_grad_python_control_flow_unhashable_static_argnums(self):
|
||||
@partial(api.remat, static_argnums=(0,))
|
||||
def g(x):
|
||||
x = x.val
|
||||
with jax.ensure_compile_time_eval():
|
||||
x_pos = x > 0
|
||||
if x_pos:
|
||||
return lax.sin(x), 3.
|
||||
else:
|
||||
return lax.cos(x), 4.
|
||||
|
||||
def f(x):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
class A:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
def __hash__(self):
|
||||
raise TypeError
|
||||
|
||||
ans = f(A(2.))
|
||||
expected = np.sin(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(lambda x: f(A(x)))(2.)
|
||||
expected = np.cos(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat retraces here")
|
||||
def test_remat_retracing(self):
|
||||
# This is *not* a very important behavior; remat doesn't need to provide
|
||||
# caching guarantees with the same importance as jit. But even so, in the
|
||||
# interest of not redoing tracing work (and thus make jax.remat more
|
||||
# feasible to use in eager mode), this test checks that we don't re-trace
|
||||
# the remat-decorated function.
|
||||
count = 0
|
||||
|
||||
@api.remat
|
||||
def g(x):
|
||||
nonlocal count
|
||||
count += 1
|
||||
return lax.sin(x), 3.
|
||||
|
||||
def f(x):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
for _ in range(10):
|
||||
y = f(2.)
|
||||
y.block_until_ready()
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat retraces here")
|
||||
def test_remat_static_agnums_retracing(self):
|
||||
# This is *not* a super important behavior; remat doesn't need to provide
|
||||
# caching guarantees with the same importance as jit. But even so, in the
|
||||
# interest of not redoing tracing work (and thus make jax.remat more
|
||||
# feasible to use in eager mode), this test checks that we don't re-trace
|
||||
# the remat-decorated function *even with static_argnums*. See also the
|
||||
# above test, which doesn't check for static_argnums.
|
||||
count = 0
|
||||
|
||||
@partial(api.remat, static_argnums=(0,))
|
||||
def g(x):
|
||||
nonlocal count
|
||||
count += 1
|
||||
with jax.ensure_compile_time_eval():
|
||||
x_pos = x > 0
|
||||
if x_pos:
|
||||
return lax.sin(x), 3.
|
||||
else:
|
||||
return lax.cos(x), 4.
|
||||
|
||||
def f(x):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
for _ in range(10):
|
||||
y = f(2.)
|
||||
y.block_until_ready()
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
@ -4784,6 +4939,15 @@ class RematTest(jtu.JaxTestCase):
|
||||
f_vjp(1.)[0].block_until_ready()
|
||||
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
|
||||
|
||||
def test_vjp_caching_static_argnums(self):
|
||||
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
|
||||
static_argnums=(1,))
|
||||
_, f_vjp = jax.vjp(identity, 1., True)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
f_vjp(1.)[0].block_until_ready()
|
||||
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here")
|
||||
def test_fwd_caching(self):
|
||||
# see above test also
|
||||
@ -4794,6 +4958,16 @@ class RematTest(jtu.JaxTestCase):
|
||||
y.block_until_ready()
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here")
|
||||
def test_fwd_caching_static_argnums(self):
|
||||
# see above test also
|
||||
identity = jax.checkpoint(jax.jit(lambda x: 2 * x), static_argnums=(0,))
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
y = identity(1.)
|
||||
y.block_until_ready()
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
|
Loading…
x
Reference in New Issue
Block a user