18 Commits

Author SHA1 Message Date
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
Matthew Johnson
2867e4be08 fix grad of jit caching bug
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-26 09:04:35 -08:00
Matthew Johnson
d2156ea1c5 improve names, avoid double lookup (thanks @hawkinsp) 2019-10-31 16:26:29 -07:00
Matthew Johnson
8bcee8d45f fix a leak where compiled results lived too long
The original repro @levskaya showed us was essentially this OOM:

  for i in range(40):
    f = jit(lambda: 1. * np.ones((300, 1024, 1024)))
    f().block_until_ready()

Even though f was being rebound on every iteration, the cache entries
corresponding to the previous iterations of the loop were sticking around.

Instead, if the user drops all references to a function, we want to clear the
corresponding compilation cache entries (since they can never be used).

The fix here is to use a two-level cache for compiled code: the first level is
a WeakKeyDictionary keyed by the raw Python callable underlying the WrappedFun,
and the second level is a regular dictionary keyed by (transforms, params,
args). Because this logic is now present in linear_util.py:cache, the
implementations of WrappedFun.__eq__ and WrappedFun.__hash__ may be superfluous
now.

One unintended consequence is that this implementation now avoids using
fastcache.crlu_cache for the jit and pmap compilation caches. It was easier to
implement this logic in pure Python. We might want to revise this for
performance reasons.

This commit also incidentally fixed #1600.
2019-10-31 16:26:29 -07:00
Matthew Johnson
98c7567a0d add flag for logging when jit performs compilation 2019-08-23 08:17:41 -07:00
Peter Hawkins
e68feda4ce Optimize linear_util cache lookup.
Simplify code to build a hashable payload:
* Sort the params at construction time.
* Separate stores and transforms into separate fields, to make it easier to hash and test equality of the non-stores.
* Don't build tuples in __eq__(), instead just test fields for equality directly.

Make some small optimizations to the Store implementation: use __slots__, test for object identity rather than relying on hasattr().
2019-08-13 09:49:27 -04:00
Peter Hawkins
a8ddf071bd Add test case for concurrent device_get and device_put calls.
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00
Peter Hawkins
08013954a4 Use fastcache for LRU caches in JAX.
fastcache is both a faster cache implementation and is also thread-safe.
2019-07-22 17:24:10 -04:00
Matthew Johnson
2582598294 remove assert that fails on python3 map objects 2019-04-10 22:17:54 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Peter Hawkins
e7f9fbe4f4 Minor typo fixes. 2019-03-12 15:28:07 -04:00
Peter Hawkins
43676de13e Remove unnecessary output tupling. 2019-03-12 15:25:13 -04:00
Peter Hawkins
1800d6554d Add comments to linear_util.py. 2019-03-12 15:07:52 -04:00
Peter Hawkins
2b383bdbd9 Increase cache size to 4096. 2019-01-15 10:32:58 -05:00
Peter Hawkins
3266bb3122 Change linear_util.memoize to use an LRU cache.
Add util.OrderedDict that retrofits a move_to_end method onto Python 2 OrderedDicts.
2019-01-14 21:48:28 -05:00
Peter Hawkins
5e60639bc5 source sync
PiperOrigin-RevId: 222452709
2018-11-21 20:22:54 -08:00
Peter Hawkins
e180f08113 source sync
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Matthew Johnson
a30e858e59 populating source tree 2018-11-17 18:03:33 -08:00