Before:
```
In [2]: %timeit jax._src.config.config._trace_context()
The slowest run took 23.63 times longer than the fastest. This could mean that an intermediate result is being cached
100000 loops, best of 5: 3.5 µs per loop
```
After:
```
In [5]: %timeit jax._src.config.trace_context()
The slowest run took 12.16 times longer than the fastest. This could mean that an intermediate result is being cached
100000 loops, best of 5: 2.59 µs per loop
```
It's slightly faster to access flags directly via the holder object, rather than via `jax.config`.
PiperOrigin-RevId: 606366377
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.
Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```
Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```
i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.
In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.
I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.
PiperOrigin-RevId: 606340656
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.
PiperOrigin-RevId: 598550225
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.
Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.
Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
Currently, the persistent compilation cache has a time
threshold: the entry is cached only if the compilation
time is less than the threshold. If compilation happens
to take a while, but the resulting executable is small,
there is nothing that prevents caching. This can result
in a large number of small files in the cache.
Introduce a size threshold. If the resulting executable's
size (after serialization and compression) is less than
this threshold, don't cache. This check is in addition to
the compilation time check described above.
Testing: new unit test, test workload.
PiperOrigin-RevId: 595815611
As part of making JAX's behavior more transparent, it must be clear not only
when code is slow because it's spending all its time missing caches (and hence
retracing/recompiling), but also _why_ it missed those caches. That is, just
knowing (from e.g. setting jax_log_compiles) that code is retracing a lot
doesn't tell the user what to do to fix things. But once the user knows that
the cache misses are due to changing dtypes, or due to jit being passed a new
callable object on every iteration of a loop, it's often clear what to do. And
JAX can provide that information
The main idea here is that pointing out which parts of the cache key differs
from previously-seen keys can constitute a pretty good explanation.
This PR adds an explanation mechanism. It can be enabled in a few different ways:
* setting the `JAX_EXPLAIN_CACHE_MISSES` shell environment variable to something truthy;
* setting the config option `jax.config.update('jax_explain_cache_misses', True)`;
* using the context manager `jax._src.config.explain_cache_misses` context
manager (not in public namespace yet);
* when parsing command line flags with absl, using the
`--jax_explain_cache_misses` flag.
Co-authored-by: Yash Katariya <yashkatariya@google.com>
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.
Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.
Testing: test workload.
PiperOrigin-RevId: 590378036
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.
Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.
Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
The new cache-key generation algorithm is more robust and
results in fewer stale entries being returned.
Testing: test workloads.
PiperOrigin-RevId: 579928158
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
When the value in --jax_xla_profile_version changes, all tracing
and compilation caches should be invalidated since the XLA programs
need to be recompiled with the new XLA-AutoFDO profile.
Testing:
. New unit test.
. Test workload with instrumentation to repeatedly change
the profile version. Before/after comparison.
PiperOrigin-RevId: 577280639
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
PiperOrigin-RevId: 571932143
This change is a follow-on to google/jax#16866, which added an ABSL-like API
for flags defined with `DEFINE_...`. Here we add a similar typed API for flags
defined with `define_..._state`.
See 37dad4d356/absl/flags/_flagvalues.py (L1333).
PiperOrigin-RevId: 570721827
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.
If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.
PiperOrigin-RevId: 564457393
This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.
See the CHANGELOG for details.
PiperOrigin-RevId: 556222908
The new cache-key generation algorithm will coexist with the original
version until the new one is fully deployed. While they coexist,
--jax_use_original_compilation_cache_key_generation will determine which
one is used. Once the new algorithm is deployed, the original algorithm
and this flag will be removed.
This change sets up the plumbing. Later changes will implement the new
algorithm.
Testing: test workload.
PiperOrigin-RevId: 555333628
Changing the flag to a config permits more contained testing.
This is in preparation for an upcoming change to incorporate
AutoFDO profile versions in the cache key.
Testing: test workload.
PiperOrigin-RevId: 554942573
The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code.
PiperOrigin-RevId: 554608165
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.
For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.
Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.
This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.
PiperOrigin-RevId: 551604974
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().