From cbcfe95e800e3bcc6165484b0f46c78764cd2296 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 29 Jul 2022 15:23:29 -0700 Subject: [PATCH] fix ad_checkpoint.checkpoint caching issue Also add a config option to switch to the new checkpoint implementation globally (default False for now), as the first step in replacing and then deleting old remat. --- jax/_src/ad_checkpoint.py | 116 +++++++++++++++------- jax/_src/api.py | 17 ++-- jax/_src/config.py | 8 +- jax/_src/lax/control_flow/conditionals.py | 2 - jax/_src/lax/control_flow/loops.py | 2 - jax/experimental/host_callback.py | 11 ++ jax/interpreters/ad.py | 20 ++-- jax/interpreters/partial_eval.py | 4 + tests/api_test.py | 21 ++-- tests/host_callback_test.py | 2 + 10 files changed, 130 insertions(+), 73 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 0fbdf8d72..28a38694c 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -14,7 +14,7 @@ from functools import partial import operator as op -from typing import Callable, Optional, List, Tuple +from typing import Callable, Optional, List, Tuple, Sequence, Union import types import jax @@ -27,16 +27,15 @@ from jax.interpreters import mlir from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util from jax._src import source_info_util -from jax._src.api_util import flatten_fun +from jax._src.api_util import flatten_fun, shaped_abstractify from jax._src.traceback_util import api_boundary from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map, - safe_zip, merge_lists) + safe_zip, merge_lists, weakref_lru_cache) source_info_util.register_exclusion(__file__) # TODO(mattjj): before this can be the standard remat implementation, we must: # [ ] fix up callers who use the 'concrete' option (now removed) -# [ ] implement remat-of-control-flow-primitives (passing through the policy) map = safe_map zip = safe_zip @@ -209,18 +208,23 @@ def checkpoint(fun: Callable, prevent_cse: bool = True, @api_boundary def fun_remat(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) - flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) - in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] - debug = pe.debug_info(fun, in_tree, False, "checkpoint") - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + in_avals = [shaped_abstractify(x) for x in args_flat] + jaxpr, consts, out_tree = _trace_to_jaxpr(fun, in_tree, tuple(in_avals)) out_flat = remat_p.bind( - *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr), - prevent_cse=prevent_cse, differentiated=False, policy=policy) - return tree_unflatten(out_tree(), out_flat) + *consts, *args_flat, jaxpr=jaxpr, prevent_cse=prevent_cse, + differentiated=False, policy=policy) + return tree_unflatten(out_tree, out_flat) return fun_remat remat = checkpoint # alias +@weakref_lru_cache +def _trace_to_jaxpr(fun, in_tree, in_avals): + debug = pe.debug_info(fun, in_tree, False, "checkpoint") + flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + return pe.convert_constvars_jaxpr(jaxpr), consts, out_tree() + ### Utilities @@ -285,8 +289,7 @@ def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy): def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): assert not jaxpr.constvars in_nonzeros = [type(t) is not ad_util.Zero for t in tangents] - jaxpr_ = core.ClosedJaxpr(jaxpr, ()) - jaxpr_jvp_, out_nonzeros = ad.jvp_jaxpr(jaxpr_, in_nonzeros, False) + jaxpr_jvp_, out_nz = ad.jvp_jaxpr(pe.close_jaxpr(jaxpr), in_nonzeros, False) nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero] jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr) outs = remat_p.bind( @@ -295,7 +298,7 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)]) out_tangents_ = iter(out_tangents_) out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p) - for p, nz in zip(out_primals, out_nonzeros)] + for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents ad.primitive_jvps[remat_p] = remat_jvp @@ -351,45 +354,82 @@ pe.partial_eval_jaxpr_custom_rules[remat_p] = \ def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params): assert not jaxpr.constvars + in_linear = [ad.is_undefined_primal(x) for x in in_primals] + out_zeros = [type(ct) is ad_util.Zero for ct in out_cts] + transposed_jaxpr_, in_zeros = transpose_jaxpr( + pe.close_jaxpr(jaxpr), in_linear, out_zeros, reduce_axes) + transposed_jaxpr, consts = transposed_jaxpr_.jaxpr, transposed_jaxpr_.consts + transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr) + args, _ = tree_flatten((in_primals, out_cts)) + in_cts_nz = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params) + in_cts_nz_, in_zeros_ = iter(in_cts_nz), iter(in_zeros) + in_cts = [None if not ad.is_undefined_primal(x) else + ad_util.Zero(x.aval) if next(in_zeros_) else next(in_cts_nz_) + for x in in_primals] + assert next(in_cts_nz_, None) is next(in_zeros_, None) is None + return in_cts +ad.reducing_transposes[remat_p] = remat_transpose + +# TODO(mattjj): move this to ad.py +def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: Union[bool, Sequence[bool]], + out_zeros: Union[bool, Sequence[bool]], + reduce_axes: Sequence[core.AxisName], + ) -> Tuple[core.ClosedJaxpr, List[bool]]: + if type(in_linear) is bool: + in_linear = (in_linear,) * len(jaxpr.in_avals) + if type(out_zeros) is bool: + out_zeros = (out_zeros,) * len(jaxpr.out_avals) + return _transpose_jaxpr(jaxpr, tuple(in_linear), tuple(out_zeros), + tuple(reduce_axes)) + +@weakref_lru_cache +def _transpose_jaxpr(jaxpr, in_lin, out_zeros, reduce_axes): + in_avals = ([a for a, lin in zip(jaxpr.in_avals, in_lin ) if not lin] + + [a for a, zero in zip(jaxpr.out_avals, out_zeros) if not zero]) cell = lambda: None @lu.wrap_init - def transposed(*args): - in_primals, out_cts = tree_unflatten(treedef, args) - in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else - pe.PartialVal.known(x) for x in in_primals] - primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ())) - t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False) - dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars] - in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args, - out_cts) - in_cts_ = iter(in_cts) - in_cts = [next(in_cts_) if ad.is_undefined_primal(x) - else ad_util.Zero(x.aval) for x in in_primals] - assert next(in_cts_, None) is None - in_cts, cell.treedef = tree_flatten(in_cts) - return in_cts + def transposed(*args_flat): + ins_flat, out_cts_flat = split_list(args_flat, [len(in_lin) - sum(in_lin)]) + + # Evaluate nonlinear parts using partial evaluation to get a linear jaxpr. + ins_iter = iter(ins_flat) + in_pvals = [pe.PartialVal.unknown(aval) if lin else + pe.PartialVal.known(next(ins_iter)) + for aval, lin in zip(jaxpr.in_avals, in_lin)] + assert next(ins_iter, None) is None + lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits( + lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False) + + # Transpose the linear jaxpr (which only has linear inputs). + out_cts_iter = iter(out_cts_flat) + out_cts = [ad_util.Zero(aval) if zero else next(out_cts_iter) + for aval, zero in zip(jaxpr.out_avals, out_zeros)] + assert next(out_cts_iter, None) is None + dummy_args = [ad.UndefinedPrimal(v.aval) for v in lin_jaxpr.invars] + in_cts = ad.backward_pass(lin_jaxpr, reduce_axes, False, consts, dummy_args, + out_cts) + + # Identify symbolic zeros in the resulting cotangents, and return nonzeros. + in_zeros = cell.in_cts_zero = [type(ct) is ad_util.Zero for ct in in_cts] + in_cts_nz, _ = partition_list(in_zeros, in_cts) + return in_cts_nz - args, treedef = tree_flatten((in_primals, out_cts)) - in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args] transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals) - transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_) - in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params) - return tree_unflatten(cell.treedef, in_cts) # type: ignore -ad.reducing_transposes[remat_p] = remat_transpose + transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts) + return transposed_jaxpr, cell.in_cts_zero # type: ignore def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params): assert not jaxpr.constvars - jaxpr_ = core.ClosedJaxpr(jaxpr, ()) jaxpr_batched_, out_batched = batching.batch_jaxpr_axes( - jaxpr_, axis_size, dims, [batching.zero_if_mapped] * len(jaxpr.outvars), + pe.close_jaxpr(jaxpr), axis_size, dims, + [batching.zero_if_mapped] * len(jaxpr.outvars), axis_name=axis_name, main_type=main_type) jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts out_dims = [0 if b else None for b in out_batched] return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims batching.axis_primitive_batchers[remat_p] = remat_vmap -# TODO(mattjj,sharadmv): test this more # TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn ) -> Tuple[List[bool], Optional[core.JaxprEqn]]: diff --git a/jax/_src/api.py b/jax/_src/api.py index 81459c4ab..1207f98fc 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -70,7 +70,7 @@ from jax._src.lib.xla_bridge import (device_count, local_device_count, devices, local_devices, process_index, process_count, host_id, host_ids, host_count, default_backend) -from jax.ad_checkpoint import checkpoint_policies +from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint from jax.core import ShapedArray, raise_to_shaped from jax.custom_batching import custom_vmap from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp, @@ -3095,12 +3095,12 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True, ``pmap``, CSE can defeat the purpose of this decorator. But in some settings, like when used inside a ``scan``, this CSE prevention mechanism is unnecessary, in which case ``prevent_cse`` can be set to False. - policy: This is an experimental feature and the API is likely to change. - Optional callable, one of the attributes of ``jax.checkpoint_policies``, - which takes as input a type-level specification of a first-order primitive - application and returns a boolean indicating whether the corresponding - output value(s) can be saved as a residual (or, if not, instead must be - recomputed in the (co)tangent computation). + policy: Optional callable, one of the attributes of + ``jax.checkpoint_policies``, which takes as input a type-level + specification of a first-order primitive application and returns a boolean + indicating whether the corresponding output value(s) can be saved as a + residual (or instead must be recomputed in the (co)tangent computation if + needed). Returns: A function (callable) with the same input/output behavior as ``fun`` but @@ -3155,6 +3155,9 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True, ... return lambda x: f1(jax.checkpoint(f2)(x)) ... """ + if config.jax_new_checkpoint and not concrete: + return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy) + @wraps(fun) @api_boundary def remat_f(*args, **kwargs): diff --git a/jax/_src/config.py b/jax/_src/config.py index e5e9f30e3..a88ff6f2b 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -875,12 +875,12 @@ config.define_bool_state( default=(lib.version >= (0, 3, 6)), help=('Enables using optimization-barrier op for lowering remat.')) -# TODO(mattjj): remove after May 19 2022, NeurIPS submission deadline +# TODO(mattjj): set default to True, then remove config.define_bool_state( - name='after_neurips', - default=True, + name='jax_new_checkpoint', + default=False, upgrade=True, - help='Gate changes until after NeurIPS 2022 deadline.') + help='Whether to use the new jax.checkpoint implementation.') # TODO(b/205307544): Remove flag once coordination service has rolled out. config.define_bool_state( diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 5a6314fd2..e52f97f95 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -593,8 +593,6 @@ def _ordered_unique(xs): def _cond_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn, ) -> Tuple[List[bool], core.JaxprEqn]: - if not config.after_neurips: - return [True] * len(eqn.params['jaxpr'].in_avals), eqn closed_branches = eqn.params['branches'] branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches] diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index be26089b0..91ff1188d 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -781,8 +781,6 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn ) -> Tuple[List[bool], core.JaxprEqn]: - if not config.after_neurips: - return [True] * len(eqn.params['jaxpr'].in_avals), eqn jaxpr = eqn.params['jaxpr'] num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] num_xs = len(jaxpr.in_avals) - num_consts - num_carry diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 098e2dde7..c9c870c1f 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -515,6 +515,7 @@ from jax.experimental import pjit from jax.interpreters import ad, xla, batching, pxla from jax.interpreters import partial_eval as pe from jax.interpreters import mlir +from jax._src import ad_checkpoint from jax._src import dispatch from jax._src import pretty_printer as pp from jax._src import source_info_util @@ -1675,6 +1676,16 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], out_axis_resources=(eqn.params["out_axis_resources"] + (pjit.REPLICATED, pjit.REPLICATED)), ))) + elif eqn.primitive is ad_checkpoint.remat_p: + jaxpr_ = cast(core.Jaxpr, eqn.params["jaxpr"]) + eqns.append( + eqn.replace( + invars=eqn.invars + [input_token_var, input_itoken_var], + outvars=eqn.outvars + [output_token_var, output_itoken_var], + params=dict( + eqn.params, + jaxpr=_rewrite_jaxpr(jaxpr_, True, True), + ))) else: raise NotImplementedError(f"outfeed rewrite {eqn.primitive}") diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index f166f1b0a..ca7919f8c 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -16,8 +16,7 @@ import contextlib import functools from functools import partial import itertools as it -from typing import Any, Callable, Dict, List, Tuple, Optional - +from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union import jax from jax.interpreters import partial_eval as pe from jax.config import config @@ -642,8 +641,8 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals, reduce_axes): unknowns = map(is_undefined_primal, primals_in) primal_jaxpr, tangent_jaxpr, _, _ = \ - pe.partial_eval_jaxpr_nounits(_close_jaxpr(call_jaxpr), unknowns=unknowns, - instantiate=True) # type: ignore + pe.partial_eval_jaxpr_nounits(pe.close_jaxpr(call_jaxpr), + unknowns=unknowns, instantiate=True) # type: ignore args, in_tree = tree_flatten((primals_in, cotangents_in)) transpose = lu.hashable_partial(lu.wrap_init(_remat_transpose), primal_jaxpr, tangent_jaxpr, reduce_axes) @@ -665,10 +664,6 @@ def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes, assert next(cotangents_out, None) is None return outs -@weakref_lru_cache -def _close_jaxpr(jaxpr: core.Jaxpr) -> core.ClosedJaxpr: - return core.ClosedJaxpr(jaxpr, []) - @lu.transformation_with_aux def nonzero_outputs(*args, **kwargs): results = yield args, kwargs @@ -717,9 +712,12 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): return tuple(arg_cts) -def jvp_jaxpr(jaxpr, nonzeros, instantiate): - inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate - return _jvp_jaxpr(jaxpr, tuple(nonzeros), inst) +def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], + instantiate: Union[bool, Sequence[bool]] + ) -> Tuple[core.ClosedJaxpr, List[bool]]: + if type(instantiate) is bool: + instantiate = (instantiate,) * len(jaxpr.out_avals) + return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate)) @weakref_lru_cache def _jvp_jaxpr(jaxpr, nonzeros, instantiate): diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 261c41be0..3a6250305 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1509,6 +1509,10 @@ def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn return used_inputs, new_eqn dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule +@weakref_lru_cache +def close_jaxpr(jaxpr: Jaxpr) -> ClosedJaxpr: + return ClosedJaxpr(jaxpr, ()) + def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] ) -> ClosedJaxpr: """Reorder `invars` by moving those indicated in `to_move` to the front.""" diff --git a/tests/api_test.py b/tests/api_test.py index 7552ab630..106208b33 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4784,6 +4784,16 @@ class RematTest(jtu.JaxTestCase): f_vjp(1.)[0].block_until_ready() self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd + @unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here") + def test_fwd_caching(self): + # see above test also + identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) + with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + for _ in range(20): + y, _ = jax.vjp(identity, 1.) + y.block_until_ready() + self.assertEqual(count[0], 1) + @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ @@ -4851,7 +4861,6 @@ class RematTest(jtu.JaxTestCase): jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2, modes=['fwd', 'rev']) - @unittest.skipIf(not config.after_neurips, "skip until neurips deadline") def test_remat_of_scan_policy(self): save_cos = lambda prim, *_, **__: str(prim) == 'cos' to_scan = lambda c, _: (jnp.sin(c), jnp.sin(c)) @@ -4864,7 +4873,6 @@ class RematTest(jtu.JaxTestCase): self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 0) - @unittest.skipIf(not config.after_neurips, "skip until neurips deadline") def test_remat_of_scan_funky_custom_jvp(self): def scan_apply(f, x): y, _ = lax.scan(lambda x, _: (f(x), None), x, None, length=1) @@ -4921,7 +4929,6 @@ class RematTest(jtu.JaxTestCase): self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point self.assertEqual(jaxpr_text.count(' cos '), 2) - @unittest.skipIf(not config.after_neurips, "skip until neurips deadline") def test_remat_of_scan_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use scan. @@ -5081,7 +5088,6 @@ class RematTest(jtu.JaxTestCase): jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2, modes=['fwd', 'rev']) - @unittest.skipIf(not config.after_neurips, "skip until neurips deadline") def test_remat_of_cond_policy(self): save_cos = lambda prim, *_, **__: str(prim) == 'cos' f = new_checkpoint(lambda x: lax.cond(x > 0, jnp.sin, lambda x: x, x), @@ -5093,7 +5099,6 @@ class RematTest(jtu.JaxTestCase): self.assertEqual(jaxpr_text.count(' sin '), 0) self.assertEqual(jaxpr_text.count(' cos '), 0) - @unittest.skipIf(not config.after_neurips, "skip until neurips deadline") def test_remat_of_cond_funky_custom_jvp(self): def cond_apply(f, x): return lax.cond(x.sum() > -jnp.inf, f, lambda x: x, x) @@ -5149,7 +5154,6 @@ class RematTest(jtu.JaxTestCase): self.assertEqual(jaxpr_text.count(' sin '), 1) self.assertEqual(jaxpr_text.count(' cos '), 2) - @unittest.skipIf(not config.after_neurips, "skip until neurips deadline") def test_remat_of_cond_funky_custom_jvp2(self): # Like the above test but instead of using jit inside custom_jvp, use cond. @@ -5395,7 +5399,6 @@ class JaxprTest(jtu.JaxTestCase): self.assertLen(jaxpr.eqns, 0) -@unittest.skipIf(not config.after_neurips, "skip until neurips deadline") class DCETest(jtu.JaxTestCase): def assert_dce_result(self, jaxpr: core.Jaxpr, used_outputs: List[bool], @@ -5580,9 +5583,9 @@ class DCETest(jtu.JaxTestCase): return out, out def f(xs): - return lax.scan(scanned_f, 1., xs) + return lax.scan(scanned_f, jnp.array(1., 'float32'), xs) - xs = jnp.arange(10.) + xs = jnp.arange(10., dtype='float32') jaxpr = api.make_jaxpr(lambda xs: api.linearize(f, xs)[1])(xs).jaxpr jaxpr, used_inputs = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars)) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 166c0a893..5c3cdd434 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -1979,6 +1979,8 @@ class HostCallbackTapTest(jtu.JaxTestCase): for grad_func in ["grad", "value_and_grad"] for use_remat in ["old", "new", "none"])) def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"): + if config.jax_new_checkpoint and use_remat == "old": raise SkipTest() + def f(x): id_print_result = hcb.id_print(x, output_stream=testing_stream) if use_result: