This code snippet could cause a 'store occupied' error:
@jit
def f(x):
return x + np.nan
FLAGS.jax_debug_nans = True
f(1)
The reason is that in xla._xla_call_impl we would run a
linear_util.WrappedFun twice, first via xla._xla_callable and then again
directly (i.e. in op-by-op) if we got a nan on the output. Things would
work fine if the second execution also raised a nan error, since then
the WrappedFun wouldn't complete execution, but if the second execution
does not raise an error (as in the above case, because `1 + np.nan`
doesn't involve any jax primitive executions) then we'd end up with a
StoreOccupied error from running the WrappedFun twice.
The fix is just to intentionally allow re-running the WrappedFun, since
the whole point of jax_debug_nans is to re-run functions that in normal
circumstances we would only want to execute exactly once.
Using xla_computation has been a doomed attempt, because it does not support all the features, and cannot deal with, in particular, nested tracing.
Thus, we directly use the current path, and use a thread local value to access the last compiled objects from C++ (it allows to not touch the Python tracing logic).
This also:
- Delay the access of jax_enable_64 to after GoogleInit.
PiperOrigin-RevId: 335910130
This fixes some errors that have been appearing in our CI from time to
time. All transformations are implemented as generators, but they
haven't been explicitly aborted when an exception has been raised.
Instead, they only got closed when they got garbage collected, which
could happen at an unspecified later time, potentially leading to a
corruption of global state, which could have been modified after the
exception was handled.
Note that this implementation doesn't propagate the original exception
into the argument transformations, and doesn't allow them to handle the
error either. Such an extension would be possible, but throwing an
exception into a generator mutates the exception object, clobbering
the nice traceback that we would usually carry. One can work around
those issues, but it feels really hacky and we don't need it right now
anyway, so I figured we'll be better off with the simple thing for the
time being.
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.
Also cleaned up the inconsistent way of importing the module.
Prefer importing with qualified name 'lu.transformation' rather
than just 'transformation'.
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.
See https://github.com/google/jax/pull/1749 for more.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
The original repro @levskaya showed us was essentially this OOM:
for i in range(40):
f = jit(lambda: 1. * np.ones((300, 1024, 1024)))
f().block_until_ready()
Even though f was being rebound on every iteration, the cache entries
corresponding to the previous iterations of the loop were sticking around.
Instead, if the user drops all references to a function, we want to clear the
corresponding compilation cache entries (since they can never be used).
The fix here is to use a two-level cache for compiled code: the first level is
a WeakKeyDictionary keyed by the raw Python callable underlying the WrappedFun,
and the second level is a regular dictionary keyed by (transforms, params,
args). Because this logic is now present in linear_util.py:cache, the
implementations of WrappedFun.__eq__ and WrappedFun.__hash__ may be superfluous
now.
One unintended consequence is that this implementation now avoids using
fastcache.crlu_cache for the jit and pmap compilation caches. It was easier to
implement this logic in pure Python. We might want to revise this for
performance reasons.
This commit also incidentally fixed#1600.
Simplify code to build a hashable payload:
* Sort the params at construction time.
* Separate stores and transforms into separate fields, to make it easier to hash and test equality of the non-stores.
* Don't build tuples in __eq__(), instead just test fields for equality directly.
Make some small optimizations to the Store implementation: use __slots__, test for object identity rather than relying on hasattr().
Fix concurrency problems in memoize_... decorators.
Rename util.memoize to util.cache.
Remove util.memoize_unary and xla_bridge.memoize_thunk, replace with more general and thread-safe util.memoize that wraps fastcache.