209 Commits

Author SHA1 Message Date
Roy Frostig
3f9540761e reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
Selam Waktola
b02f82b815 redundant phrase 'ever time' removed 2024-04-22 11:47:23 -07:00
Sai-Suraj-27
60cd5af67a Made the error messages when raising TypeError better. 2024-04-10 09:27:47 +05:30
Sai-Suraj-27
5564521308 Prefer raising of TypeError for invalid types instead of ValueError. 2024-04-08 13:08:24 +05:30
Matthew Johnson
3d4687fbfc add a temporary config option to disable custom_vjp shape checking 2024-04-04 18:21:10 -07:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
jax authors
d8f231a85e Merge pull request #20250 from jakevdp:key-reuse-jit
PiperOrigin-RevId: 616971171
2024-03-18 15:58:11 -07:00
Peter Hawkins
ee2631e4da Remove --jax_parallel_functions_output_gda.
PiperOrigin-RevId: 616898032
2024-03-18 11:46:06 -07:00
Jake VanderPlas
ae4e273b74 Add key reuse config to trace context 2024-03-14 06:59:37 -07:00
Peter Hawkins
cf856ad4a9 Reverts 8e2a8b7b95e838947dcf581d146909d5c4128742
PiperOrigin-RevId: 615401711
2024-03-13 07:01:49 -07:00
Matthew Johnson
dd0ce6e2ff add upgrade (aka to-be-removed) flag for new select rule 2024-03-01 09:50:01 -08:00
Jake VanderPlas
d08e9a03d8 [key reuse] add eager checks 2024-02-29 15:30:19 -08:00
Peter Hawkins
aad02dba7e Increase minimum jaxlib version to 0.4.20.
jaxlib 0.4.20 has xla_extension_version 210 and mlir_api_version 54.

PiperOrigin-RevId: 609094229
2024-02-21 12:58:57 -08:00
Peter Hawkins
d713f3a632 Fix test flake which occurred due to a spurious cache misses.
The StateContextManager restores its thread-local state to None, which means that the
initial thread-local state must also be None if the context manager is
to correctly restore the initial state.

This caused a test failure in a test case in pmap_test which checked for
exactly one cache entry across threads. One thread had used the
softmax_custom_jvp context manager, and had a different state (None)
instead of False.
2024-02-20 22:54:31 +00:00
Sergei Lebedev
57e59eb6c3 Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Peter Hawkins
2165611584 Fix code to populate defaults for boolean flags from environment variables.
PiperOrigin-RevId: 608574620
2024-02-20 05:57:23 -08:00
Thomas Köppe
dcc65e621e Reverts b506fee9e389391efb1336bc7575dba913e75cdf
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
b506fee9e3 Removed deprecated jax.config methods and jax.config.config
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81

PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
jax authors
eb03436835 Reverts 318a19a89387caebd116168c4e47592e7d71ca65
PiperOrigin-RevId: 607708463
2024-02-16 09:11:05 -08:00
Sergei Lebedev
318a19a893 Removed deprecated jax.config methods
PiperOrigin-RevId: 607675571
2024-02-16 06:49:13 -08:00
Peter Hawkins
86c8af7025 Make config.value_holders private to avoid misuse.
Readd a config.values property to fix breakage to end users caused by its removal.

PiperOrigin-RevId: 606429238
2024-02-12 18:00:18 -08:00
Peter Hawkins
995bdaa912 Small optimization to trace_context.
Before:
```
In [2]: %timeit jax._src.config.config._trace_context()
The slowest run took 23.63 times longer than the fastest. This could mean that an intermediate result is being cached
100000 loops, best of 5: 3.5 µs per loop
```

After:
```
In [5]: %timeit jax._src.config.trace_context()
The slowest run took 12.16 times longer than the fastest. This could mean that an intermediate result is being cached
100000 loops, best of 5: 2.59 µs per loop
```

It's slightly faster to access flags directly via the holder object, rather than via `jax.config`.

PiperOrigin-RevId: 606366377
2024-02-12 14:30:11 -08:00
Peter Hawkins
c1f234a95c Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.

Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:

```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```

Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```

i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.

In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.

I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.

PiperOrigin-RevId: 606340656
2024-02-12 13:04:38 -08:00
Peter Hawkins
6b8c99c60e Tighten up typing of optional config values.
Distinguish optional config values from non-optional config values, and accurately type optional config values.

PiperOrigin-RevId: 606235519
2024-02-12 06:44:02 -08:00
jax authors
9a098e922a Share autotune config between hosts.
PiperOrigin-RevId: 604569298
2024-02-06 01:28:18 -08:00
George Necula
fdf227e7b2 [export] Set default native serialization version to 9.
This version adds better support for JAX effects.

See description in CHANGELOG.md and also at
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.

PiperOrigin-RevId: 603579274
2024-02-01 21:56:03 -08:00
Peter Hawkins
fc6df3218c Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.

i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.

Why do this?

The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.

The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.

This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.

Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.

The change is disabled by default, so we do not expect any user visible impacts from this change.

PiperOrigin-RevId: 599787818
2024-01-19 03:53:37 -08:00
jax authors
ab8eb896d7 Document several jax.config methods.
PiperOrigin-RevId: 598775913
2024-01-16 02:26:41 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
jax authors
b8b119d9b9 Cleanup deprecated compilation cache APIs.
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.

Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.

Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
2024-01-12 22:44:48 -08:00
jax authors
adbbe69cc2 Add option to share compiled module between hosts.
PiperOrigin-RevId: 597754861
2024-01-11 23:38:02 -08:00
jax authors
ea66029731 Introduce min entry size check for compilation cache.
Currently, the persistent compilation cache has a time
threshold: the entry is cached only if the compilation
time is less than the threshold. If compilation happens
to take a while, but the resulting executable is small,
there is nothing that prevents caching. This can result
in a large number of small files in the cache.

Introduce a size threshold. If the resulting executable's
size (after serialization and compression) is less than
this threshold, don't cache. This check is in addition to
the compilation time check described above.

Testing: new unit test, test workload.
PiperOrigin-RevId: 595815611
2024-01-04 15:17:05 -08:00
jax authors
981b670453 [Jax] Allow to set the python traceback frames limit.
PiperOrigin-RevId: 595607107
2024-01-03 23:26:55 -08:00
Matthew Johnson
9112dcebc9 add jax.explain_cache_misses tracing cache miss explanations
As part of making JAX's behavior more transparent, it must be clear not only
when code is slow because it's spending all its time missing caches (and hence
retracing/recompiling), but also _why_ it missed those caches. That is, just
knowing (from e.g. setting jax_log_compiles) that code is retracing a lot
doesn't tell the user what to do to fix things. But once the user knows that
the cache misses are due to changing dtypes, or due to jit being passed a new
callable object on every iteration of a loop, it's often clear what to do. And
JAX can provide that information

The main idea here is that pointing out which parts of the cache key differs
from previously-seen keys can constitute a pretty good explanation.

This PR adds an explanation mechanism. It can be enabled in a few different ways:
  * setting the `JAX_EXPLAIN_CACHE_MISSES` shell environment variable to something truthy;
  * setting the config option `jax.config.update('jax_explain_cache_misses', True)`;
  * using the context manager `jax._src.config.explain_cache_misses` context
    manager (not in public namespace yet);
  * when parsing command line flags with absl, using the
    `--jax_explain_cache_misses` flag.

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-12-26 21:54:27 -08:00
jax authors
259c285b10 [Jax] Enable jax_include_full_tracebacks_in_locations by default
PiperOrigin-RevId: 591783126
2023-12-17 21:56:13 -08:00
jax authors
196c97fa0c Merge pull request #18949 from froystig:seed-offset
PiperOrigin-RevId: 590637382
2023-12-13 10:18:40 -08:00
Roy Frostig
671790730e introduce a config flag to control a random seed offset 2023-12-12 18:31:07 -08:00
jax authors
32c99f627e Remove the old cache-key generation code.
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.

Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.

Testing: test workload.
PiperOrigin-RevId: 590378036
2023-12-12 16:34:32 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
jax authors
b9b5410ddd Default-enable the Jax persistent compilation cache.
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.

Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.

Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
2023-11-27 14:53:20 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Yash Katariya
cf3c041366 Disable jax memories flag.
PiperOrigin-RevId: 580961421
2023-11-09 10:54:02 -08:00
jax authors
28e33ca5d1 Switch to the new cache-key generation algorithm.
The new cache-key generation algorithm is more robust and
results in fewer stale entries being returned.

Testing: test workloads.
PiperOrigin-RevId: 579928158
2023-11-06 12:57:01 -08:00
Roy Frostig
b22e75716f add threefry_partitionable config setting to thread-local JIT context 2023-10-31 13:45:49 -07:00
Sergei Lebedev
fd3a8b2cc6 Deprecated define_* and DEFINE_* methods on jax.config
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
2023-10-29 20:58:19 +00:00
jax authors
9ba305cced Invalidate in-memory caches on XLA-AutoFDO profile version change.
When the value in --jax_xla_profile_version changes, all tracing
and compilation caches should be invalidated since the XLA programs
need to be recompiled with the new XLA-AutoFDO profile.

Testing:
. New unit test.
. Test workload with instrumentation to repeatedly change
  the profile version. Before/after comparison.
PiperOrigin-RevId: 577280639
2023-10-27 12:52:57 -07:00
Sergei Lebedev
f2ce5dbd01 MAINT Do not use str() and repr() in f-string replacement fields
`str()` is called by default by the formatting machinery, and `repr()` only
needs `!r`.
2023-10-23 15:12:04 +01:00
George Necula
e89212c81a [export] Set the default export serialization version to 8.
This version has been supported by XlaCallModule since July 21, 2023 and we are now past the forward-compatibility window.

See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions

Reverts ae81ac9cc21696a22b973b1eae6ce222c7318ba7

PiperOrigin-RevId: 575382324
2023-10-20 21:00:55 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00