In jax.remat/jax.checkpoint, don't cache on Tracers in static args

Why do we have caching in jax.remat at all? I added it in
https://github.com/google/jax/pull/11743 without much justification other than
it made some tests faster. I think I was worried that the switch to the new
remat's "initial-style" (jaxpr forming up-front) approach would regress
eager-mode performance, so I added benchmarks to measure it and then made those
fast with caching.

But the caching seems a bit too aggressive when static_argnums are involved. In
particular, I allowed caching on Tracer arguments (by object id). That seems
dangerous!

So the change here is to check whether any of the arguments marked static by
static_argnums are Tracers. If so, skip the caching. This change happens not to
affect the benchmarks at all.

PiperOrigin-RevId: 529502687
This commit is contained in:
Matthew Johnson 2023-05-04 13:41:13 -07:00 committed by jax authors
parent e6e6490ab0
commit 2845df03fc

View File

@ -348,9 +348,14 @@ class WrapHashably:
# See api_benchmark.py:bench_remat_eager_retracing_overheads_static_argnums.
# On that benchmark, including this caching makes a ~10x difference (which can
# be made arbitrary large by involving larger functions to be traced).
@weakref_lru_cache
def _dyn_args_fun(fun: Callable, static_argnums: FrozenSet[int],
static_args: Tuple[WrapHashably, ...], nargs: int):
if any(isinstance(x.val, core.Tracer) for x in static_args):
return _dyn_args_fun_uncached(fun, static_argnums, static_args, nargs)
return _dyn_args_fun_cached(fun, static_argnums, static_args, nargs)
def _dyn_args_fun_uncached(fun: Callable, static_argnums: FrozenSet[int],
static_args: Tuple[WrapHashably, ...], nargs: int):
def new_fun(*dyn_args, **kwargs):
static_args_, dyn_args_ = iter(static_args), iter(dyn_args)
full_args = [next(static_args_).val if i in static_argnums
@ -358,6 +363,8 @@ def _dyn_args_fun(fun: Callable, static_argnums: FrozenSet[int],
return fun(*full_args, **kwargs)
return new_fun
_dyn_args_fun_cached = weakref_lru_cache(_dyn_args_fun_uncached)
# This helper is similar to those in control_flow/common.py, but with
# remat-specific errors.
@weakref_lru_cache