On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`
PiperOrigin-RevId: 629763763
The StateContextManager restores its thread-local state to None, which means that the
initial thread-local state must also be None if the context manager is
to correctly restore the initial state.
This caused a test failure in a test case in pmap_test which checked for
exactly one cache entry across threads. One thread had used the
softmax_custom_jvp context manager, and had a different state (None)
instead of False.
Before:
```
In [2]: %timeit jax._src.config.config._trace_context()
The slowest run took 23.63 times longer than the fastest. This could mean that an intermediate result is being cached
100000 loops, best of 5: 3.5 µs per loop
```
After:
```
In [5]: %timeit jax._src.config.trace_context()
The slowest run took 12.16 times longer than the fastest. This could mean that an intermediate result is being cached
100000 loops, best of 5: 2.59 µs per loop
```
It's slightly faster to access flags directly via the holder object, rather than via `jax.config`.
PiperOrigin-RevId: 606366377
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.
PiperOrigin-RevId: 598550225
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.
Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.
Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
Currently, the persistent compilation cache has a time
threshold: the entry is cached only if the compilation
time is less than the threshold. If compilation happens
to take a while, but the resulting executable is small,
there is nothing that prevents caching. This can result
in a large number of small files in the cache.
Introduce a size threshold. If the resulting executable's
size (after serialization and compression) is less than
this threshold, don't cache. This check is in addition to
the compilation time check described above.
Testing: new unit test, test workload.
PiperOrigin-RevId: 595815611
As part of making JAX's behavior more transparent, it must be clear not only
when code is slow because it's spending all its time missing caches (and hence
retracing/recompiling), but also _why_ it missed those caches. That is, just
knowing (from e.g. setting jax_log_compiles) that code is retracing a lot
doesn't tell the user what to do to fix things. But once the user knows that
the cache misses are due to changing dtypes, or due to jit being passed a new
callable object on every iteration of a loop, it's often clear what to do. And
JAX can provide that information
The main idea here is that pointing out which parts of the cache key differs
from previously-seen keys can constitute a pretty good explanation.
This PR adds an explanation mechanism. It can be enabled in a few different ways:
* setting the `JAX_EXPLAIN_CACHE_MISSES` shell environment variable to something truthy;
* setting the config option `jax.config.update('jax_explain_cache_misses', True)`;
* using the context manager `jax._src.config.explain_cache_misses` context
manager (not in public namespace yet);
* when parsing command line flags with absl, using the
`--jax_explain_cache_misses` flag.
Co-authored-by: Yash Katariya <yashkatariya@google.com>
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.
Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.
Testing: test workload.
PiperOrigin-RevId: 590378036
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.
Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.
Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
The new cache-key generation algorithm is more robust and
results in fewer stale entries being returned.
Testing: test workloads.
PiperOrigin-RevId: 579928158
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
When the value in --jax_xla_profile_version changes, all tracing
and compilation caches should be invalidated since the XLA programs
need to be recompiled with the new XLA-AutoFDO profile.
Testing:
. New unit test.
. Test workload with instrumentation to repeatedly change
the profile version. Before/after comparison.
PiperOrigin-RevId: 577280639
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.