The `config` object it re-exports is available on the top-level `jax` package,
i.e.
from jax.config import config
can always safely be replaced with
from jax import config
or just
import jax
jax.config
Avoid performing NaN/Inf checking in the common path for calling a jit-ted function. Instead, add a global/thread-local `posthook` function that, if, set, the C++ jit code calls with the inputs (function, args, kwargs, outputs). Use the posthook feature to implement NaN checking.
Add a `_cache_miss` attribute to the C++ JIT function objects to allow the NaN checking code to extract and call the cache miss function.
PiperOrigin-RevId: 365108787
To keep C++ and Python state in synchronization, adds new update_..._hook callbacks to the boolean configuration objects that are called when global or thread-local state changes.
Without this when running tests the flag parsing may happen before the
host_callback module is loaded and then the host_callback flags may
be left undefined.
Create separate holder objects for global and thread-local state, and move enable_x64 and disable_jit context into the holder objects.
Expose the global and per-thread state objects to Python via pybind11.
Refactoring only; no functional changes intended.
PiperOrigin-RevId: 363510449
This is done to simplify the code, and not at all for performance, because it's only executed during the compilation phase.
One possible design question: should we let the user access the value of the flag if it has not been set? Right now, the Python code allows it I think (meaning the behavior may not match the flag value, which has not been parsed yet).
We could raise an error, by setting the flag value to absl::nullopt, and check it's not null. But it would be a breaking change, so I am a little reluctant doing so.
PiperOrigin-RevId: 360407549
This change adds to the error message when we hit an escaped tracer. In
particular, it adds source info for the function that was transformed.
This change currently only applies to escaped `DynamicJaxprTracer`s
(arising from `jit`, `pmap`, `scan`, and other staging functions) and
not other traces. A natural follow-up would be to attach this
information to other traces.
Co-authored-by: Lena Martens <lenamartens@google.com>
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.
* use numpy.random to select test cases, rather than random. This allows more control over random seeds. Pick a fixed random seed for each test case.
* sort types in linalg_test.py so the choice of test cases is deterministic.
* use known_flags=True when doing early parsing of flags from parse_flags_with_absl.