41 Commits

Author SHA1 Message Date
Sergei Lebedev
c90f1f0c96 Deprecated accessing config via the jax.config submodule
The `config` object it re-exports is available on the top-level `jax` package,
i.e.

    from jax.config import config

can always safely be replaced with

    from jax import config

or just

    import jax
    jax.config
2023-10-27 20:08:10 +01:00
Jake VanderPlas
7f89fd40a2 Cleanup: remove unused imports in private modules
Also improve our flake8 filter rules to avoid ignoring these.
2022-10-20 14:37:21 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00
Matthew Johnson
399f330f5a disable_omnistaging error, enable_omnistaging warn 2021-03-29 18:11:40 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Jake VanderPlas
40dac9425c pre-release omnistaging cleanup 2021-03-25 16:44:58 -07:00
Peter Hawkins
cac1b891ce [JAX] Refactor NaN/Inf checking in jitted functions.
Avoid performing NaN/Inf checking in the common path for calling a jit-ted function. Instead, add a global/thread-local `posthook` function that, if, set, the C++ jit code calls with the inputs (function, args, kwargs, outputs). Use the posthook feature to implement NaN checking.

Add a `_cache_miss` attribute to the C++ JIT function objects to allow the NaN checking code to extract and call the cache miss function.

PiperOrigin-RevId: 365108787
2021-03-25 13:13:02 -07:00
Peter Hawkins
b1e1d0acd6 Switch disable_jit() to use the common boolean state mechanism. 2021-03-24 19:46:52 -04:00
Matthew Johnson
89768a3d28 add jax_default_matmul_precision flag & context mngr 2021-03-24 14:03:58 -07:00
Peter Hawkins
1d2c7b87a8 Change enable_x64 to use the common boolean configuration mechanism.
To keep C++ and Python state in synchronization, adds new update_..._hook callbacks to the boolean configuration objects that are called when global or thread-local state changes.
2021-03-24 14:45:53 -04:00
Matthew Johnson
70cd62d545 disable disable_omnistaging 2021-03-23 19:13:15 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
George Necula
52408b35a8 Move the flag definitions to config
Without this when running tests the flag parsing may happen before the
host_callback module is loaded and then the host_callback flags may
be left undefined.
2021-03-18 11:09:33 +01:00
Peter Hawkins
23756a040b [JAX] Refactor handling of JIT interpreter state in jax_jit API.
Create separate holder objects for global and thread-local state, and move enable_x64 and disable_jit context into the holder objects.

Expose the global and per-thread state objects to Python via pybind11.

Refactoring only; no functional changes intended.

PiperOrigin-RevId: 363510449
2021-03-17 14:39:34 -07:00
Peter Hawkins
328930b917 Increase minimum jaxlib version to 0.1.62. 2021-03-16 15:11:36 -04:00
Jean-Baptiste Lespiau
654a5b332c Remove a callback to Python to get the value of some flag.
This is done to simplify the code, and not at all for performance, because it's only executed during the compilation phase.

One possible design question: should we let the user access the value of the flag if it has not been set? Right now, the Python code allows it I think (meaning the behavior may not match the flag value, which has not been parsed yet).
We could raise an error, by setting the flag value to absl::nullopt, and check it's not null. But it would be a breaking change, so I am a little reluctant doing so.

PiperOrigin-RevId: 360407549
2021-03-02 05:39:52 -08:00
Jean-Baptiste Lespiau
5d11d101c6 Move the x64 context manager threadlocal state from jax/python to xla/c++
Fixes #5532.

PiperOrigin-RevId: 360252057
2021-03-01 12:33:11 -08:00
Jake VanderPlas
2fd682ef2a Make jax_enable_x64 a thread-local value. 2021-02-04 09:48:22 -08:00
Matthew Johnson
a7bfebe4bc improve leak checker flag description 2021-01-23 14:17:22 -08:00
Matthew Johnson
203af4517b revive the leak checker, as a debug mode
Co-authored-by: James Bradbury <jekbradbury@google.com>
2021-01-22 18:31:00 -08:00
Matthew Johnson
84e91d5f1d add transformed fun src info to escaped tracer err
This change adds to the error message when we hit an escaped tracer. In
particular, it adds source info for the function that was transformed.

This change currently only applies to escaped `DynamicJaxprTracer`s
(arising from `jit`, `pmap`, `scan`, and other staging functions) and
not other traces. A natural follow-up would be to attach this
information to other traces.

Co-authored-by: Lena Martens <lenamartens@google.com>
2021-01-20 15:30:37 -08:00
Peter Hawkins
e4e1af5d91 Reenable pytype for previously problematic ABSL library.
The problematic change to ABSL was rolled back.

PiperOrigin-RevId: 343421000
2020-11-19 20:13:16 -08:00
Peter Hawkins
c8e6a6a061 Internal change
PiperOrigin-RevId: 343407824
2020-11-19 18:05:03 -08:00
Matthew Johnson
2678a4647a
omnistaging on by default (#4038) 2020-09-15 08:06:46 -07:00
Tom Hennigan
f0fb7d0925
Use omnistaging env var even when not using absl flags for config. (#4152) 2020-08-26 14:06:27 -07:00
Matthew Johnson
822344d654
add placeholder disable_omnistaging method (#4054) 2020-08-13 15:25:39 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Matthew Johnson
b3806ce874
add placeholder enable_omnistaging method (#3907) 2020-07-30 00:42:32 -07:00
Matthew Johnson
c8771e12e0
add omnistaging flag placeholder (#3904) 2020-07-29 21:44:03 -07:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Jamie Townsend
670fab59cf
Test code in docs and api.py docstrings (#2994)
Also remove jaxpr doc tests from api_test.py.
2020-05-16 16:19:24 +03:00
George Necula
2e9047d388
Add flag to enable checking, and turn on checking in tests. (#2900)
Fix an error in check_jaxpr.
2020-05-01 09:16:31 +03:00
George Necula
b79c7948ee Removed dependency on distutils.strtobool 2020-02-06 17:27:46 +01:00
James Bradbury
d174fd660c fix config.update when absl is enabled 2019-10-02 16:08:05 -07:00
Peter Hawkins
ca0d943999 Test case improvements:
* use numpy.random to select test cases, rather than random. This allows more control over random seeds. Pick a fixed random seed for each test case.
* sort types in linalg_test.py so the choice of test cases is deterministic.
* use known_flags=True when doing early parsing of flags from parse_flags_with_absl.
2019-04-12 10:48:11 -04:00
Anselm Levskaya
116e329e10 correctly update jax config.values after absl flag parsing 2019-04-04 02:09:35 -07:00
Matthew Johnson
a285017110 fix failing tests (misc small bugs) 2018-12-13 11:52:41 -08:00
Dougal Maclaurin
8b88027df0 Number of test cases settable with command-line flag 2018-12-06 18:30:59 -05:00
Dougal Maclaurin
2df36f7510 Made a shim to handle configuration without having absl parse command-line flags.
PiperOrigin-RevId: 223391288
2018-11-29 13:44:54 -08:00