39 Commits

Author SHA1 Message Date
Parker Schuh
eba0564b70 Allow disabling compilation cache for particular runtime_types.
PiperOrigin-RevId: 640264856
2024-06-04 13:30:42 -07:00
jax authors
26f9820417 [JAX] Automatically share PGO data for GPU latency-hiding scheduler.
Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data.

1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results.

2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner.

3. The profile session runner should be passed to pxla.py and then called.

4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h

5. Once FDO is collected we need to share it between hosts to keep deterministic compilation.

PiperOrigin-RevId: 638197166
2024-05-29 01:50:03 -07:00
jax authors
e0bf783c3e [JAX] Fix multi-process data sharing comments.
PiperOrigin-RevId: 636817251
2024-05-24 00:22:28 -07:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
jax authors
95c4ba961c [JAX] Use first process id instead of process 0 to share multi-host data.
PiperOrigin-RevId: 633526220
2024-05-14 03:45:41 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
jax authors
7681493760 Don't create temp directory when module is getting imported.
PiperOrigin-RevId: 630958402
2024-05-06 00:58:45 -07:00
jax authors
88dd29a0b5 Re-enable persistent cache on cpu.
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.

PiperOrigin-RevId: 621341638
2024-04-02 17:30:52 -07:00
jax authors
98ad6ef057 Add debug logging for autotune profile sharing.
PiperOrigin-RevId: 614915339
2024-03-11 22:46:56 -07:00
jax authors
3708336f8f To avoid the inconsistency between process_index and process_id, replace backend.process_index with distributed.global_state.process_id in Jax compilation _cache_write function.
Testing: new unit test.
PiperOrigin-RevId: 607385112
2024-02-15 10:48:05 -08:00
Junwhan Ahn
b9c7296754 Add an interface for IFRT backends to specify executable serializability and use this information to automatically disable JAX compilation cache for backends that do not implement executable serialization
This CL adds `supports_executable_serialization` as an IFRT client attribute indicating whether the executables produced by a given IFRT implementation are serializable or not. This is based on a principle where `xla::ifrt::Client::attributes()` returns a set of attributes representing the "capabilities" of an IFRT implementation, so that the users of IFRT can act based on such capabilities without having to know the exact backend that they are using.

This change is backward compatible as IFRT backends that do not implement `supports_executable_serialization` are assumed to implement executable serialization.

PiperOrigin-RevId: 606799188
2024-02-13 17:27:58 -08:00
jax authors
9a098e922a Share autotune config between hosts.
PiperOrigin-RevId: 604569298
2024-02-06 01:28:18 -08:00
Eugene Zhulenev
28ef77dfb0 Disable JAX compilation cache for XLA:CPU
PiperOrigin-RevId: 603551262
2024-02-01 19:20:39 -08:00
jax authors
34d22fc498 Exclude modules with host callbacks from inter-host sharing.
PiperOrigin-RevId: 600697796
2024-01-23 00:38:13 -08:00
jax authors
b8b119d9b9 Cleanup deprecated compilation cache APIs.
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.

Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.

Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
2024-01-12 22:44:48 -08:00
jax authors
adbbe69cc2 Add option to share compiled module between hosts.
PiperOrigin-RevId: 597754861
2024-01-11 23:38:02 -08:00
Eugene Zhulenev
ba4c2b1c7d [pjrt:cpu] Add CpuTopology to TfrtCpuClient and enable persistent compilation cache for cpu backend
PiperOrigin-RevId: 597327136
2024-01-10 12:40:57 -08:00
jax authors
da96633f11 Correct the cache miss metric instrumentation due to the new min cache entry size flag
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.

PiperOrigin-RevId: 596696013
2024-01-08 14:03:33 -08:00
jax authors
c28aa2cecc Remove CPU support from compilation cache.
CPU support was originally added to the compilation cache
in anticipation of the availability of CPU acceleration
compilation. Since this is not available and the
--xla_cpu_use_xla_runtime flag has been deprecated,
cleanup the code and test.

Testing: test workload, revised unit test.
PiperOrigin-RevId: 592962316
2023-12-21 15:25:56 -08:00
George Necula
bb84e6c22e Improve support for JAX_DUMP_IR_TO.
Previously the environment variable JAX_DUMP_IR_TO controlled
whether and where to dump the MLIR module prior to compilation. Now we move the code for that support from
compiler.py to mlir.py, so that it can be used in other
parts of the code. We also add support for logging to Sponge.

Using this support we now log the module on errors from
refine_polymorphic_shapes.

PiperOrigin-RevId: 592099633
2023-12-18 21:25:45 -08:00
jax authors
32c99f627e Remove the old cache-key generation code.
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.

Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.

Testing: test workload.
PiperOrigin-RevId: 590378036
2023-12-12 16:34:32 -08:00
jax authors
d3f4bbfdd0 Fix cache_used metric implementation and test.
The cache_used metric is incremented once per task and is
used to determine how many tasks are using the Jax
compilation cache. The current implementation and unit
test are not thread safe. This results in the test
failing when unit tests are executed concurrently.

The fix is to make the implementation thread safe and
to update the test to examine the delta in the metric.

Testing: Cloud TPU VM testing; test workload.
PiperOrigin-RevId: 589174850
2023-12-08 10:30:22 -08:00
Adam Paszke
d80c15aaee Downgrade a bunch of logging to DEBUG
The logs related to compilation cache ended up being quite chatty,
which is quite unlike the other logs in JAX. This downgrades a bunch
of them to debug, as they can always be enabled independently
using JAX config. This should also fix the recent failures in
logging_test.py.
2023-11-29 12:10:53 +00:00
jax authors
fc8058a17d Restrict retrieving XLA-AutoFDO profile version to TPU workloads.
XLA-AutoFDO is supported only for TPUs, so requesting the latest
profile version for non-TPU workloads is unnecessary and can delay
the completion of initialization.

Testing: test workload.
PiperOrigin-RevId: 584148686
2023-11-20 15:52:03 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
jax authors
7e372944f9 Fix the missing cache_misses metric when min compile time is set to zero.
Remove the code which checks if the min compile time is greater than zero. After this change, we can catch cache_misses when min compile time is zero.

Testing: revised unit test.
PiperOrigin-RevId: 579951415
2023-11-06 14:04:35 -08:00
Yash Katariya
8b05b1623c Make the directories (and it's parents) specified in JAX_DUMP_IR_TO flag if they don't exist
PiperOrigin-RevId: 576151618
2023-10-24 08:39:51 -07:00
jax authors
65cfe1a5a3 Instrument metrics for the new JAX compilation cache key generation algorithm.
Metrics:
1) '/jax/compilation_cache/cache_hits' to track the number of times the cached executable is successfully returned from a cache read using the new implementation.
2) '/jax/compilation_cache/compile_time_saved_sec' to record the time saved on cache hits using the new implementation.
PiperOrigin-RevId: 573019115
2023-10-12 14:56:02 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
4611d13c07 Only perform compilation cache writes from process 0.
This avoids problems with contending writes on filesystems such as GCS.

PiperOrigin-RevId: 572032482
2023-10-09 13:55:07 -07:00
Sergei Lebedev
65d3058944 Migrate a subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00
Peter Hawkins
b2ac2deb5b [XLA] Split --xla_detailed_logging_and_dumping debug flag into --xla_detailed_logging and --xla_enable_dumping.
We want to suppress detailed logging (notably on TPU, which has pretty verbose detailed logging) separately from disabling HLO dumps. Even if we don't print detailed log information, it's quite surprising if an HLO module doesn't show up in the set of modules dumped by XLA.

PiperOrigin-RevId: 570374492
2023-10-03 06:59:47 -07:00
jax authors
6869000636 Modify the name for JAX compilation cache adoption metric.
Rename the adoption metric from '/jax/compilation_cache/tasks_using_original_cache' to '/jax/compilation_cache/tasks_using_cache' to record the number of tasks using compilation cache for both original and new implementations.

We realized that one adoption metric is enough to monitor and compare the original and new cache adoption rates. To avoid confusion of the word 'original' in the metric name, we decide to change the metric name to better describe the purpose of this adoption metric.

Testing: revised unit test.
PiperOrigin-RevId: 565197027
2023-09-13 16:44:10 -07:00
Peter Hawkins
729752b32b Disable XLA detailed logging and dumping for small computations.
This significantly reduces the amount of logging from XLA on TPU.

PiperOrigin-RevId: 565148809
2023-09-13 13:45:00 -07:00
jax authors
23e4f0b471 Hash serialized topology description for new cache key generation.
The original cache key generation hashes devices and backend. This
is not future proof: it does not work for accelerators other than
TPUs. Change this to use the serialized version of
PjRtTopologyDescription which is supported for all accelerators.

Note:
. CPU and PjRt C API not supported as yet.
. Stream Executor will not be supported.

Testing: revised unit test.
PiperOrigin-RevId: 564461564
2023-09-11 12:08:26 -07:00
jax authors
80f6151110 Instrument metrics to track cache hit rate of original JAX compilation cache.
Metrics:
1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of  the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true.
2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation.
3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository.

Created a context manager to register/unregister event listeners.

PiperOrigin-RevId: 561771262
2023-08-31 15:05:23 -07:00
Skye Wanderman-Milne
c523e3b1d0 Turn down jax_persistent_cache_min_compile_time_secs logging from info to debug.
It's very noisy otherwise, since jax usually produces many small computations that aren't cached.
2023-08-28 19:49:25 +00:00
jax authors
27d19ee233 Instrument a metric to measure the number of tasks using compilation cache in JAX -> PJRT.
Create the metric '/jax/compilation_cache/tasks_using_original_cache' to record the number of tasks using compilation cache.

PiperOrigin-RevId: 559159282
2023-08-22 10:48:58 -07:00
Peter Hawkins
a259df0d76 Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.
Refactoring only, no user-visible changes intended.

PiperOrigin-RevId: 557116160
2023-08-15 06:39:46 -07:00