2021-10-12 20:06:38 -07:00
|
|
|
# Copyright 2021 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from functools import partial
|
2021-10-13 18:21:20 -07:00
|
|
|
from typing import Callable, Optional, List, Tuple
|
2021-10-12 20:06:38 -07:00
|
|
|
import types
|
|
|
|
|
|
|
|
import jax
|
|
|
|
from jax import core
|
|
|
|
from jax import linear_util as lu
|
|
|
|
from jax.interpreters import ad
|
|
|
|
from jax.interpreters import batching
|
|
|
|
from jax.interpreters import partial_eval as pe
|
|
|
|
from jax.interpreters import xla
|
|
|
|
from jax.tree_util import tree_flatten, tree_unflatten
|
|
|
|
from jax._src import ad_util
|
|
|
|
from jax._src import source_info_util
|
|
|
|
from jax._src.api_util import flatten_fun
|
|
|
|
from jax._src.traceback_util import api_boundary
|
|
|
|
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
|
|
|
|
safe_zip)
|
|
|
|
|
2021-10-13 18:21:20 -07:00
|
|
|
source_info_util.register_exclusion(__file__)
|
|
|
|
|
2021-10-12 20:06:38 -07:00
|
|
|
# TODO(mattjj): before this can be the standard remat implementation, we must:
|
|
|
|
# [ ] fix up callers who use the 'concrete' option (now removed)
|
|
|
|
# [ ] implement remat-of-control-flow-primitives (passing through the policy)
|
|
|
|
|
|
|
|
map = safe_map
|
|
|
|
zip = safe_zip
|
|
|
|
|
|
|
|
|
2021-10-13 18:21:20 -07:00
|
|
|
### Policies
|
|
|
|
|
|
|
|
def everything_saveable(*_, **__) -> bool:
|
|
|
|
# This is the effective policy without any use of jax.remat.
|
|
|
|
return True
|
|
|
|
|
|
|
|
def nothing_saveable(*_, **__) -> bool:
|
|
|
|
# This is the effective policy when using jax.remat without explicit policy.
|
|
|
|
return False
|
|
|
|
|
|
|
|
def checkpoint_dots(prim, *_, **__) -> bool:
|
|
|
|
# Matrix multiplies are expensive, so let's save them (and nothing else).
|
2021-10-12 20:06:38 -07:00
|
|
|
return prim in {jax._src.lax.lax.dot_general_p,
|
2021-11-23 12:35:23 -08:00
|
|
|
jax._src.lax.convolution.conv_general_dilated_p}
|
2021-10-12 20:06:38 -07:00
|
|
|
|
2021-10-13 18:21:20 -07:00
|
|
|
def dot_with_no_batch_dims(prim, *_, **params) -> bool:
|
|
|
|
# This is a useful heuristic for transformers.
|
2021-10-12 20:06:38 -07:00
|
|
|
if prim is jax._src.lax.lax.dot_general_p:
|
|
|
|
(_, _), (lhs_b, rhs_b) = params['dimension_numbers']
|
|
|
|
if not lhs_b and not rhs_b:
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
2021-10-13 18:21:20 -07:00
|
|
|
name_p = core.Primitive('name')
|
|
|
|
|
|
|
|
def save_any_names_but_these(*names_not_to_save):
|
|
|
|
# Save named values, excluding the names given.
|
|
|
|
names_not_to_save = frozenset(names_not_to_save)
|
|
|
|
def policy(prim, *_, **params):
|
|
|
|
if prim is name_p:
|
|
|
|
return params['name'] not in names_not_to_save
|
|
|
|
return False # only allow saving named values
|
|
|
|
return policy
|
|
|
|
|
|
|
|
def save_only_these_names(*names_which_can_be_saved):
|
|
|
|
# Save named values, only among the names given.
|
|
|
|
names_which_can_be_saved = set(names_which_can_be_saved)
|
|
|
|
def policy(prim, *_, **params):
|
|
|
|
if prim is name_p:
|
|
|
|
return params['name'] in names_which_can_be_saved
|
|
|
|
return False # not saveable unless it's in the allow-list
|
|
|
|
return policy
|
|
|
|
|
2021-10-12 20:06:38 -07:00
|
|
|
checkpoint_policies = types.SimpleNamespace(
|
2021-10-13 18:21:20 -07:00
|
|
|
everything_saveable=everything_saveable,
|
|
|
|
nothing_saveable=nothing_saveable,
|
|
|
|
checkpoint_dots=checkpoint_dots,
|
|
|
|
checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims,
|
|
|
|
save_any_names_but_these=save_any_names_but_these,
|
|
|
|
save_only_these_names=save_only_these_names,
|
2021-10-12 20:06:38 -07:00
|
|
|
)
|
|
|
|
|
2021-10-13 18:21:20 -07:00
|
|
|
|
|
|
|
### Main API
|
|
|
|
|
2021-10-12 20:06:38 -07:00
|
|
|
def checkpoint(fun: Callable, prevent_cse: bool = True,
|
|
|
|
policy: Optional[Callable[..., bool]] = None
|
|
|
|
) -> Callable:
|
|
|
|
"""Make ``fun`` recompute internal linearization points when differentiated.
|
|
|
|
|
|
|
|
The :func:`jax.checkpoint` decorator, aliased to ``jax.remat``, provides a
|
|
|
|
way to trade off computation time and memory cost in the context of automatic
|
|
|
|
differentiation, especially with reverse-mode autodiff like :func:`jax.grad`
|
|
|
|
and :func:`jax.vjp` but also with :func:`jax.linearize`.
|
|
|
|
|
|
|
|
When differentiating a function in reverse-mode, by default all the
|
|
|
|
linearization points (e.g. inputs to elementwise nonlinear primitive
|
|
|
|
operations) are stored when evaluating the forward pass so that they can be
|
|
|
|
reused on the backward pass. This evaluation strategy can lead to a high
|
|
|
|
memory cost, or even to poor performance on hardware accelerators where memory
|
|
|
|
access is much more expensive than FLOPs.
|
|
|
|
|
|
|
|
An alternative evaluation strategy is for some of the linearization points to
|
|
|
|
be recomputed (i.e. rematerialized) rather than stored. This approach can
|
|
|
|
reduce memory usage at the cost of increased computation.
|
|
|
|
|
|
|
|
This function decorator produces a new version of ``fun`` which follows
|
|
|
|
the rematerialization strategy rather than the default store-everything
|
|
|
|
strategy. That is, it returns a new version of ``fun`` which, when
|
|
|
|
differentiated, doesn't store any of its intermediate linearization points.
|
|
|
|
Instead, these linearization points are recomputed from the function's saved
|
|
|
|
inputs.
|
|
|
|
|
|
|
|
See the examples below.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
fun: Function for which the autodiff evaluation strategy is to be changed
|
|
|
|
from the default of storing all intermediate linearization points to
|
|
|
|
recomputing them. Its arguments and return value should be arrays,
|
|
|
|
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
|
|
|
|
concrete: Optional, boolean indicating whether ``fun`` may involve
|
|
|
|
value-dependent Python control flow (default False). Support for such
|
|
|
|
control flow is optional, and disabled by default, because in some
|
|
|
|
edge-case compositions with :func:`jax.jit` it can lead to some extra
|
|
|
|
computation.
|
|
|
|
prevent_cse: Optional, boolean indicating whether to prevent common
|
|
|
|
subexpression elimination (CSE) optimizations in the HLO generated from
|
|
|
|
differentiation. This CSE prevention has costs because it can foil other
|
|
|
|
optimizations, and because it can incur high overheads on some backends,
|
|
|
|
especially GPU. The default is True because otherwise, under a ``jit`` or
|
|
|
|
``pmap``, CSE can defeat the purpose of this decorator. But in some
|
|
|
|
settings, like when used inside a ``scan``, this CSE prevention mechanism
|
|
|
|
is unnecessary, in which case ``prevent_cse`` can be set to False.
|
|
|
|
policy: This is an experimental feature and the API is likely to change.
|
|
|
|
Optional callable, one of the attributes of ``jax.checkpoint_policies``,
|
|
|
|
which takes as input a type-level specification of a first-order primitive
|
|
|
|
application and returns a boolean indicating whether the corresponding
|
|
|
|
output value(s) can be saved as a residual (or, if not, instead must be
|
|
|
|
recomputed in the (co)tangent computation).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A function (callable) with the same input/output behavior as ``fun`` but
|
|
|
|
which, when differentiated using e.g. :func:`jax.grad`, :func:`jax.vjp`, or
|
|
|
|
:func:`jax.linearize`, recomputes rather than stores intermediate
|
|
|
|
linearization points, thus potentially saving memory at the cost of extra
|
|
|
|
computation.
|
|
|
|
|
|
|
|
Here is a simple example:
|
|
|
|
|
|
|
|
>>> import jax
|
|
|
|
>>> import jax.numpy as jnp
|
|
|
|
|
|
|
|
>>> @jax.checkpoint
|
|
|
|
... def g(x):
|
|
|
|
... y = jnp.sin(x)
|
|
|
|
... z = jnp.sin(y)
|
|
|
|
... return z
|
|
|
|
...
|
|
|
|
>>> jax.value_and_grad(g)(2.0)
|
|
|
|
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32))
|
|
|
|
|
|
|
|
Here, the same value is produced whether or not the :func:`jax.checkpoint`
|
|
|
|
decorator is present. When the decorator is not present, the values
|
|
|
|
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))`` are computed on the forward
|
|
|
|
pass and are stored for use in the backward pass, because they are needed
|
|
|
|
on the backward pass and depend only on the primal inputs. When using
|
|
|
|
:func:`jax.checkpoint`, the forward pass will compute only the primal outputs
|
|
|
|
and only the primal inputs (``2.0``) will be stored for the backward pass.
|
|
|
|
At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
|
|
|
|
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
|
|
|
|
|
|
|
|
While ``jax.checkpoint`` controls what values are stored from the forward-pass
|
|
|
|
to be used on the backward pass, the total amount of memory required to
|
|
|
|
evaluate a function or its VJP depends on many additional internal details of
|
|
|
|
that function. Those details include which numerical primitives are used,
|
|
|
|
how they're composed, where jit and control flow primitives like scan
|
|
|
|
are used, and other factors.
|
|
|
|
|
|
|
|
The :func:`jax.checkpoint` decorator can be applied recursively to express
|
|
|
|
sophisticated autodiff rematerialization strategies. For example:
|
|
|
|
|
|
|
|
>>> def recursive_checkpoint(funs):
|
|
|
|
... if len(funs) == 1:
|
|
|
|
... return funs[0]
|
|
|
|
... elif len(funs) == 2:
|
|
|
|
... f1, f2 = funs
|
|
|
|
... return lambda x: f1(f2(x))
|
|
|
|
... else:
|
|
|
|
... f1 = recursive_checkpoint(funs[:len(funs)//2])
|
|
|
|
... f2 = recursive_checkpoint(funs[len(funs)//2:])
|
|
|
|
... return lambda x: f1(jax.checkpoint(f2)(x))
|
|
|
|
...
|
|
|
|
"""
|
|
|
|
@wraps(fun)
|
|
|
|
@api_boundary
|
|
|
|
def fun_remat(*args, **kwargs):
|
|
|
|
args_flat, in_tree = tree_flatten((args, kwargs))
|
|
|
|
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
|
|
|
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
|
|
|
debug = pe.debug_info(fun, in_tree, False, "checkpoint")
|
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
|
|
|
out_flat = remat_p.bind(
|
|
|
|
*consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr),
|
|
|
|
prevent_cse=prevent_cse, differentiated=False, policy=policy)
|
|
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
|
|
return fun_remat
|
|
|
|
|
|
|
|
remat = checkpoint # alias
|
|
|
|
|
2021-10-13 18:21:20 -07:00
|
|
|
|
|
|
|
### Utilities
|
|
|
|
|
|
|
|
def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
|
|
|
|
args, in_tree = tree_flatten((args, kwargs))
|
|
|
|
|
|
|
|
def f_(*args):
|
|
|
|
args, kwargs = tree_unflatten(in_tree, args)
|
|
|
|
return f(*args, **kwargs)
|
|
|
|
|
|
|
|
jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])(*args).jaxpr
|
2021-10-14 11:32:09 -07:00
|
|
|
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
|
|
|
|
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}
|
|
|
|
|
2021-10-13 18:21:20 -07:00
|
|
|
results = []
|
|
|
|
|
2021-10-14 11:32:09 -07:00
|
|
|
for x in res_lits:
|
|
|
|
results.append((x.aval, 'from a literal'))
|
|
|
|
|
2021-10-13 18:21:20 -07:00
|
|
|
for v in jaxpr.constvars:
|
|
|
|
if v in res_vars:
|
2021-10-14 11:32:09 -07:00
|
|
|
results.append((v.aval, 'from a constant'))
|
2021-10-13 18:21:20 -07:00
|
|
|
|
|
|
|
assert len(jaxpr.invars) == len(args)
|
|
|
|
for i, v in enumerate(jaxpr.invars):
|
|
|
|
if v in res_vars:
|
|
|
|
src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}'
|
2021-10-14 11:32:09 -07:00
|
|
|
results.append((v.aval, src))
|
2021-10-13 18:21:20 -07:00
|
|
|
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
src = source_info_util.summarize(eqn.source_info)
|
|
|
|
for v in eqn.outvars:
|
|
|
|
if v in res_vars:
|
|
|
|
if eqn.primitive is name_p:
|
2021-10-14 18:49:56 -07:00
|
|
|
results.append((v.aval, f"named '{eqn.params['name']}' from {src}"))
|
2021-10-13 18:21:20 -07:00
|
|
|
else:
|
2021-10-14 11:32:09 -07:00
|
|
|
results.append((v.aval, f'from {src}'))
|
2021-10-13 18:21:20 -07:00
|
|
|
|
2021-10-14 11:32:09 -07:00
|
|
|
assert len(results) == len(jaxpr.outvars)
|
|
|
|
return results
|
2021-10-13 18:21:20 -07:00
|
|
|
|
|
|
|
def print_saved_residuals(f, *args, **kwargs):
|
|
|
|
for aval, src in saved_residuals(f, *args, **kwargs):
|
|
|
|
print(f'{aval.str_short(short_dtypes=True)} {src}')
|
|
|
|
|
|
|
|
|
|
|
|
### Implementation
|
|
|
|
|
2021-10-12 20:06:38 -07:00
|
|
|
remat_p = core.Primitive('remat2')
|
|
|
|
remat_p.multiple_results = True
|
|
|
|
|
|
|
|
@remat_p.def_impl
|
|
|
|
def remat_impl(*args, jaxpr, prevent_cse, differentiated, policy):
|
|
|
|
del prevent_cse, differentiated, policy # Unused.
|
|
|
|
return core.eval_jaxpr(jaxpr, (), *args)
|
|
|
|
|
|
|
|
@remat_p.def_abstract_eval
|
|
|
|
def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy):
|
|
|
|
del args, prevent_cse, differentiated, policy # Unused.
|
|
|
|
return [v.aval for v in jaxpr.outvars]
|
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
def remat_translation(ctx, avals_in, avals_out, *in_nodes,
|
2021-10-12 20:06:38 -07:00
|
|
|
jaxpr, prevent_cse, differentiated, policy):
|
|
|
|
del policy # Unused.
|
|
|
|
if differentiated and prevent_cse:
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
if ctx.platform == "gpu":
|
|
|
|
return xla._remat_using_while(ctx, in_nodes, "checkpoint", jaxpr)
|
2021-10-12 20:06:38 -07:00
|
|
|
else:
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
return xla._remat_using_cond(ctx, in_nodes, "checkpoint", jaxpr)
|
2021-10-12 20:06:38 -07:00
|
|
|
else:
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
return xla.jaxpr_subcomp(ctx, jaxpr, (), *in_nodes)
|
|
|
|
xla.register_translation(remat_p, remat_translation)
|
2021-10-12 20:06:38 -07:00
|
|
|
|
|
|
|
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
|
|
|
|
assert not jaxpr.constvars
|
|
|
|
in_nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
|
|
|
jaxpr_ = core.ClosedJaxpr(jaxpr, ())
|
|
|
|
jaxpr_jvp_, out_nonzeros = ad.jvp_jaxpr(jaxpr_, in_nonzeros, False)
|
|
|
|
nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
|
|
|
jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr)
|
|
|
|
outs = remat_p.bind(
|
|
|
|
*jaxpr_jvp_.consts, *primals, *nonzero_tangents, jaxpr=jaxpr_jvp,
|
|
|
|
prevent_cse=prevent_cse, differentiated=differentiated, policy=policy)
|
|
|
|
out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)])
|
|
|
|
out_tangents_ = iter(out_tangents_)
|
|
|
|
out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p)
|
|
|
|
for p, nz in zip(out_primals, out_nonzeros)]
|
|
|
|
return out_primals, out_tangents
|
|
|
|
ad.primitive_jvps[remat_p] = remat_jvp
|
|
|
|
|
|
|
|
def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
|
|
|
assert not jaxpr.constvars
|
|
|
|
policy = params['policy'] or (lambda *_, **__: False)
|
2021-10-14 18:49:56 -07:00
|
|
|
# unzip into jaxpr_known and jaxpr_unknown
|
2021-10-12 20:06:38 -07:00
|
|
|
in_unknowns = [not t.is_known() for t in tracers]
|
|
|
|
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
|
|
|
|
pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)
|
|
|
|
jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
|
|
|
|
_, used_outs_unknown = partition_list(out_inst, out_unknowns)
|
|
|
|
jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown)
|
|
|
|
|
|
|
|
# compute known outputs and residuals (hoisted out of remat primitive)
|
|
|
|
_, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
|
|
|
|
_, in_consts = partition_list(in_used_known, in_consts_)
|
|
|
|
out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
|
|
|
|
out_consts_ = iter(out_consts)
|
|
|
|
# form known outputs and collect residual tracers
|
|
|
|
out_known_tracers = [
|
|
|
|
pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None)
|
|
|
|
for uk in out_unknowns if not uk]
|
|
|
|
residuals = list(out_consts_)
|
|
|
|
|
|
|
|
# set up unknown outputs with a recipe to call remat
|
|
|
|
res_tracers = map(trace.new_instantiated_const, residuals)
|
|
|
|
in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)]
|
|
|
|
_, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
|
|
|
|
out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
|
|
|
|
for x in jaxpr_unknown.outvars]
|
|
|
|
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
|
|
|
|
recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
|
|
|
|
new_params, source_info_util.current())
|
|
|
|
for t in out_jaxpr_tracers: t.recipe = recipe
|
|
|
|
|
|
|
|
# zip together known and unknown outputs
|
|
|
|
return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
|
|
|
|
pe.custom_partial_eval_rules[remat_p] = remat_partial_eval
|
|
|
|
|
|
|
|
def remat_paratial_eval_custom_params_updater(_, __, params_known, params_staged):
|
|
|
|
jaxpr_known = params_known.pop('call_jaxpr')
|
|
|
|
jaxpr_staged = params_staged.pop('call_jaxpr')
|
|
|
|
return (dict(params_known, jaxpr=jaxpr_known),
|
|
|
|
dict(params_staged, jaxpr=jaxpr_staged, differentiated=True))
|
|
|
|
pe.partial_eval_jaxpr_custom_rules[remat_p] = \
|
|
|
|
partial(pe.call_partial_eval_custom_rule, 'jaxpr',
|
|
|
|
remat_paratial_eval_custom_params_updater)
|
|
|
|
|
|
|
|
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
|
|
|
|
assert not jaxpr.constvars
|
|
|
|
cell = lambda: None
|
|
|
|
|
|
|
|
@lu.wrap_init
|
|
|
|
def transposed(*args):
|
|
|
|
in_primals, out_cts = tree_unflatten(treedef, args)
|
|
|
|
in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
|
|
|
|
pe.PartialVal.known(x) for x in in_primals]
|
|
|
|
primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
|
|
|
|
tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False)
|
|
|
|
dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars]
|
|
|
|
in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, consts, dummy_args,
|
|
|
|
out_cts)
|
|
|
|
in_cts, cell.treedef = tree_flatten(in_cts_)
|
|
|
|
return in_cts
|
|
|
|
|
|
|
|
args, treedef = tree_flatten((in_primals, out_cts))
|
|
|
|
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
|
|
|
|
transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
|
|
|
|
transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_)
|
|
|
|
in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params)
|
|
|
|
return tree_unflatten(cell.treedef, in_cts) # type: ignore
|
|
|
|
ad.reducing_transposes[remat_p] = remat_transpose
|
|
|
|
|
|
|
|
def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params):
|
|
|
|
assert not jaxpr.constvars
|
|
|
|
in_batched = [d is not batching.not_mapped for d in dims]
|
|
|
|
jaxpr_ = core.ClosedJaxpr(jaxpr, ())
|
|
|
|
jaxpr_batched_, out_batched = batching.batch_jaxpr(
|
|
|
|
jaxpr_, axis_size, in_batched, instantiate=False, axis_name=axis_name,
|
|
|
|
main_type=main_type)
|
|
|
|
jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
|
|
|
|
out_dims = [0 if b else None for b in out_batched]
|
|
|
|
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
|
|
|
|
batching.axis_primitive_batchers[remat_p] = remat_vmap
|
2021-10-13 18:21:20 -07:00
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_name(x, name):
|
|
|
|
return name_p.bind(x, name=name)
|
|
|
|
|
|
|
|
name_p.def_impl(lambda x, *, name: x)
|
|
|
|
name_p.def_abstract_eval(lambda x, *, name: x)
|
|
|
|
|
|
|
|
def name_jvp(primals, tangents, *, name):
|
|
|
|
(x,), (xdot,) = primals, tangents
|
|
|
|
return name_p.bind(x, name=name), xdot # don't name the tangent value
|
|
|
|
ad.primitive_jvps[name_p] = name_jvp
|
|
|
|
|
2021-10-19 09:47:55 -07:00
|
|
|
xla.register_translation(name_p,
|
|
|
|
lambda ctx, avals_in, avals_out, x, *, name: [x])
|
2021-10-13 18:21:20 -07:00
|
|
|
|
|
|
|
def name_batcher(args, dims, *, name):
|
|
|
|
(x,), (d,) = args, dims
|
|
|
|
return name_p.bind(x, name=name), d
|
|
|
|
batching.primitive_batchers[name_p] = name_batcher
|