90 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
a59bbb7cd7 Add test utility for accessing jaxlib version tuple.
We frequently need to condition tests on the current version of jaxlib. This change exposes the version tuple directly as part of `jtu` so that we don't need to import `jax._src.lib.version` in the tests.

PiperOrigin-RevId: 698097487
2024-11-19 12:00:32 -08:00
Trevor Morris
a79d307ac7 When caching is enabled, also enable XLA caching features as well
Add unit test

Fix typechecker

Set caching mode depending on process id
2024-11-13 10:30:04 -08:00
Bart Chrzaszcz
44158ab0e4 #sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike

Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.

Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.

PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
Peter Hawkins
fc4f554e09 Delete jax.lib.xla_client.execute_with_python_values.
Nothing under jax.lib.xla_client is public, so there's no deprecation period required.

PiperOrigin-RevId: 681166972
2024-10-01 14:32:22 -07:00
Peter Hawkins
a0e4448393 Remove warning filters from pyproject.toml, add local warning
suppressions.

We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
2024-09-24 01:38:24 +00:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
jax authors
44241eeab1 Add a beacon metric to report tasks disabled Jax compilation cache.
- Create metric '/jax/compilation_cache/task_disabled_cache' as a beacon metric to monitor tasks which have disabled compilation cache.
- Modified existing logic for reporting the '/jax/compilation_cache/tasks_using_cache' metric and make it easier to find the two adoption related metrics in the code.

PiperOrigin-RevId: 654970654
2024-07-22 18:44:18 -07:00
rdyro
c6d6207170 Unifying persistent cache messages
and moving them to WARNING logging when explain_cache_misses is true.
2024-07-16 00:47:53 +00:00
Ayaka
bc7addf938 Improve compilation cache tests 2024-06-27 19:51:19 +04:00
Parker Schuh
eba0564b70 Allow disabling compilation cache for particular runtime_types.
PiperOrigin-RevId: 640264856
2024-06-04 13:30:42 -07:00
Jake VanderPlas
57f70e2a60 Avoid global state in compilation_cache_test 2024-05-29 12:54:06 -07:00
jax authors
cc3a380f76 Add unit test to check if the backend serialization/deserialization result equal to the original executable.
PiperOrigin-RevId: 635485374
2024-05-20 09:52:38 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Peter Hawkins
7018e0b085 Fix warnings in CI from compilation_cache_test.
Whether the jitted function __eq__ is cached changes the number of warnings we expect.
2024-04-30 13:40:35 +00:00
Yunlong Liu
2df6b35dce Adds meaningful function names for better debugging.
The default `fn.__name__` was added in `_one_to_one_unop` but not other functions so that it leads to many downstream function wrappers giving unmeaningful names while debugging. For instance,

When a JAX numpy primitive `lax.add` is wrapped by `lu.WrappedFun`, `print(wrapped)` will give,

```
Wrapped function:
0   : _argnums_partial   ((0, 1), ())
1   : flatten_fun   (PyTreeDef(((*, *), {})),)
2   : result_paths   ()
Core: fn
```
instead of
```
Wrapped function:
0   : _argnums_partial   ((0, 1), ())
1   : flatten_fun   (PyTreeDef(((*, *), {})),)
2   : result_paths   ()
Core: add
```
PiperOrigin-RevId: 627417452
2024-04-23 09:45:57 -07:00
jax authors
88dd29a0b5 Re-enable persistent cache on cpu.
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.

PiperOrigin-RevId: 621341638
2024-04-02 17:30:52 -07:00
Jake VanderPlas
84e49bd6ce Remove internal references to deprecated jax.experimental.maps 2024-03-19 09:24:52 -07:00
jax authors
3708336f8f To avoid the inconsistency between process_index and process_id, replace backend.process_index with distributed.global_state.process_id in Jax compilation _cache_write function.
Testing: new unit test.
PiperOrigin-RevId: 607385112
2024-02-15 10:48:05 -08:00
Eugene Zhulenev
28ef77dfb0 Disable JAX compilation cache for XLA:CPU
PiperOrigin-RevId: 603551262
2024-02-01 19:20:39 -08:00
Peter Hawkins
a7023b18d5 [JAX] Disable a compilation cache test that fails on Windows in CI.
PiperOrigin-RevId: 599901235
2024-01-19 12:07:25 -08:00
jax authors
b8b119d9b9 Cleanup deprecated compilation cache APIs.
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.

Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.

Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
2024-01-12 22:44:48 -08:00
Eugene Zhulenev
ba4c2b1c7d [pjrt:cpu] Add CpuTopology to TfrtCpuClient and enable persistent compilation cache for cpu backend
PiperOrigin-RevId: 597327136
2024-01-10 12:40:57 -08:00
jax authors
da96633f11 Correct the cache miss metric instrumentation due to the new min cache entry size flag
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.

PiperOrigin-RevId: 596696013
2024-01-08 14:03:33 -08:00
jax authors
ea66029731 Introduce min entry size check for compilation cache.
Currently, the persistent compilation cache has a time
threshold: the entry is cached only if the compilation
time is less than the threshold. If compilation happens
to take a while, but the resulting executable is small,
there is nothing that prevents caching. This can result
in a large number of small files in the cache.

Introduce a size threshold. If the resulting executable's
size (after serialization and compression) is less than
this threshold, don't cache. This check is in addition to
the compilation time check described above.

Testing: new unit test, test workload.
PiperOrigin-RevId: 595815611
2024-01-04 15:17:05 -08:00
jax authors
c28aa2cecc Remove CPU support from compilation cache.
CPU support was originally added to the compilation cache
in anticipation of the availability of CPU acceleration
compilation. Since this is not available and the
--xla_cpu_use_xla_runtime flag has been deprecated,
cleanup the code and test.

Testing: test workload, revised unit test.
PiperOrigin-RevId: 592962316
2023-12-21 15:25:56 -08:00
jax authors
32c99f627e Remove the old cache-key generation code.
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.

Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.

Testing: test workload.
PiperOrigin-RevId: 590378036
2023-12-12 16:34:32 -08:00
jax authors
d3f4bbfdd0 Fix cache_used metric implementation and test.
The cache_used metric is incremented once per task and is
used to determine how many tasks are using the Jax
compilation cache. The current implementation and unit
test are not thread safe. This results in the test
failing when unit tests are executed concurrently.

The fix is to make the implementation thread safe and
to update the test to examine the delta in the metric.

Testing: Cloud TPU VM testing; test workload.
PiperOrigin-RevId: 589174850
2023-12-08 10:30:22 -08:00
jax authors
896d4cfbaf Disable task_using_cache_metric unit test while debugging.
This test is failing in the OSS environment. Temporarily
disabling the test while debugging.

PiperOrigin-RevId: 586144501
2023-11-28 17:04:23 -08:00
jax authors
b9b5410ddd Default-enable the Jax persistent compilation cache.
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.

Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.

Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
2023-11-27 14:53:20 -08:00
jax authors
45982c8439 Update test since C PjRt API topology description is available.
The compilation_cache_test had an exclusion since the C PjRt
topology description had not been implemented. Now that it is
available, remove the exclusion.

PiperOrigin-RevId: 581396824
2023-11-10 16:12:06 -08:00
jax authors
7e372944f9 Fix the missing cache_misses metric when min compile time is set to zero.
Remove the code which checks if the min compile time is greater than zero. After this change, we can catch cache_misses when min compile time is zero.

Testing: revised unit test.
PiperOrigin-RevId: 579951415
2023-11-06 14:04:35 -08:00
jax authors
9ba305cced Invalidate in-memory caches on XLA-AutoFDO profile version change.
When the value in --jax_xla_profile_version changes, all tracing
and compilation caches should be invalidated since the XLA programs
need to be recompiled with the new XLA-AutoFDO profile.

Testing:
. New unit test.
. Test workload with instrumentation to repeatedly change
  the profile version. Before/after comparison.
PiperOrigin-RevId: 577280639
2023-10-27 12:52:57 -07:00
jax authors
65cfe1a5a3 Instrument metrics for the new JAX compilation cache key generation algorithm.
Metrics:
1) '/jax/compilation_cache/cache_hits' to track the number of times the cached executable is successfully returned from a cache read using the new implementation.
2) '/jax/compilation_cache/compile_time_saved_sec' to record the time saved on cache hits using the new implementation.
PiperOrigin-RevId: 573019115
2023-10-12 14:56:02 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
jax authors
6869000636 Modify the name for JAX compilation cache adoption metric.
Rename the adoption metric from '/jax/compilation_cache/tasks_using_original_cache' to '/jax/compilation_cache/tasks_using_cache' to record the number of tasks using compilation cache for both original and new implementations.

We realized that one adoption metric is enough to monitor and compare the original and new cache adoption rates. To avoid confusion of the word 'original' in the metric name, we decide to change the metric name to better describe the purpose of this adoption metric.

Testing: revised unit test.
PiperOrigin-RevId: 565197027
2023-09-13 16:44:10 -07:00
jax authors
80f6151110 Instrument metrics to track cache hit rate of original JAX compilation cache.
Metrics:
1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of  the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true.
2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation.
3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository.

Created a context manager to register/unregister event listeners.

PiperOrigin-RevId: 561771262
2023-08-31 15:05:23 -07:00
jax authors
fe107fd119 Create a context manager to register/unregister event duration listeners.
PiperOrigin-RevId: 559209418
2023-08-22 13:35:45 -07:00
jax authors
27d19ee233 Instrument a metric to measure the number of tasks using compilation cache in JAX -> PJRT.
Create the metric '/jax/compilation_cache/tasks_using_original_cache' to record the number of tasks using compilation cache.

PiperOrigin-RevId: 559159282
2023-08-22 10:48:58 -07:00
Peter Hawkins
a259df0d76 Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.
Refactoring only, no user-visible changes intended.

PiperOrigin-RevId: 557116160
2023-08-15 06:39:46 -07:00
jax authors
35d33f620c Instrument metrics to measure compilation cache savings in JAX -> PJRT.
Create metrics:
1) '/jax/compilation_cache/cache_retrieval_time_sec' to record the time duration for getting cache entries.
2) '/jax/compilation_cache/original_compile_time_saved_sec' to record the time saved on cache hits.

PiperOrigin-RevId: 556243588
2023-08-12 00:20:42 -07:00
jax authors
3b28d4e180 Refactor Jax compilation cache key generation.
This is in preparation for introducing a more robust key-generation
algorithm.

This refactoring does not introduce any change in behavior.

Testing: refactored unit tests and test workload.
PiperOrigin-RevId: 551744892
2023-07-27 23:01:00 -07:00
jax authors
0c4c020716 Include compile time along with executable in cache entry.
In order to measure cache savings, we add compilation time to the cache entry along with the serialized executable. The compile time can then be retrieved on a cache hit.

Testing: updated tests.
PiperOrigin-RevId: 549439628
2023-07-19 15:17:45 -07:00
Jake VanderPlas
b9c7b9bb4f Remove obsolete jaxlib version checks 2023-07-12 11:53:55 -07:00
Tao Wang
6eb3096461 Enable to set fdo_profile through XLA python client.
PiperOrigin-RevId: 547303330
2023-07-11 14:47:39 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Peter Hawkins
1d63d9b833 Include the device_kind in the compilation cache key.
PiperOrigin-RevId: 525726898
2023-04-20 06:16:45 -07:00
Peter Hawkins
34fd4a1562 Add version guard to compilation cache test.
PiperOrigin-RevId: 525572568
2023-04-19 15:50:33 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
017548c40b Move implementation of compilation cache out of jax/experimental and into jax/_src.
Use a Protocol instead of an abstract base class for the CacheInterface since it allows us to use one fewer file.

No functional change intended.

PiperOrigin-RevId: 524855263
2023-04-17 08:35:53 -07:00