prepare to switch to new remat

This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
  * add benchmarks for new remat eager performance, and some caching to make those
    benchmarks fast
  * warn when the old-remat-exclusive `concrete` feature is used, with an
    actionable message pointing to the new recommended approach involving static_argnums
  * add the static_argnums parameter to both new and old remt
  * update docstrings (and de-duplicate them to)
  * add new tests, especially around caching and errors/warnings
This commit is contained in:
Matthew Johnson 2022-08-02 14:49:16 -07:00
parent 40c12e376e
commit e3a92d52ba
6 changed files with 424 additions and 151 deletions

View File

@ -22,6 +22,7 @@ from jax import lax
from jax._src import test_util as jtu
from jax.experimental import sparse
from jax._src.api_util import shaped_abstractify # technically not an api fn
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
from jax._src.lib import xla_client as xc
from jax.interpreters import pxla
from jax.experimental import sharding
@ -44,6 +45,9 @@ def required_devices(num_devices_required):
return helper2
return helper1
def swap(a, b):
return b, a
@google_benchmark.register
def eager_unary_dispatch(state):
@ -504,8 +508,35 @@ def bench_pjit_check_aval_sharding(state):
pjit_lib.pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)
def swap(a, b):
return b, a
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_remat_eager_retracing_overheads(state):
def double_compose(f):
return lambda x: f(f(x))
f = jnp.sin
for _ in range(6):
f = double_compose(f)
f = double_compose(checkpoint(f))
while state:
y, _ = jax.vjp(f, 3.)
y.block_until_ready()
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_remat_eager_retracing_overheads_static_argnums(state):
def double_compose(f):
return lambda x, y: f(f(x, y), y)
f = lambda x, _: jnp.sin(x)
for _ in range(6):
f = double_compose(f)
f = double_compose(checkpoint(f, static_argnums=(1,)))
while state:
y, _ = jax.vjp(f, 3., True)
y.block_until_ready()
if __name__ == "__main__":

View File

@ -14,7 +14,7 @@
from functools import partial
import operator as op
from typing import Callable, Optional, List, Tuple, Sequence, Union
from typing import Callable, Optional, List, Tuple, Sequence, Union, Any
import types
import jax
@ -96,8 +96,9 @@ checkpoint_policies = types.SimpleNamespace(
### Main API
def checkpoint(fun: Callable, prevent_cse: bool = True,
policy: Optional[Callable[..., bool]] = None
def checkpoint(fun: Callable, *, prevent_cse: bool = True,
policy: Optional[Callable[..., bool]] = None,
static_argnums: Union[int, Tuple[int, ...]] = (),
) -> Callable:
"""Make ``fun`` recompute internal linearization points when differentiated.
@ -131,25 +132,26 @@ def checkpoint(fun: Callable, prevent_cse: bool = True,
from the default of storing all intermediate linearization points to
recomputing them. Its arguments and return value should be arrays,
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
concrete: Optional, boolean indicating whether ``fun`` may involve
value-dependent Python control flow (default False). Support for such
control flow is optional, and disabled by default, because in some
edge-case compositions with :func:`jax.jit` it can lead to some extra
computation.
prevent_cse: Optional, boolean indicating whether to prevent common
subexpression elimination (CSE) optimizations in the HLO generated from
differentiation. This CSE prevention has costs because it can foil other
optimizations, and because it can incur high overheads on some backends,
especially GPU. The default is True because otherwise, under a ``jit`` or
``pmap``, CSE can defeat the purpose of this decorator. But in some
settings, like when used inside a ``scan``, this CSE prevention mechanism
is unnecessary, in which case ``prevent_cse`` can be set to False.
policy: This is an experimental feature and the API is likely to change.
Optional callable, one of the attributes of ``jax.checkpoint_policies``,
which takes as input a type-level specification of a first-order primitive
application and returns a boolean indicating whether the corresponding
output value(s) can be saved as a residual (or, if not, instead must be
recomputed in the (co)tangent computation).
prevent_cse: Optional, boolean keyword-only argument indicating whether to
prevent common subexpression elimination (CSE) optimizations in the HLO
generated from differentiation. This CSE prevention has costs because it
can foil other optimizations, and because it can incur high overheads on
some backends, especially GPU. The default is True because otherwise,
under a ``jit`` or ``pmap``, CSE can defeat the purpose of this decorator.
But in some settings, like when used inside a ``scan``, this CSE
prevention mechanism is unnecessary, in which case ``prevent_cse`` can be
set to False.
static_argnums: Optional, int or sequence of ints, a keyword-only argument
indicating which argument values on which to specialize for tracing and
caching purposes. Specifying arguments as static can avoid
ConcretizationTypeErrors when tracing, but at the cost of more retracing
overheads. See the example below.
policy: Optional, callable keyword-only argument. It should be one of the
attributes of ``jax.checkpoint_policies``. The callable takes as input a
type-level specification of a first-order primitive application and
returns a boolean indicating whether the corresponding output value(s) can
be saved as residuals (or instead must be recomputed in the (co)tangent
computation if needed).
Returns:
A function (callable) with the same input/output behavior as ``fun`` but
@ -203,13 +205,46 @@ def checkpoint(fun: Callable, prevent_cse: bool = True,
... f2 = recursive_checkpoint(funs[len(funs)//2:])
... return lambda x: f1(jax.checkpoint(f2)(x))
...
If ``fun`` involves Python control flow which depends on argument values,
it may be necessary to use the ``static_argums`` parameter. For example,
consider a boolean flag argument:
from functools import partial
@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, is_training):
if is_training:
...
else:
...
Here, the use of ``static_argnums`` is necessary because the ``if`` statement
depends on the value of ``is_training``. The only cost to using
``static_argnums`` is more retracing overheads: in the example, ``foo`` must
be retraced for every new value of ``is_training``. In some cases it may
additionally be necessary to use ``jax.ensure_compile_time_eval``:
@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, y):
with jax.ensure_compile_time_eval():
y_pos = y > 0
if y_pos:
...
else:
...
As an alternative to using ``static_argnums`` (and
``jax.ensure_compile_time_eval``), it may be easier to compute some values
outside the ``jax.checkpoint``-decorated function and then close over them.
"""
@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):
fun_, args = _remat_static_argnums(fun, static_argnums, args)
args_flat, in_tree = tree_flatten((args, kwargs))
in_avals = [shaped_abstractify(x) for x in args_flat]
jaxpr, consts, out_tree = _trace_to_jaxpr(fun, in_tree, tuple(in_avals))
jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals))
out_flat = remat_p.bind(
*consts, *args_flat, jaxpr=jaxpr, prevent_cse=prevent_cse,
differentiated=False, policy=policy)
@ -218,11 +253,90 @@ def checkpoint(fun: Callable, prevent_cse: bool = True,
remat = checkpoint # alias
# This function is similar to api_util.argnums_partial, except the error
# messages are specific to jax.remat (and thus more actionable), the
# hashing/caching behavior is slightly different, and this function accepts a
# boolean for static_argnums. Perhaps the two could be de-duplicated.
def _remat_static_argnums(fun, static_argnums, args):
if type(static_argnums) is int:
static_argnums = (static_argnums,)
elif not (type(static_argnums) is tuple and
all(type(d) is int for d in static_argnums)):
raise TypeError("the `static_argnums` argument to `jax.checkpoint` / "
"`jax.remat` must be an int, tuple of ints or, bool, but "
f"got value {static_argnums}")
if not all(-len(args) <= d < len(args) for d in static_argnums):
raise ValueError("the `static_argnums` argument to `jax.checkpoint` / "
"`jax.remat` can only take integer values greater than or "
"equal to `-len(args)` and less than `len(args)`, but got "
f"{static_argnums}")
if not static_argnums:
return fun, args
nargs = len(args)
static_argnums_ = frozenset(d % len(args) for d in static_argnums)
dyn_args, static_args = [], []
for i, x in enumerate(args):
if i in static_argnums_: static_args.append(WrapHashably(x))
else: dyn_args.append(x)
new_fun = _dyn_args_fun(fun, static_argnums_, tuple(static_args), nargs)
return new_fun, dyn_args
class WrapHashably:
val: Any
hash: Optional[int] = None
hashable: bool
def __init__(self, val):
self.val = val
try:
self.hash = hash(val)
self.hashable = True
except:
self.hash = id(val)
self.hashable = False
def __hash__(self):
return self.hash
def __eq__(self, other):
if isinstance(other, WrapHashably):
try: return self.val == other.val
except: return self.val is other.val
return False
# This caching is useful to avoid retracing even when static_argnums is used.
# See api_benchmark.py:bench_remat_eager_retracing_overheads_static_argnums.
# On that benchmark, including this caching makes a ~10x difference (which can
# be made arbitrary large by involving larger functions to be traced).
@weakref_lru_cache
def _dyn_args_fun(fun: Callable, static_argnums: Tuple[int, ...],
static_args: Tuple[WrapHashably, ...], nargs: int):
def new_fun(*dyn_args, **kwargs):
static_args_, dyn_args_ = iter(static_args), iter(dyn_args)
full_args = [next(static_args_).val if i in static_argnums
else next(dyn_args_) for i in range(nargs)]
return fun(*full_args, **kwargs)
return new_fun
# This helper is similar to those in control_flow/common.py, but with
# remat-specific errors.
@weakref_lru_cache
def _trace_to_jaxpr(fun, in_tree, in_avals):
debug = pe.debug_info(fun, in_tree, False, "checkpoint")
debug = pe.debug_info(fun, in_tree, True, "checkpoint")
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
try:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
except core.ConcretizationTypeError as e:
msg, = e.args
new_msg = msg + "\n\n" + (
"Consider using the `static_argnums` parameter for `jax.remat` or "
"`jax.checkpoint`. See the `jax.checkpoint` docstring and its example "
"involving `static_argnums`:\n"
"https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html"
"\n")
new_e = core.ConcretizationTypeError.__new__(core.ConcretizationTypeError)
new_e.args = (new_msg,)
raise new_e from None
return pe.convert_constvars_jaxpr(jaxpr), consts, out_tree()

View File

@ -53,7 +53,8 @@ from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple, argnames_partial_except, validate_argnames, validate_argnums)
shaped_abstractify, _ensure_str_tuple, argnames_partial_except,
validate_argnames, validate_argnums)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_bridge as xb
@ -63,7 +64,7 @@ from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, new_name_stack, wrap_name, cache,
wraps, HashableFunction)
wraps, HashableFunction, weakref_lru_cache)
# Unused imports to be exported
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
@ -71,6 +72,7 @@ from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
process_count, host_id, host_ids,
host_count, default_backend)
from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint
from jax._src.ad_checkpoint import _remat_static_argnums
from jax.core import ShapedArray, raise_to_shaped
from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
@ -660,12 +662,12 @@ def disable_jit():
"""Context manager that disables :py:func:`jit` behavior under its dynamic context.
For debugging it is useful to have a mechanism that disables :py:func:`jit`
everywhere in a dynamic context. Note that this not only disables explicit uses
of `jit` by the user, but will also remove any implicit JIT compilation used by the
JAX library: this includes implicit JIT computation of `body` and `cond`
functions passed to higher-level primitives like :func:`scan` and :func:`while_loop`,
JIT used in implementations of :mod:`jax.numpy` functions, and any other case where
`jit` is used within an API's implementation.
everywhere in a dynamic context. Note that this not only disables explicit
uses of `jit` by the user, but will also remove any implicit JIT compilation
used by the JAX library: this includes implicit JIT computation of `body` and
`cond` functions passed to higher-level primitives like :func:`scan` and
:func:`while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
and any other case where `jit` is used within an API's implementation.
Values that have a data dependence on the arguments to a jitted function are
traced and abstracted. For example, an abstract value may be a
@ -3046,130 +3048,81 @@ def eval_shape(fun: Callable, *args, **kwargs):
return tree_unflatten(out_tree(), out)
def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
def checkpoint(fun: Callable, *,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, Tuple[int, ...], bool] = (),
policy: Optional[Callable[..., bool]] = None,
) -> Callable:
"""Make ``fun`` recompute internal linearization points when differentiated.
if concrete:
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
"in its place, you can use its `static_argnums` option, and if "
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
"\n"
"For example, if using `concrete=True` for an `is_training` flag:\n"
"\n"
" from functools import partial\n"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, is_training):\n"
" if is_training:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"replace it with a use of `static_argnums`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, is_training):\n"
" ...\n"
"\n"
"If jax.numpy operations need to be performed on static arguments, "
"we can use the `jax.ensure_compile_time_eval()` context manager. "
"For example, we can replace this use of `concrete=True`\n:"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, y):\n"
" if y > 0:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"with this combination of `static_argnums` and "
"`jax.ensure_compile_time_eval()`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, y):\n"
" with jax.ensure_compile_time_eval():\n"
" y_pos = y > 0\n"
" if y_pos:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n")
if config.jax_new_checkpoint:
raise NotImplementedError(msg)
else:
warn(msg, DeprecationWarning)
The :func:`jax.checkpoint` decorator, aliased to ``jax.remat``, provides a
way to trade off computation time and memory cost in the context of automatic
differentiation, especially with reverse-mode autodiff like :func:`jax.grad`
and :func:`jax.vjp` but also with :func:`jax.linearize`.
When differentiating a function in reverse-mode, by default all the
linearization points (e.g. inputs to elementwise nonlinear primitive
operations) are stored when evaluating the forward pass so that they can be
reused on the backward pass. This evaluation strategy can lead to a high
memory cost, or even to poor performance on hardware accelerators where memory
access is much more expensive than FLOPs.
An alternative evaluation strategy is for some of the linearization points to
be recomputed (i.e. rematerialized) rather than stored. This approach can
reduce memory usage at the cost of increased computation.
This function decorator produces a new version of ``fun`` which follows
the rematerialization strategy rather than the default store-everything
strategy. That is, it returns a new version of ``fun`` which, when
differentiated, doesn't store any of its intermediate linearization points.
Instead, these linearization points are recomputed from the function's saved
inputs.
See the examples below.
Args:
fun: Function for which the autodiff evaluation strategy is to be changed
from the default of storing all intermediate linearization points to
recomputing them. Its arguments and return value should be arrays,
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
concrete: Optional, boolean indicating whether ``fun`` may involve
value-dependent Python control flow (default False). Support for such
control flow is optional, and disabled by default, because in some
edge-case compositions with :func:`jax.jit` it can lead to some extra
computation.
prevent_cse: Optional, boolean indicating whether to prevent common
subexpression elimination (CSE) optimizations in the HLO generated from
differentiation. This CSE prevention has costs because it can foil other
optimizations, and because it can incur high overheads on some backends,
especially GPU. The default is True because otherwise, under a ``jit`` or
``pmap``, CSE can defeat the purpose of this decorator. But in some
settings, like when used inside a ``scan``, this CSE prevention mechanism
is unnecessary, in which case ``prevent_cse`` can be set to False.
policy: Optional callable, one of the attributes of
``jax.checkpoint_policies``, which takes as input a type-level
specification of a first-order primitive application and returns a boolean
indicating whether the corresponding output value(s) can be saved as a
residual (or instead must be recomputed in the (co)tangent computation if
needed).
Returns:
A function (callable) with the same input/output behavior as ``fun`` but
which, when differentiated using e.g. :func:`jax.grad`, :func:`jax.vjp`, or
:func:`jax.linearize`, recomputes rather than stores intermediate
linearization points, thus potentially saving memory at the cost of extra
computation.
Here is a simple example:
>>> import jax
>>> import jax.numpy as jnp
>>> @jax.checkpoint
... def g(x):
... y = jnp.sin(x)
... z = jnp.sin(y)
... return z
...
>>> jax.value_and_grad(g)(2.0)
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32, weak_type=True))
Here, the same value is produced whether or not the :func:`jax.checkpoint`
decorator is present. When the decorator is not present, the values
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))`` are computed on the forward
pass and are stored for use in the backward pass, because they are needed
on the backward pass and depend only on the primal inputs. When using
:func:`jax.checkpoint`, the forward pass will compute only the primal outputs
and only the primal inputs (``2.0``) will be stored for the backward pass.
At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
While ``jax.checkpoint`` controls what values are stored from the forward-pass
to be used on the backward pass, the total amount of memory required to
evaluate a function or its VJP depends on many additional internal details of
that function. Those details include which numerical primitives are used,
how they're composed, where jit and control flow primitives like scan
are used, and other factors.
The :func:`jax.checkpoint` decorator can be applied recursively to express
sophisticated autodiff rematerialization strategies. For example:
>>> def recursive_checkpoint(funs):
... if len(funs) == 1:
... return funs[0]
... elif len(funs) == 2:
... f1, f2 = funs
... return lambda x: f1(f2(x))
... else:
... f1 = recursive_checkpoint(funs[:len(funs)//2])
... f2 = recursive_checkpoint(funs[len(funs)//2:])
... return lambda x: f1(jax.checkpoint(f2)(x))
...
"""
if config.jax_new_checkpoint and not concrete:
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy)
if config.jax_new_checkpoint:
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)
@wraps(fun)
@api_boundary
def remat_f(*args, **kwargs):
f, args = _remat_static_argnums(fun, static_argnums, args)
args_flat, in_tree = tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
flat_fun, out_tree = flatten_fun(lu.wrap_init(f), in_tree)
out_flat = pe.remat_call(flat_fun, *args_flat, name=flat_fun.__name__,
concrete=concrete, prevent_cse=prevent_cse,
differentiated=False,
policy=policy)
differentiated=False, policy=policy)
return tree_unflatten(out_tree(), out_flat)
return remat_f
checkpoint.__doc__ = new_checkpoint.__doc__
remat = checkpoint # type: ignore
def named_call(
fun: Callable[..., Any],
*,

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from jax import core
@ -145,7 +146,7 @@ class ConcretizationTypeError(JAXTypeError):
and concrete vs. abstract values, you may want to read
:ref:`faq-different-kinds-of-jax-values`.
"""
def __init__(self, tracer: "core.Tracer", context: str = ""):
def __init__(self, tracer: core.Tracer, context: str = ""):
super().__init__(
"Abstract tracer value encountered where concrete value is expected: "
f"{tracer}\n{context}{tracer._origin_msg()}\n")
@ -237,7 +238,7 @@ class NonConcreteBooleanIndexError(JAXIndexError):
>>> manual_clip(jnp.arange(-2, 2))
DeviceArray([0, 0, 0, 1], dtype=int32)
"""
def __init__(self, tracer: "core.Tracer"):
def __init__(self, tracer: core.Tracer):
super().__init__(
f"Array boolean indices must be concrete; got {tracer}\n")
@ -316,7 +317,7 @@ class TracerArrayConversionError(JAXTypeError):
and concrete vs. abstract values, you may want to read
:ref:`faq-different-kinds-of-jax-values`.
"""
def __init__(self, tracer: "core.Tracer"):
def __init__(self, tracer: core.Tracer):
super().__init__(
"The numpy.ndarray conversion method __array__() was called on "
f"the JAX Tracer object {tracer}{tracer._origin_msg()}")
@ -409,7 +410,7 @@ class TracerIntegerConversionError(JAXTypeError):
and concrete vs. abstract values, you may want to read
:ref:`faq-different-kinds-of-jax-values`.
"""
def __init__(self, tracer: "core.Tracer"):
def __init__(self, tracer: core.Tracer):
super().__init__(
f"The __index__() method was called on the JAX Tracer object {tracer}")

View File

@ -908,7 +908,7 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
local_devices = list(api.local_devices())
if len(local_devices) < size:
raise unittest.SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore
with Mesh(mesh_devices, axis_names):
yield

View File

@ -3911,8 +3911,11 @@ class RematTest(jtu.JaxTestCase):
expected = f_lin_expected(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
@unittest.skipIf(config.jax_new_checkpoint, "test is for old remat only")
def test_remat_grad_python_control_flow(self):
@partial(api.remat, concrete=True)
# See test `test_remat_grad_python_control_flow_static_argnums` for the
# new recommended way to express this computation.
def g(x):
if x > 0:
return lax.sin(x), 3.
@ -3923,6 +3926,75 @@ class RematTest(jtu.JaxTestCase):
x, _ = g(x)
return x
with self.assertWarnsRegex(DeprecationWarning, "static_argnums"):
g = api.remat(g, concrete=True)
ans = f(2.)
expected = np.sin(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(f)(2.)
expected = np.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
@unittest.skipIf(config.jax_new_checkpoint, "new remat raises error here")
def test_remat_concrete_deprecation_warning(self):
def g(x):
if x > 0:
return lax.sin(x), 3.
else:
return lax.cos(x), 4.
with self.assertWarnsRegex(DeprecationWarning, "static_argnums"):
_ = api.remat(g, concrete=True)
@unittest.skipIf(not config.jax_new_checkpoint, "old remat warns here")
def test_remat_concrete_deprecation_error(self):
def g(x):
if x > 0:
return lax.sin(x), 3.
else:
return lax.cos(x), 4.
with self.assertRaisesRegex(NotImplementedError, "static_argnums"):
_ = api.remat(g, concrete=True)
@unittest.skipIf(not config.jax_new_checkpoint, "old remat different error")
def test_remat_concrete_error(self):
@api.remat # no static_argnums or concrete
def g(x):
if x > 0:
return lax.sin(x)
else:
return lax.cos(x)
with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
g(3.)
@partial(api.remat, static_argnums=(0,)) # using static_argnums but...
def g(x):
if x > 0: # jnp operations still get staged!
return lax.sin(x)
else:
return lax.cos(x)
with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
g(jnp.array(3.))
def test_remat_grad_python_control_flow_static_argnums(self):
@partial(api.remat, static_argnums=(0,))
def g(x):
with jax.ensure_compile_time_eval():
x_pos = x > 0
if x_pos:
return lax.sin(x), 3.
else:
return lax.cos(x), 4.
def f(x):
x, _ = g(x)
return x
ans = f(2.)
expected = np.sin(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
@ -3931,6 +4003,89 @@ class RematTest(jtu.JaxTestCase):
expected = np.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_grad_python_control_flow_unhashable_static_argnums(self):
@partial(api.remat, static_argnums=(0,))
def g(x):
x = x.val
with jax.ensure_compile_time_eval():
x_pos = x > 0
if x_pos:
return lax.sin(x), 3.
else:
return lax.cos(x), 4.
def f(x):
x, _ = g(x)
return x
class A:
def __init__(self, val):
self.val = val
def __hash__(self):
raise TypeError
ans = f(A(2.))
expected = np.sin(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(lambda x: f(A(x)))(2.)
expected = np.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
@unittest.skipIf(not config.jax_new_checkpoint, "old remat retraces here")
def test_remat_retracing(self):
# This is *not* a very important behavior; remat doesn't need to provide
# caching guarantees with the same importance as jit. But even so, in the
# interest of not redoing tracing work (and thus make jax.remat more
# feasible to use in eager mode), this test checks that we don't re-trace
# the remat-decorated function.
count = 0
@api.remat
def g(x):
nonlocal count
count += 1
return lax.sin(x), 3.
def f(x):
x, _ = g(x)
return x
for _ in range(10):
y = f(2.)
y.block_until_ready()
self.assertEqual(count, 1)
@unittest.skipIf(not config.jax_new_checkpoint, "old remat retraces here")
def test_remat_static_agnums_retracing(self):
# This is *not* a super important behavior; remat doesn't need to provide
# caching guarantees with the same importance as jit. But even so, in the
# interest of not redoing tracing work (and thus make jax.remat more
# feasible to use in eager mode), this test checks that we don't re-trace
# the remat-decorated function *even with static_argnums*. See also the
# above test, which doesn't check for static_argnums.
count = 0
@partial(api.remat, static_argnums=(0,))
def g(x):
nonlocal count
count += 1
with jax.ensure_compile_time_eval():
x_pos = x > 0
if x_pos:
return lax.sin(x), 3.
else:
return lax.cos(x), 4.
def f(x):
x, _ = g(x)
return x
for _ in range(10):
y = f(2.)
y.block_until_ready()
self.assertEqual(count, 1)
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
@ -4784,6 +4939,15 @@ class RematTest(jtu.JaxTestCase):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
def test_vjp_caching_static_argnums(self):
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
static_argnums=(1,))
_, f_vjp = jax.vjp(identity, 1., True)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
@unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here")
def test_fwd_caching(self):
# see above test also
@ -4794,6 +4958,16 @@ class RematTest(jtu.JaxTestCase):
y.block_until_ready()
self.assertEqual(count[0], 1)
@unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here")
def test_fwd_caching_static_argnums(self):
# see above test also
identity = jax.checkpoint(jax.jit(lambda x: 2 * x), static_argnums=(0,))
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
for _ in range(20):
y = identity(1.)
y.block_until_ready()
self.assertEqual(count[0], 1)
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [