mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Merge pull request #11675 from mattjj:new-remat-caching
PiperOrigin-RevId: 464543287
This commit is contained in:
commit
be8939771c
@ -14,7 +14,7 @@
|
||||
|
||||
from functools import partial
|
||||
import operator as op
|
||||
from typing import Callable, Optional, List, Tuple
|
||||
from typing import Callable, Optional, List, Tuple, Sequence, Union
|
||||
import types
|
||||
|
||||
import jax
|
||||
@ -27,16 +27,15 @@ from jax.interpreters import mlir
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src import source_info_util
|
||||
from jax._src.api_util import flatten_fun
|
||||
from jax._src.api_util import flatten_fun, shaped_abstractify
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
|
||||
safe_zip, merge_lists)
|
||||
safe_zip, merge_lists, weakref_lru_cache)
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
|
||||
# TODO(mattjj): before this can be the standard remat implementation, we must:
|
||||
# [ ] fix up callers who use the 'concrete' option (now removed)
|
||||
# [ ] implement remat-of-control-flow-primitives (passing through the policy)
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
@ -209,18 +208,23 @@ def checkpoint(fun: Callable, prevent_cse: bool = True,
|
||||
@api_boundary
|
||||
def fun_remat(*args, **kwargs):
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
debug = pe.debug_info(fun, in_tree, False, "checkpoint")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
in_avals = [shaped_abstractify(x) for x in args_flat]
|
||||
jaxpr, consts, out_tree = _trace_to_jaxpr(fun, in_tree, tuple(in_avals))
|
||||
out_flat = remat_p.bind(
|
||||
*consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr),
|
||||
prevent_cse=prevent_cse, differentiated=False, policy=policy)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
*consts, *args_flat, jaxpr=jaxpr, prevent_cse=prevent_cse,
|
||||
differentiated=False, policy=policy)
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
return fun_remat
|
||||
|
||||
remat = checkpoint # alias
|
||||
|
||||
@weakref_lru_cache
|
||||
def _trace_to_jaxpr(fun, in_tree, in_avals):
|
||||
debug = pe.debug_info(fun, in_tree, False, "checkpoint")
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
return pe.convert_constvars_jaxpr(jaxpr), consts, out_tree()
|
||||
|
||||
|
||||
### Utilities
|
||||
|
||||
@ -285,8 +289,7 @@ def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy):
|
||||
def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
|
||||
assert not jaxpr.constvars
|
||||
in_nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
||||
jaxpr_ = core.ClosedJaxpr(jaxpr, ())
|
||||
jaxpr_jvp_, out_nonzeros = ad.jvp_jaxpr(jaxpr_, in_nonzeros, False)
|
||||
jaxpr_jvp_, out_nz = ad.jvp_jaxpr(pe.close_jaxpr(jaxpr), in_nonzeros, False)
|
||||
nonzero_tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
||||
jaxpr_jvp = pe.convert_constvars_jaxpr(jaxpr_jvp_.jaxpr)
|
||||
outs = remat_p.bind(
|
||||
@ -295,7 +298,7 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
|
||||
out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)])
|
||||
out_tangents_ = iter(out_tangents_)
|
||||
out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p)
|
||||
for p, nz in zip(out_primals, out_nonzeros)]
|
||||
for p, nz in zip(out_primals, out_nz)]
|
||||
return out_primals, out_tangents
|
||||
ad.primitive_jvps[remat_p] = remat_jvp
|
||||
|
||||
@ -351,45 +354,82 @@ pe.partial_eval_jaxpr_custom_rules[remat_p] = \
|
||||
|
||||
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
|
||||
assert not jaxpr.constvars
|
||||
in_linear = [ad.is_undefined_primal(x) for x in in_primals]
|
||||
out_zeros = [type(ct) is ad_util.Zero for ct in out_cts]
|
||||
transposed_jaxpr_, in_zeros = transpose_jaxpr(
|
||||
pe.close_jaxpr(jaxpr), in_linear, out_zeros, reduce_axes)
|
||||
transposed_jaxpr, consts = transposed_jaxpr_.jaxpr, transposed_jaxpr_.consts
|
||||
transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr)
|
||||
args, _ = tree_flatten((in_primals, out_cts))
|
||||
in_cts_nz = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params)
|
||||
in_cts_nz_, in_zeros_ = iter(in_cts_nz), iter(in_zeros)
|
||||
in_cts = [None if not ad.is_undefined_primal(x) else
|
||||
ad_util.Zero(x.aval) if next(in_zeros_) else next(in_cts_nz_)
|
||||
for x in in_primals]
|
||||
assert next(in_cts_nz_, None) is next(in_zeros_, None) is None
|
||||
return in_cts
|
||||
ad.reducing_transposes[remat_p] = remat_transpose
|
||||
|
||||
# TODO(mattjj): move this to ad.py
|
||||
def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: Union[bool, Sequence[bool]],
|
||||
out_zeros: Union[bool, Sequence[bool]],
|
||||
reduce_axes: Sequence[core.AxisName],
|
||||
) -> Tuple[core.ClosedJaxpr, List[bool]]:
|
||||
if type(in_linear) is bool:
|
||||
in_linear = (in_linear,) * len(jaxpr.in_avals)
|
||||
if type(out_zeros) is bool:
|
||||
out_zeros = (out_zeros,) * len(jaxpr.out_avals)
|
||||
return _transpose_jaxpr(jaxpr, tuple(in_linear), tuple(out_zeros),
|
||||
tuple(reduce_axes))
|
||||
|
||||
@weakref_lru_cache
|
||||
def _transpose_jaxpr(jaxpr, in_lin, out_zeros, reduce_axes):
|
||||
in_avals = ([a for a, lin in zip(jaxpr.in_avals, in_lin ) if not lin] +
|
||||
[a for a, zero in zip(jaxpr.out_avals, out_zeros) if not zero])
|
||||
cell = lambda: None
|
||||
|
||||
@lu.wrap_init
|
||||
def transposed(*args):
|
||||
in_primals, out_cts = tree_unflatten(treedef, args)
|
||||
in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
|
||||
pe.PartialVal.known(x) for x in in_primals]
|
||||
primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
|
||||
t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False)
|
||||
dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars]
|
||||
in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args,
|
||||
out_cts)
|
||||
in_cts_ = iter(in_cts)
|
||||
in_cts = [next(in_cts_) if ad.is_undefined_primal(x)
|
||||
else ad_util.Zero(x.aval) for x in in_primals]
|
||||
assert next(in_cts_, None) is None
|
||||
in_cts, cell.treedef = tree_flatten(in_cts)
|
||||
return in_cts
|
||||
def transposed(*args_flat):
|
||||
ins_flat, out_cts_flat = split_list(args_flat, [len(in_lin) - sum(in_lin)])
|
||||
|
||||
# Evaluate nonlinear parts using partial evaluation to get a linear jaxpr.
|
||||
ins_iter = iter(ins_flat)
|
||||
in_pvals = [pe.PartialVal.unknown(aval) if lin else
|
||||
pe.PartialVal.known(next(ins_iter))
|
||||
for aval, lin in zip(jaxpr.in_avals, in_lin)]
|
||||
assert next(ins_iter, None) is None
|
||||
lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(
|
||||
lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False)
|
||||
|
||||
# Transpose the linear jaxpr (which only has linear inputs).
|
||||
out_cts_iter = iter(out_cts_flat)
|
||||
out_cts = [ad_util.Zero(aval) if zero else next(out_cts_iter)
|
||||
for aval, zero in zip(jaxpr.out_avals, out_zeros)]
|
||||
assert next(out_cts_iter, None) is None
|
||||
dummy_args = [ad.UndefinedPrimal(v.aval) for v in lin_jaxpr.invars]
|
||||
in_cts = ad.backward_pass(lin_jaxpr, reduce_axes, False, consts, dummy_args,
|
||||
out_cts)
|
||||
|
||||
# Identify symbolic zeros in the resulting cotangents, and return nonzeros.
|
||||
in_zeros = cell.in_cts_zero = [type(ct) is ad_util.Zero for ct in in_cts]
|
||||
in_cts_nz, _ = partition_list(in_zeros, in_cts)
|
||||
return in_cts_nz
|
||||
|
||||
args, treedef = tree_flatten((in_primals, out_cts))
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
|
||||
transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
|
||||
transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_)
|
||||
in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params)
|
||||
return tree_unflatten(cell.treedef, in_cts) # type: ignore
|
||||
ad.reducing_transposes[remat_p] = remat_transpose
|
||||
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
|
||||
return transposed_jaxpr, cell.in_cts_zero # type: ignore
|
||||
|
||||
def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params):
|
||||
assert not jaxpr.constvars
|
||||
jaxpr_ = core.ClosedJaxpr(jaxpr, ())
|
||||
jaxpr_batched_, out_batched = batching.batch_jaxpr_axes(
|
||||
jaxpr_, axis_size, dims, [batching.zero_if_mapped] * len(jaxpr.outvars),
|
||||
pe.close_jaxpr(jaxpr), axis_size, dims,
|
||||
[batching.zero_if_mapped] * len(jaxpr.outvars),
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
|
||||
out_dims = [0 if b else None for b in out_batched]
|
||||
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
|
||||
batching.axis_primitive_batchers[remat_p] = remat_vmap
|
||||
|
||||
# TODO(mattjj,sharadmv): test this more
|
||||
# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
|
||||
def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
|
||||
|
@ -70,7 +70,7 @@ from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
local_devices, process_index,
|
||||
process_count, host_id, host_ids,
|
||||
host_count, default_backend)
|
||||
from jax.ad_checkpoint import checkpoint_policies
|
||||
from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint
|
||||
from jax.core import ShapedArray, raise_to_shaped
|
||||
from jax.custom_batching import custom_vmap
|
||||
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
|
||||
@ -3095,12 +3095,12 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
|
||||
``pmap``, CSE can defeat the purpose of this decorator. But in some
|
||||
settings, like when used inside a ``scan``, this CSE prevention mechanism
|
||||
is unnecessary, in which case ``prevent_cse`` can be set to False.
|
||||
policy: This is an experimental feature and the API is likely to change.
|
||||
Optional callable, one of the attributes of ``jax.checkpoint_policies``,
|
||||
which takes as input a type-level specification of a first-order primitive
|
||||
application and returns a boolean indicating whether the corresponding
|
||||
output value(s) can be saved as a residual (or, if not, instead must be
|
||||
recomputed in the (co)tangent computation).
|
||||
policy: Optional callable, one of the attributes of
|
||||
``jax.checkpoint_policies``, which takes as input a type-level
|
||||
specification of a first-order primitive application and returns a boolean
|
||||
indicating whether the corresponding output value(s) can be saved as a
|
||||
residual (or instead must be recomputed in the (co)tangent computation if
|
||||
needed).
|
||||
|
||||
Returns:
|
||||
A function (callable) with the same input/output behavior as ``fun`` but
|
||||
@ -3155,6 +3155,9 @@ def checkpoint(fun: Callable, concrete: bool = False, prevent_cse: bool = True,
|
||||
... return lambda x: f1(jax.checkpoint(f2)(x))
|
||||
...
|
||||
"""
|
||||
if config.jax_new_checkpoint and not concrete:
|
||||
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy)
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def remat_f(*args, **kwargs):
|
||||
|
@ -875,12 +875,12 @@ config.define_bool_state(
|
||||
default=(lib.version >= (0, 3, 6)),
|
||||
help=('Enables using optimization-barrier op for lowering remat.'))
|
||||
|
||||
# TODO(mattjj): remove after May 19 2022, NeurIPS submission deadline
|
||||
# TODO(mattjj): set default to True, then remove
|
||||
config.define_bool_state(
|
||||
name='after_neurips',
|
||||
default=True,
|
||||
name='jax_new_checkpoint',
|
||||
default=False,
|
||||
upgrade=True,
|
||||
help='Gate changes until after NeurIPS 2022 deadline.')
|
||||
help='Whether to use the new jax.checkpoint implementation.')
|
||||
|
||||
# TODO(b/205307544): Remove flag once coordination service has rolled out.
|
||||
config.define_bool_state(
|
||||
|
@ -593,8 +593,6 @@ def _ordered_unique(xs):
|
||||
|
||||
def _cond_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn,
|
||||
) -> Tuple[List[bool], core.JaxprEqn]:
|
||||
if not config.after_neurips:
|
||||
return [True] * len(eqn.params['jaxpr'].in_avals), eqn
|
||||
closed_branches = eqn.params['branches']
|
||||
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]
|
||||
|
||||
|
@ -781,8 +781,6 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
|
||||
|
||||
def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
) -> Tuple[List[bool], core.JaxprEqn]:
|
||||
if not config.after_neurips:
|
||||
return [True] * len(eqn.params['jaxpr'].in_avals), eqn
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
|
||||
num_xs = len(jaxpr.in_avals) - num_consts - num_carry
|
||||
|
@ -515,6 +515,7 @@ from jax.experimental import pjit
|
||||
from jax.interpreters import ad, xla, batching, pxla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import dispatch
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import source_info_util
|
||||
@ -1675,6 +1676,16 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
out_axis_resources=(eqn.params["out_axis_resources"] +
|
||||
(pjit.REPLICATED, pjit.REPLICATED)),
|
||||
)))
|
||||
elif eqn.primitive is ad_checkpoint.remat_p:
|
||||
jaxpr_ = cast(core.Jaxpr, eqn.params["jaxpr"])
|
||||
eqns.append(
|
||||
eqn.replace(
|
||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
jaxpr=_rewrite_jaxpr(jaxpr_, True, True),
|
||||
)))
|
||||
else:
|
||||
raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
|
||||
|
||||
|
@ -16,8 +16,7 @@ import contextlib
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
from typing import Any, Callable, Dict, List, Tuple, Optional
|
||||
|
||||
from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
|
||||
import jax
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.config import config
|
||||
@ -642,8 +641,8 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
|
||||
cotangent_in_avals, reduce_axes):
|
||||
unknowns = map(is_undefined_primal, primals_in)
|
||||
primal_jaxpr, tangent_jaxpr, _, _ = \
|
||||
pe.partial_eval_jaxpr_nounits(_close_jaxpr(call_jaxpr), unknowns=unknowns,
|
||||
instantiate=True) # type: ignore
|
||||
pe.partial_eval_jaxpr_nounits(pe.close_jaxpr(call_jaxpr),
|
||||
unknowns=unknowns, instantiate=True) # type: ignore
|
||||
args, in_tree = tree_flatten((primals_in, cotangents_in))
|
||||
transpose = lu.hashable_partial(lu.wrap_init(_remat_transpose), primal_jaxpr,
|
||||
tangent_jaxpr, reduce_axes)
|
||||
@ -665,10 +664,6 @@ def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes,
|
||||
assert next(cotangents_out, None) is None
|
||||
return outs
|
||||
|
||||
@weakref_lru_cache
|
||||
def _close_jaxpr(jaxpr: core.Jaxpr) -> core.ClosedJaxpr:
|
||||
return core.ClosedJaxpr(jaxpr, [])
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def nonzero_outputs(*args, **kwargs):
|
||||
results = yield args, kwargs
|
||||
@ -717,9 +712,12 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
return tuple(arg_cts)
|
||||
|
||||
|
||||
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
||||
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
|
||||
return _jvp_jaxpr(jaxpr, tuple(nonzeros), inst)
|
||||
def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
|
||||
instantiate: Union[bool, Sequence[bool]]
|
||||
) -> Tuple[core.ClosedJaxpr, List[bool]]:
|
||||
if type(instantiate) is bool:
|
||||
instantiate = (instantiate,) * len(jaxpr.out_avals)
|
||||
return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate))
|
||||
|
||||
@weakref_lru_cache
|
||||
def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
||||
|
@ -1509,6 +1509,10 @@ def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
return used_inputs, new_eqn
|
||||
dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule
|
||||
|
||||
@weakref_lru_cache
|
||||
def close_jaxpr(jaxpr: Jaxpr) -> ClosedJaxpr:
|
||||
return ClosedJaxpr(jaxpr, ())
|
||||
|
||||
def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
|
||||
) -> ClosedJaxpr:
|
||||
"""Reorder `invars` by moving those indicated in `to_move` to the front."""
|
||||
|
@ -4784,6 +4784,16 @@ class RematTest(jtu.JaxTestCase):
|
||||
f_vjp(1.)[0].block_until_ready()
|
||||
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here")
|
||||
def test_fwd_caching(self):
|
||||
# see above test also
|
||||
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
y, _ = jax.vjp(identity, 1.)
|
||||
y.block_until_ready()
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
for suffix, remat in [
|
||||
@ -4851,7 +4861,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
|
||||
modes=['fwd', 'rev'])
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_scan_policy(self):
|
||||
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
|
||||
to_scan = lambda c, _: (jnp.sin(c), jnp.sin(c))
|
||||
@ -4864,7 +4873,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_scan_funky_custom_jvp(self):
|
||||
def scan_apply(f, x):
|
||||
y, _ = lax.scan(lambda x, _: (f(x), None), x, None, length=1)
|
||||
@ -4921,7 +4929,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 2) # +1 b/c dce fixed point
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_scan_funky_custom_jvp2(self):
|
||||
# Like the above test but instead of using jit inside custom_jvp, use scan.
|
||||
|
||||
@ -5081,7 +5088,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
|
||||
modes=['fwd', 'rev'])
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_cond_policy(self):
|
||||
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
|
||||
f = new_checkpoint(lambda x: lax.cond(x > 0, jnp.sin, lambda x: x, x),
|
||||
@ -5093,7 +5099,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 0)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 0)
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_cond_funky_custom_jvp(self):
|
||||
def cond_apply(f, x):
|
||||
return lax.cond(x.sum() > -jnp.inf, f, lambda x: x, x)
|
||||
@ -5149,7 +5154,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 1)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 2)
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
def test_remat_of_cond_funky_custom_jvp2(self):
|
||||
# Like the above test but instead of using jit inside custom_jvp, use cond.
|
||||
|
||||
@ -5395,7 +5399,6 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
self.assertLen(jaxpr.eqns, 0)
|
||||
|
||||
|
||||
@unittest.skipIf(not config.after_neurips, "skip until neurips deadline")
|
||||
class DCETest(jtu.JaxTestCase):
|
||||
|
||||
def assert_dce_result(self, jaxpr: core.Jaxpr, used_outputs: List[bool],
|
||||
@ -5580,9 +5583,9 @@ class DCETest(jtu.JaxTestCase):
|
||||
return out, out
|
||||
|
||||
def f(xs):
|
||||
return lax.scan(scanned_f, 1., xs)
|
||||
return lax.scan(scanned_f, jnp.array(1., 'float32'), xs)
|
||||
|
||||
xs = jnp.arange(10.)
|
||||
xs = jnp.arange(10., dtype='float32')
|
||||
jaxpr = api.make_jaxpr(lambda xs: api.linearize(f, xs)[1])(xs).jaxpr
|
||||
|
||||
jaxpr, used_inputs = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars))
|
||||
|
@ -1979,6 +1979,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
for grad_func in ["grad", "value_and_grad"]
|
||||
for use_remat in ["old", "new", "none"]))
|
||||
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
|
||||
if config.jax_new_checkpoint and use_remat == "old": raise SkipTest()
|
||||
|
||||
def f(x):
|
||||
id_print_result = hcb.id_print(x, output_stream=testing_stream)
|
||||
if use_result:
|
||||
|
Loading…
x
Reference in New Issue
Block a user