Merge pull request #11675 from mattjj:new-remat-caching

PiperOrigin-RevId: 464543287
This commit is contained in:
jax authors 2022-08-01 08:38:28 -07:00
commit be8939771c
10 changed files with 130 additions and 73 deletions

View File

@ -14,7 +14,7 @@
from functools import partial
import operator as op
from typing import Callable, Optional, List, Tuple
from typing import Callable, Optional, List, Tuple, Sequence, Union
import types
import jax
@ -27,16 +27,15 @@ from jax.interpreters import mlir
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src import source_info_util
from jax._src.api_util import flatten_fun
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip, merge_lists)
safe_zip, merge_lists, weakref_lru_cache)
source_info_util.register_exclusion(__file__)
# TODO(mattjj): before this can be the standard remat implementation, we must:
# [ ] fix up callers who use the 'concrete' option (now removed)
# [ ] implement remat-of-control-flow-primitives (passing through the policy)
map = safe_map
zip = safe_zip
@ -209,18 +208,23 @@ def checkpoint(fun: Callable, prevent_cse: bool = True,
@api_boundary
def fun_remat(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(fun, in_tree, False, "checkpoint")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
in_avals = [shaped_abstractify(x) for x in args_flat]
jaxpr, consts, out_tree = _trace_to_jaxpr(fun, in_tree, tuple(in_avals))
out_flat = remat_p.bind(
*consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr),
prevent_cse=prevent_cse, differentiated=False, policy=policy)
return tree_unflatten(out_tree(), out_flat)
*consts, *args_flat, jaxpr=jaxpr, prevent_cse=prevent_cse,
differentiated=False, policy=policy)
return tree_unflatten(out_tree, out_flat)
return fun_remat
remat = checkpoint # alias
@weakref_lru_cache
def _trace_to_jaxpr(fun, in_tree, in_avals):
debug = pe.debug_info(fun, in_tree, False, "checkpoint")
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
return pe.convert_constvars_jaxpr(jaxpr), consts, out_tree()
### Utilities
@ -285,8 +289,7 @@ def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy):
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
assert not jaxpr.constvars
in_nonzeros = [type(t) is not ad_util.Zero for t in tangents]
jaxpr_ = core.ClosedJaxpr(jaxpr, ())
jaxpr_jvp_, out_nonzeros = ad.jvp_jaxpr(jaxpr_, in_nonzeros, False)
jaxpr_jvp_, out_nz = ad.jvp_jaxpr(pe.close_jaxpr(jaxpr), in_nonzeros, False)
nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero]
jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr)
outs = remat_p.bind(
@ -295,7 +298,7 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)])
out_tangents_ = iter(out_tangents_)
out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p)
for p, nz in zip(out_primals, out_nonzeros)]
for p, nz in zip(out_primals, out_nz)]
return out_primals, out_tangents
ad.primitive_jvps[remat_p] = remat_jvp
@ -351,45 +354,82 @@ pe.partial_eval_jaxpr_custom_rules[remat_p] = \
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
assert not jaxpr.constvars
in_linear = [ad.is_undefined_primal(x) for x in in_primals]
out_zeros = [type(ct) is ad_util.Zero for ct in out_cts]
transposed_jaxpr_, in_zeros = transpose_jaxpr(
pe.close_jaxpr(jaxpr), in_linear, out_zeros, reduce_axes)
transposed_jaxpr, consts = transposed_jaxpr_.jaxpr, transposed_jaxpr_.consts
transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr)
args, _ = tree_flatten((in_primals, out_cts))
in_cts_nz = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params)
in_cts_nz_, in_zeros_ = iter(in_cts_nz), iter(in_zeros)
in_cts = [None if not ad.is_undefined_primal(x) else
ad_util.Zero(x.aval) if next(in_zeros_) else next(in_cts_nz_)
for x in in_primals]
assert next(in_cts_nz_, None) is next(in_zeros_, None) is None
return in_cts
ad.reducing_transposes[remat_p] = remat_transpose
# TODO(mattjj): move this to ad.py
def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: Union[bool, Sequence[bool]],
out_zeros: Union[bool, Sequence[bool]],
reduce_axes: Sequence[core.AxisName],
) -> Tuple[core.ClosedJaxpr, List[bool]]:
if type(in_linear) is bool:
in_linear = (in_linear,) * len(jaxpr.in_avals)
if type(out_zeros) is bool:
out_zeros = (out_zeros,) * len(jaxpr.out_avals)
return _transpose_jaxpr(jaxpr, tuple(in_linear), tuple(out_zeros),
tuple(reduce_axes))
@weakref_lru_cache
def _transpose_jaxpr(jaxpr, in_lin, out_zeros, reduce_axes):
in_avals = ([a for a, lin in zip(jaxpr.in_avals, in_lin ) if not lin] +
[a for a, zero in zip(jaxpr.out_avals, out_zeros) if not zero])
cell = lambda: None
@lu.wrap_init
def transposed(*args):
in_primals, out_cts = tree_unflatten(treedef, args)
in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
pe.PartialVal.known(x) for x in in_primals]
primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False)
dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars]
in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args,
out_cts)
in_cts_ = iter(in_cts)
in_cts = [next(in_cts_) if ad.is_undefined_primal(x)
else ad_util.Zero(x.aval) for x in in_primals]
assert next(in_cts_, None) is None
in_cts, cell.treedef = tree_flatten(in_cts)
return in_cts
def transposed(*args_flat):
ins_flat, out_cts_flat = split_list(args_flat, [len(in_lin) - sum(in_lin)])
# Evaluate nonlinear parts using partial evaluation to get a linear jaxpr.
ins_iter = iter(ins_flat)
in_pvals = [pe.PartialVal.unknown(aval) if lin else
pe.PartialVal.known(next(ins_iter))
for aval, lin in zip(jaxpr.in_avals, in_lin)]
assert next(ins_iter, None) is None
lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False)
# Transpose the linear jaxpr (which only has linear inputs).
out_cts_iter = iter(out_cts_flat)
out_cts = [ad_util.Zero(aval) if zero else next(out_cts_iter)
for aval, zero in zip(jaxpr.out_avals, out_zeros)]
assert next(out_cts_iter, None) is None
dummy_args = [ad.UndefinedPrimal(v.aval) for v in lin_jaxpr.invars]
in_cts = ad.backward_pass(lin_jaxpr, reduce_axes, False, consts, dummy_args,
out_cts)
# Identify symbolic zeros in the resulting cotangents, and return nonzeros.
in_zeros = cell.in_cts_zero = [type(ct) is ad_util.Zero for ct in in_cts]
in_cts_nz, _ = partition_list(in_zeros, in_cts)
return in_cts_nz
args, treedef = tree_flatten((in_primals, out_cts))
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_)
in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params)
return tree_unflatten(cell.treedef, in_cts) # type: ignore
ad.reducing_transposes[remat_p] = remat_transpose
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
return transposed_jaxpr, cell.in_cts_zero # type: ignore
def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params):
assert not jaxpr.constvars
jaxpr_ = core.ClosedJaxpr(jaxpr, ())
jaxpr_batched_, out_batched = batching.batch_jaxpr_axes(
jaxpr_, axis_size, dims, [batching.zero_if_mapped] * len(jaxpr.outvars),
pe.close_jaxpr(jaxpr), axis_size, dims,
[batching.zero_if_mapped] * len(jaxpr.outvars),
axis_name=axis_name, main_type=main_type)
jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
out_dims = [0 if b else None for b in out_batched]
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
batching.axis_primitive_batchers[remat_p] = remat_vmap
# TODO(mattjj,sharadmv): test this more
# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:

View File

@ -70,7 +70,7 @@ from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, process_index,
process_count, host_id, host_ids,
host_count, default_backend)
from jax.ad_checkpoint import checkpoint_policies
from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint
from jax.core import ShapedArray, raise_to_shaped
from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
@ -3095,12 +3095,12 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
``pmap``, CSE can defeat the purpose of this decorator. But in some
settings, like when used inside a ``scan``, this CSE prevention mechanism
is unnecessary, in which case ``prevent_cse`` can be set to False.
policy: This is an experimental feature and the API is likely to change.
Optional callable, one of the attributes of ``jax.checkpoint_policies``,
which takes as input a type-level specification of a first-order primitive
application and returns a boolean indicating whether the corresponding
output value(s) can be saved as a residual (or, if not, instead must be
recomputed in the (co)tangent computation).
policy: Optional callable, one of the attributes of
``jax.checkpoint_policies``, which takes as input a type-level
specification of a first-order primitive application and returns a boolean
indicating whether the corresponding output value(s) can be saved as a
residual (or instead must be recomputed in the (co)tangent computation if
needed).
Returns:
A function (callable) with the same input/output behavior as ``fun`` but
@ -3155,6 +3155,9 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
... return lambda x: f1(jax.checkpoint(f2)(x))
...
"""
if config.jax_new_checkpoint and not concrete:
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy)
@wraps(fun)
@api_boundary
def remat_f(*args, **kwargs):

View File

@ -875,12 +875,12 @@ config.define_bool_state(
default=(lib.version >= (0, 3, 6)),
help=('Enables using optimization-barrier op for lowering remat.'))
# TODO(mattjj): remove after May 19 2022, NeurIPS submission deadline
# TODO(mattjj): set default to True, then remove
config.define_bool_state(
name='after_neurips',
default=True,
name='jax_new_checkpoint',
default=False,
upgrade=True,
help='Gate changes until after NeurIPS 2022 deadline.')
help='Whether to use the new jax.checkpoint implementation.')
# TODO(b/205307544): Remove flag once coordination service has rolled out.
config.define_bool_state(

View File

@ -593,8 +593,6 @@ def _ordered_unique(xs):
def _cond_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn,
) -> Tuple[List[bool], core.JaxprEqn]:
if not config.after_neurips:
return [True] * len(eqn.params['jaxpr'].in_avals), eqn
closed_branches = eqn.params['branches']
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]

View File

@ -781,8 +781,6 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], core.JaxprEqn]:
if not config.after_neurips:
return [True] * len(eqn.params['jaxpr'].in_avals), eqn
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_xs = len(jaxpr.in_avals) - num_consts - num_carry

View File

@ -515,6 +515,7 @@ from jax.experimental import pjit
from jax.interpreters import ad, xla, batching, pxla
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax._src import ad_checkpoint
from jax._src import dispatch
from jax._src import pretty_printer as pp
from jax._src import source_info_util
@ -1675,6 +1676,16 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
out_axis_resources=(eqn.params["out_axis_resources"] +
(pjit.REPLICATED, pjit.REPLICATED)),
)))
elif eqn.primitive is ad_checkpoint.remat_p:
jaxpr_ = cast(core.Jaxpr, eqn.params["jaxpr"])
eqns.append(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
jaxpr=_rewrite_jaxpr(jaxpr_, True, True),
)))
else:
raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")

View File

@ -16,8 +16,7 @@ import contextlib
import functools
from functools import partial
import itertools as it
from typing import Any, Callable, Dict, List, Tuple, Optional
from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
import jax
from jax.interpreters import partial_eval as pe
from jax.config import config
@ -642,8 +641,8 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
cotangent_in_avals, reduce_axes):
unknowns = map(is_undefined_primal, primals_in)
primal_jaxpr, tangent_jaxpr, _, _ = \
pe.partial_eval_jaxpr_nounits(_close_jaxpr(call_jaxpr), unknowns=unknowns,
instantiate=True) # type: ignore
pe.partial_eval_jaxpr_nounits(pe.close_jaxpr(call_jaxpr),
unknowns=unknowns, instantiate=True) # type: ignore
args, in_tree = tree_flatten((primals_in, cotangents_in))
transpose = lu.hashable_partial(lu.wrap_init(_remat_transpose), primal_jaxpr,
tangent_jaxpr, reduce_axes)
@ -665,10 +664,6 @@ def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes,
assert next(cotangents_out, None) is None
return outs
@weakref_lru_cache
def _close_jaxpr(jaxpr: core.Jaxpr) -> core.ClosedJaxpr:
return core.ClosedJaxpr(jaxpr, [])
@lu.transformation_with_aux
def nonzero_outputs(*args, **kwargs):
results = yield args, kwargs
@ -717,9 +712,12 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
return tuple(arg_cts)
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
return _jvp_jaxpr(jaxpr, tuple(nonzeros), inst)
def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
instantiate: Union[bool, Sequence[bool]]
) -> Tuple[core.ClosedJaxpr, List[bool]]:
if type(instantiate) is bool:
instantiate = (instantiate,) * len(jaxpr.out_avals)
return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate))
@weakref_lru_cache
def _jvp_jaxpr(jaxpr, nonzeros, instantiate):

View File

@ -1509,6 +1509,10 @@ def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn
return used_inputs, new_eqn
dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule
@weakref_lru_cache
def close_jaxpr(jaxpr: Jaxpr) -> ClosedJaxpr:
return ClosedJaxpr(jaxpr, ())
def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
) -> ClosedJaxpr:
"""Reorder `invars` by moving those indicated in `to_move` to the front."""

View File

@ -4784,6 +4784,16 @@ class RematTest(jtu.JaxTestCase):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
@unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here")
def test_fwd_caching(self):
# see above test also
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
for _ in range(20):
y, _ = jax.vjp(identity, 1.)
y.block_until_ready()
self.assertEqual(count[0], 1)
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
@ -4851,7 +4861,6 @@ class RematTest(jtu.JaxTestCase):
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
modes=['fwd', 'rev'])
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_scan_policy(self):
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
to_scan = lambda c, _: (jnp.sin(c), jnp.sin(c))
@ -4864,7 +4873,6 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_scan_funky_custom_jvp(self):
def scan_apply(f, x):
y, _ = lax.scan(lambda x, _: (f(x), None), x, None, length=1)
@ -4921,7 +4929,6 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
self.assertEqual(jaxpr_text.count(' cos '), 2)
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_scan_funky_custom_jvp2(self):
# Like the above test but instead of using jit inside custom_jvp, use scan.
@ -5081,7 +5088,6 @@ class RematTest(jtu.JaxTestCase):
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
modes=['fwd', 'rev'])
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_cond_policy(self):
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
f = new_checkpoint(lambda x: lax.cond(x > 0, jnp.sin, lambda x: x, x),
@ -5093,7 +5099,6 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_cond_funky_custom_jvp(self):
def cond_apply(f, x):
return lax.cond(x.sum() > -jnp.inf, f, lambda x: x, x)
@ -5149,7 +5154,6 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(jaxpr_text.count(' sin '), 1)
self.assertEqual(jaxpr_text.count(' cos '), 2)
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
def test_remat_of_cond_funky_custom_jvp2(self):
# Like the above test but instead of using jit inside custom_jvp, use cond.
@ -5395,7 +5399,6 @@ class JaxprTest(jtu.JaxTestCase):
self.assertLen(jaxpr.eqns, 0)
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
class DCETest(jtu.JaxTestCase):
def assert_dce_result(self, jaxpr: core.Jaxpr, used_outputs: List[bool],
@ -5580,9 +5583,9 @@ class DCETest(jtu.JaxTestCase):
return out, out
def f(xs):
return lax.scan(scanned_f, 1., xs)
return lax.scan(scanned_f, jnp.array(1., 'float32'), xs)
xs = jnp.arange(10.)
xs = jnp.arange(10., dtype='float32')
jaxpr = api.make_jaxpr(lambda xs: api.linearize(f, xs)[1])(xs).jaxpr
jaxpr, used_inputs = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars))

View File

@ -1979,6 +1979,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
for grad_func in ["grad", "value_and_grad"]
for use_remat in ["old", "new", "none"]))
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
if config.jax_new_checkpoint and use_remat == "old": raise SkipTest()
def f(x):
id_print_result = hcb.id_print(x, output_stream=testing_stream)
if use_result: