30 Commits

Author SHA1 Message Date
Matthew Johnson
c26c77d2c3 fix a 'store occupied' error in jax_debug_nans
This code snippet could cause a 'store occupied' error:

@jit
def f(x):
  return x + np.nan

FLAGS.jax_debug_nans = True
f(1)

The reason is that in xla._xla_call_impl we would run a
linear_util.WrappedFun twice, first via xla._xla_callable and then again
directly (i.e. in op-by-op) if we got a nan on the output. Things would
work fine if the second execution also raised a nan error, since then
the WrappedFun wouldn't complete execution, but if the second execution
does not raise an error (as in the above case, because `1 + np.nan`
doesn't involve any jax primitive executions) then we'd end up with a
StoreOccupied error from running the WrappedFun twice.

The fix is just to intentionally allow re-running the WrappedFun, since
the whole point of jax_debug_nans is to re-run functions that in normal
circumstances we would only want to execute exactly once.
2020-10-07 12:17:24 -07:00
Jean-Baptiste Lespiau
c88be879d7 Use the Python jit for the compilation in the C++ jit.
Using xla_computation has been a doomed attempt, because it does not support all the features, and cannot deal with, in particular, nested tracing.

Thus, we directly use the current path, and use a thread local value to access the last compiled objects from C++ (it allows to not touch the Python tracing logic).

This also:
- Delay the access of jax_enable_64 to after GoogleInit.

PiperOrigin-RevId: 335910130
2020-10-07 11:15:22 -07:00
Akihiro Nitta
06170da69a
Use raise from 2020-09-30 01:20:00 +09:00
Adam Paszke
3f8aaabbcc Interrupt lu transformation generators whenever an exception occurs
This fixes some errors that have been appearing in our CI from time to
time. All transformations are implemented as generators, but they
haven't been explicitly aborted when an exception has been raised.
Instead, they only got closed when they got garbage collected, which
could happen at an unspecified later time, potentially leading to a
corruption of global state, which could have been modified after the
exception was handled.

Note that this implementation doesn't propagate the original exception
into the argument transformations, and doesn't allow them to handle the
error either. Such an extension would be possible, but throwing an
exception into a generator mutates the exception object, clobbering
the nice traceback that we would usually carry. One can work around
those issues, but it feels really hacky and we don't need it right now
anyway, so I figured we'll be better off with the simple thing for the
time being.
2020-09-09 20:43:05 +02:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Matthew Johnson
6193e5e4dc revamp custom_jvp/vjp implementation to fix bugs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2020-03-29 19:35:01 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
George Necula
282225f676
Added some pytype annotations (#2386)
Tried to catch all uses of linear_util.WrappedFun
2020-03-09 20:41:01 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
George Necula
528a69f32e Added some more documentation to the linear_util module
Also cleaned up the inconsistent way of importing the module.
Prefer importing with qualified name 'lu.transformation' rather
than just 'transformation'.
2020-01-05 16:40:26 +01:00
Skye Wanderman-Milne
891aecb941
Add test utilities for counting compilations. (#1895)
Also uses the new utilities to check that pmap doesn't compile constant computations.
2019-12-19 11:19:58 -08:00
Matthew Johnson
9a8523603c Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-27 19:52:24 -08:00
Matthew Johnson
2867e4be08 fix grad of jit caching bug
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-11-26 09:04:35 -08:00
Matthew Johnson
d2156ea1c5 improve names, avoid double lookup (thanks @hawkinsp) 2019-10-31 16:26:29 -07:00
Matthew Johnson
8bcee8d45f fix a leak where compiled results lived too long
The original repro @levskaya showed us was essentially this OOM:

  for i in range(40):
    f = jit(lambda: 1. * np.ones((300, 1024, 1024)))
    f().block_until_ready()

Even though f was being rebound on every iteration, the cache entries
corresponding to the previous iterations of the loop were sticking around.

Instead, if the user drops all references to a function, we want to clear the
corresponding compilation cache entries (since they can never be used).

The fix here is to use a two-level cache for compiled code: the first level is
a WeakKeyDictionary keyed by the raw Python callable underlying the WrappedFun,
and the second level is a regular dictionary keyed by (transforms, params,
args). Because this logic is now present in linear_util.py:cache, the
implementations of WrappedFun.__eq__ and WrappedFun.__hash__ may be superfluous
now.

One unintended consequence is that this implementation now avoids using
fastcache.crlu_cache for the jit and pmap compilation caches. It was easier to
implement this logic in pure Python. We might want to revise this for
performance reasons.

This commit also incidentally fixed #1600.
2019-10-31 16:26:29 -07:00
Matthew Johnson
98c7567a0d add flag for logging when jit performs compilation 2019-08-23 08:17:41 -07:00
Peter Hawkins
e68feda4ce Optimize linear_util cache lookup.
Simplify code to build a hashable payload:
* Sort the params at construction time.
* Separate stores and transforms into separate fields, to make it easier to hash and test equality of the non-stores.
* Don't build tuples in __eq__(), instead just test fields for equality directly.

Make some small optimizations to the Store implementation: use __slots__, test for object identity rather than relying on hasattr().
2019-08-13 09:49:27 -04:00
Peter Hawkins
a8ddf071bd Add test case for concurrent device_get and device_put calls.
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.
2019-08-09 13:12:44 -04:00
Peter Hawkins
08013954a4 Use fastcache for LRU caches in JAX.
fastcache is both a faster cache implementation and is also thread-safe.
2019-07-22 17:24:10 -04:00
Matthew Johnson
2582598294 remove assert that fails on python3 map objects 2019-04-10 22:17:54 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Peter Hawkins
e7f9fbe4f4 Minor typo fixes. 2019-03-12 15:28:07 -04:00
Peter Hawkins
43676de13e Remove unnecessary output tupling. 2019-03-12 15:25:13 -04:00
Peter Hawkins
1800d6554d Add comments to linear_util.py. 2019-03-12 15:07:52 -04:00
Peter Hawkins
2b383bdbd9 Increase cache size to 4096. 2019-01-15 10:32:58 -05:00
Peter Hawkins
3266bb3122 Change linear_util.memoize to use an LRU cache.
Add util.OrderedDict that retrofits a move_to_end method onto Python 2 OrderedDicts.
2019-01-14 21:48:28 -05:00
Peter Hawkins
5e60639bc5 source sync
PiperOrigin-RevId: 222452709
2018-11-21 20:22:54 -08:00
Peter Hawkins
e180f08113 source sync
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Matthew Johnson
a30e858e59 populating source tree 2018-11-17 18:03:33 -08:00