We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.
See https://github.com/google/jax/pull/1749 for more.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
The original repro @levskaya showed us was essentially this OOM:
for i in range(40):
f = jit(lambda: 1. * np.ones((300, 1024, 1024)))
f().block_until_ready()
Even though f was being rebound on every iteration, the cache entries
corresponding to the previous iterations of the loop were sticking around.
Instead, if the user drops all references to a function, we want to clear the
corresponding compilation cache entries (since they can never be used).
The fix here is to use a two-level cache for compiled code: the first level is
a WeakKeyDictionary keyed by the raw Python callable underlying the WrappedFun,
and the second level is a regular dictionary keyed by (transforms, params,
args). Because this logic is now present in linear_util.py:cache, the
implementations of WrappedFun.__eq__ and WrappedFun.__hash__ may be superfluous
now.
One unintended consequence is that this implementation now avoids using
fastcache.crlu_cache for the jit and pmap compilation caches. It was easier to
implement this logic in pure Python. We might want to revise this for
performance reasons.
This commit also incidentally fixed#1600.
Simplify code to build a hashable payload:
* Sort the params at construction time.
* Separate stores and transforms into separate fields, to make it easier to hash and test equality of the non-stores.
* Don't build tuples in __eq__(), instead just test fields for equality directly.
Make some small optimizations to the Store implementation: use __slots__, test for object identity rather than relying on hasattr().
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.