188 Commits

Author SHA1 Message Date
Peter Hawkins
995bdaa912 Small optimization to trace_context.
Before:
```
In [2]: %timeit jax._src.config.config._trace_context()
The slowest run took 23.63 times longer than the fastest. This could mean that an intermediate result is being cached
100000 loops, best of 5: 3.5 µs per loop
```

After:
```
In [5]: %timeit jax._src.config.trace_context()
The slowest run took 12.16 times longer than the fastest. This could mean that an intermediate result is being cached
100000 loops, best of 5: 2.59 µs per loop
```

It's slightly faster to access flags directly via the holder object, rather than via `jax.config`.

PiperOrigin-RevId: 606366377
2024-02-12 14:30:11 -08:00
Peter Hawkins
c1f234a95c Refactor and optimize the implementation of config options.
Previously accessing the value of a configuration option required accessing a `values` dictionary in jax._src.config.config. By moving the values into the FlagHolder and StateContextManager objects, we allow for accessing the values directly without as many dictionary accesses.

Timings before, note `enable_x64` is a `StateContextManager` value and `jax_pprint_use_color` is a `FlagHolder` value:

```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
1000000 loops, best of 5: 328 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 377 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 293 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
1000000 loops, best of 5: 358 ns per loop
```

Timings after:
```
In [1]: import jax
In [2]: %timeit jax._src.config.enable_x64.value
10000000 loops, best of 5: 175 ns per loop
In [3]: %timeit jax.config.jax_enable_x64
1000000 loops, best of 5: 220 ns per loop
In [4]: %timeit jax.config.read("jax_pprint_use_color")
1000000 loops, best of 5: 316 ns per loop
In [5]: %timeit jax._src.pretty_printer._PPRINT_USE_COLOR.value
10000000 loops, best of 5: 54.9 ns per loop
```

i.e., if accessed via the holder object directly, this change is a significant speedup, and if accessed by the `jax.config` dictionary this change is a good speedup for `StateContextManager` values and a small slowdown for `FlagHolder` values.

In a subsequent change, we should do more work to avoid using `jax.config.xyz` to access flag values inside jax.

I also note that one of the reasons `StateContextManager` values are a bit suboptimal is that they still require a dictionary lookup in a `threading.local` object. We can probably avoid that with a small C++ extension.

PiperOrigin-RevId: 606340656
2024-02-12 13:04:38 -08:00
Peter Hawkins
6b8c99c60e Tighten up typing of optional config values.
Distinguish optional config values from non-optional config values, and accurately type optional config values.

PiperOrigin-RevId: 606235519
2024-02-12 06:44:02 -08:00
jax authors
9a098e922a Share autotune config between hosts.
PiperOrigin-RevId: 604569298
2024-02-06 01:28:18 -08:00
George Necula
fdf227e7b2 [export] Set default native serialization version to 9.
This version adds better support for JAX effects.

See description in CHANGELOG.md and also at
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.

PiperOrigin-RevId: 603579274
2024-02-01 21:56:03 -08:00
Peter Hawkins
fc6df3218c Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.

i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.

Why do this?

The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.

The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.

This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.

Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.

The change is disabled by default, so we do not expect any user visible impacts from this change.

PiperOrigin-RevId: 599787818
2024-01-19 03:53:37 -08:00
jax authors
ab8eb896d7 Document several jax.config methods.
PiperOrigin-RevId: 598775913
2024-01-16 02:26:41 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
jax authors
b8b119d9b9 Cleanup deprecated compilation cache APIs.
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.

Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.

Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
2024-01-12 22:44:48 -08:00
jax authors
adbbe69cc2 Add option to share compiled module between hosts.
PiperOrigin-RevId: 597754861
2024-01-11 23:38:02 -08:00
jax authors
ea66029731 Introduce min entry size check for compilation cache.
Currently, the persistent compilation cache has a time
threshold: the entry is cached only if the compilation
time is less than the threshold. If compilation happens
to take a while, but the resulting executable is small,
there is nothing that prevents caching. This can result
in a large number of small files in the cache.

Introduce a size threshold. If the resulting executable's
size (after serialization and compression) is less than
this threshold, don't cache. This check is in addition to
the compilation time check described above.

Testing: new unit test, test workload.
PiperOrigin-RevId: 595815611
2024-01-04 15:17:05 -08:00
jax authors
981b670453 [Jax] Allow to set the python traceback frames limit.
PiperOrigin-RevId: 595607107
2024-01-03 23:26:55 -08:00
Matthew Johnson
9112dcebc9 add jax.explain_cache_misses tracing cache miss explanations
As part of making JAX's behavior more transparent, it must be clear not only
when code is slow because it's spending all its time missing caches (and hence
retracing/recompiling), but also _why_ it missed those caches. That is, just
knowing (from e.g. setting jax_log_compiles) that code is retracing a lot
doesn't tell the user what to do to fix things. But once the user knows that
the cache misses are due to changing dtypes, or due to jit being passed a new
callable object on every iteration of a loop, it's often clear what to do. And
JAX can provide that information

The main idea here is that pointing out which parts of the cache key differs
from previously-seen keys can constitute a pretty good explanation.

This PR adds an explanation mechanism. It can be enabled in a few different ways:
  * setting the `JAX_EXPLAIN_CACHE_MISSES` shell environment variable to something truthy;
  * setting the config option `jax.config.update('jax_explain_cache_misses', True)`;
  * using the context manager `jax._src.config.explain_cache_misses` context
    manager (not in public namespace yet);
  * when parsing command line flags with absl, using the
    `--jax_explain_cache_misses` flag.

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-12-26 21:54:27 -08:00
jax authors
259c285b10 [Jax] Enable jax_include_full_tracebacks_in_locations by default
PiperOrigin-RevId: 591783126
2023-12-17 21:56:13 -08:00
jax authors
196c97fa0c Merge pull request #18949 from froystig:seed-offset
PiperOrigin-RevId: 590637382
2023-12-13 10:18:40 -08:00
Roy Frostig
671790730e introduce a config flag to control a random seed offset 2023-12-12 18:31:07 -08:00
jax authors
32c99f627e Remove the old cache-key generation code.
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.

Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.

Testing: test workload.
PiperOrigin-RevId: 590378036
2023-12-12 16:34:32 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
jax authors
b9b5410ddd Default-enable the Jax persistent compilation cache.
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.

Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.

Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
2023-11-27 14:53:20 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Yash Katariya
cf3c041366 Disable jax memories flag.
PiperOrigin-RevId: 580961421
2023-11-09 10:54:02 -08:00
jax authors
28e33ca5d1 Switch to the new cache-key generation algorithm.
The new cache-key generation algorithm is more robust and
results in fewer stale entries being returned.

Testing: test workloads.
PiperOrigin-RevId: 579928158
2023-11-06 12:57:01 -08:00
Roy Frostig
b22e75716f add threefry_partitionable config setting to thread-local JIT context 2023-10-31 13:45:49 -07:00
Sergei Lebedev
fd3a8b2cc6 Deprecated define_* and DEFINE_* methods on jax.config
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
2023-10-29 20:58:19 +00:00
jax authors
9ba305cced Invalidate in-memory caches on XLA-AutoFDO profile version change.
When the value in --jax_xla_profile_version changes, all tracing
and compilation caches should be invalidated since the XLA programs
need to be recompiled with the new XLA-AutoFDO profile.

Testing:
. New unit test.
. Test workload with instrumentation to repeatedly change
  the profile version. Before/after comparison.
PiperOrigin-RevId: 577280639
2023-10-27 12:52:57 -07:00
Sergei Lebedev
f2ce5dbd01 MAINT Do not use str() and repr() in f-string replacement fields
`str()` is called by default by the formatting machinery, and `repr()` only
needs `!r`.
2023-10-23 15:12:04 +01:00
George Necula
e89212c81a [export] Set the default export serialization version to 8.
This version has been supported by XlaCallModule since July 21, 2023 and we are now past the forward-compatibility window.

See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions

Reverts ae81ac9cc21696a22b973b1eae6ce222c7318ba7

PiperOrigin-RevId: 575382324
2023-10-20 21:00:55 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
65d3058944 Migrate a subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00
Sergei Lebedev
923498fb45 _StateContextManager now preserves the type of the value it stores.
This change is a follow-on to google/jax#16866, which added an ABSL-like API
for flags defined with `DEFINE_...`. Here we add a similar typed API for flags
defined with `define_..._state`.

See 37dad4d356/absl/flags/_flagvalues.py (L1333).

PiperOrigin-RevId: 570721827
2023-10-04 09:49:19 -07:00
George Necula
ae81ac9cc2 Reverts 1a9c94e6265b40c8f1de1e6d920208d648d70fdd
PiperOrigin-RevId: 568838127
2023-09-27 06:58:35 -07:00
George Necula
1a9c94e626 [export] Set the default export serialization version to 8.
This version has been supported by XlaCallModule since July 21, 2023 and we are now past the forward-compatibility window.

See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions

PiperOrigin-RevId: 568777006
2023-09-27 01:28:03 -07:00
Jake VanderPlas
4a5bd9e046 Fix typos across the package 2023-09-22 14:54:31 -07:00
Ruoxin Sang
3e06dc8b77 Update jax_spmd_mode flag docstring and remove unused allow_pjit option.
PiperOrigin-RevId: 564543943
2023-09-11 17:08:35 -07:00
Yash Katariya
a36598b2a7 Set the jax_enable_memories flag to True.
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.

If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.

PiperOrigin-RevId: 564457393
2023-09-11 11:55:09 -07:00
Yash Katariya
ccb88140ec Make apply_primitive go via C++ fast dispatch.
This leads to a ~30% faster dispatch time. Ideally, we should replace this with jit, but that has it's own set of problems that I will look into later.

```
eager_unary_dispatch                                  40.3µs ± 2%             29.2µs ± 9%  -27.51%          (p=0.008 n=5+5)
eager_unary                                           40.6µs ± 0%             31.1µs ±11%  -23.41%          (p=0.016 n=4+5)
eager_binary_dispatch                                 49.6µs ± 0%             34.5µs ± 8%  -30.58%          (p=0.016 n=4+5)
eager_binary                                          50.2µs ± 1%             35.4µs ± 9%  -29.38%          (p=0.016 n=4+5)
bench_remat_eager_retracing_overheads                 13.0ms ± 1%             11.3ms ± 8%  -13.26%          (p=0.008 n=5+5)
bench_remat_eager_retracing_overheads_static_argnums  13.3ms ± 0%             12.3ms ± 6%   -7.34%          (p=0.016 n=4+5)
bench_repeated_static_indexing                         112ms ± 2%               82ms ± 5%  -26.46%          (p=0.008 n=5+5)
bench_repeated_static_slicing                         90.5ms ± 1%             68.3ms ± 5%  -24.54%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 561774696
2023-08-31 15:25:11 -07:00
Yash Katariya
eea2603363 Add a proper jax config for memories so that we can iteratively develop and enable it.
PiperOrigin-RevId: 559977015
2023-08-24 22:23:55 -07:00
Jake VanderPlas
630a69f41b [random] add jax_legacy_prng_key flag 2023-08-22 15:08:51 -07:00
George Necula
cf4e1d414b [jax2tf] Bump the default JAX serialization version to 7.
This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.

See the CHANGELOG for details.

PiperOrigin-RevId: 556222908
2023-08-11 22:49:31 -07:00
jax authors
aac4cdad56 Set up plumbing for adding new compilation-cache-key generation algorithm.
The new cache-key generation algorithm will coexist with the original
version until the new one is fully deployed. While they coexist,
--jax_use_original_compilation_cache_key_generation will determine which
one is used. Once the new algorithm is deployed, the original algorithm
and this flag will be removed.

This change sets up the plumbing. Later changes will implement the new
algorithm.

Testing: test workload.
PiperOrigin-RevId: 555333628
2023-08-09 18:16:22 -07:00
jax authors
d01695c746 Change --jax_xla_profile_version definition to config.
Changing the flag to a config permits more contained testing.
This is in preparation for an upcoming change to incorporate
AutoFDO profile versions in the cache key.

Testing: test workload.
PiperOrigin-RevId: 554942573
2023-08-08 14:29:09 -07:00
jax authors
5efc681702 Cleanup comments for define_{int,float}_state.
There is no parameter enum_values in these functions.
Probably a copy/paste issue from define_enum_state.

PiperOrigin-RevId: 554644871
2023-08-07 17:41:09 -07:00
Peter Hawkins
c879f65aa6 [JAX] Remove the non-coordination service distributed service implementation from JAX.
The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code.

PiperOrigin-RevId: 554608165
2023-08-07 15:17:25 -07:00
George Necula
8d80e2587b [jax2tf] Turn on JAX native serialization by default.
See changes to the README.md for mechanisms to override the default.

PiperOrigin-RevId: 554390866
2023-08-07 01:03:55 -07:00
Patrick Kidger
5e276d0935 Tracebacks no longer have JAX-internal frames prepended by default 2023-08-03 11:38:38 -07:00
Peter Hawkins
a40f900e23 Fix exception if additional ABSL flags were registered after config_with_absl() was called.
PiperOrigin-RevId: 552584357
2023-07-31 13:59:01 -07:00
Skye Wanderman-Milne
8b58e38ec5 Add jax_debug_log_modules config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).

Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-07-28 18:11:12 +00:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00