mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
In jax.remat/jax.checkpoint, don't cache on Tracers in static args
Why do we have caching in jax.remat at all? I added it in https://github.com/google/jax/pull/11743 without much justification other than it made some tests faster. I think I was worried that the switch to the new remat's "initial-style" (jaxpr forming up-front) approach would regress eager-mode performance, so I added benchmarks to measure it and then made those fast with caching. But the caching seems a bit too aggressive when static_argnums are involved. In particular, I allowed caching on Tracer arguments (by object id). That seems dangerous! So the change here is to check whether any of the arguments marked static by static_argnums are Tracers. If so, skip the caching. This change happens not to affect the benchmarks at all. PiperOrigin-RevId: 529502687
This commit is contained in:
parent
e6e6490ab0
commit
2845df03fc
@ -348,9 +348,14 @@ class WrapHashably:
|
||||
# See api_benchmark.py:bench_remat_eager_retracing_overheads_static_argnums.
|
||||
# On that benchmark, including this caching makes a ~10x difference (which can
|
||||
# be made arbitrary large by involving larger functions to be traced).
|
||||
@weakref_lru_cache
|
||||
def _dyn_args_fun(fun: Callable, static_argnums: FrozenSet[int],
|
||||
static_args: Tuple[WrapHashably, ...], nargs: int):
|
||||
if any(isinstance(x.val, core.Tracer) for x in static_args):
|
||||
return _dyn_args_fun_uncached(fun, static_argnums, static_args, nargs)
|
||||
return _dyn_args_fun_cached(fun, static_argnums, static_args, nargs)
|
||||
|
||||
def _dyn_args_fun_uncached(fun: Callable, static_argnums: FrozenSet[int],
|
||||
static_args: Tuple[WrapHashably, ...], nargs: int):
|
||||
def new_fun(*dyn_args, **kwargs):
|
||||
static_args_, dyn_args_ = iter(static_args), iter(dyn_args)
|
||||
full_args = [next(static_args_).val if i in static_argnums
|
||||
@ -358,6 +363,8 @@ def _dyn_args_fun(fun: Callable, static_argnums: FrozenSet[int],
|
||||
return fun(*full_args, **kwargs)
|
||||
return new_fun
|
||||
|
||||
_dyn_args_fun_cached = weakref_lru_cache(_dyn_args_fun_uncached)
|
||||
|
||||
# This helper is similar to those in control_flow/common.py, but with
|
||||
# remat-specific errors.
|
||||
@weakref_lru_cache
|
||||
|
Loading…
x
Reference in New Issue
Block a user