98 Commits

Author SHA1 Message Date
Bart Chrzaszcz
db8c8fc37c #sdy unskip JAX Shardy tests that are already passing
PiperOrigin-RevId: 718898708
2025-01-23 09:26:38 -08:00
jax authors
f6243ff8e1 Merge pull request #25889 from Stella-S-Yan:cache_reset
PiperOrigin-RevId: 718537398
2025-01-22 14:52:05 -08:00
Stella S Yan
f87c94db75 Fix cache init when JAX Array is created early (#25768) 2025-01-22 03:44:29 +00:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
Peter Hawkins
b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00
Peter Hawkins
3fa557289a Port tests away from setUpClass and setUpModule to setUp alone.
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.

If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.

PiperOrigin-RevId: 713296722
2025-01-08 08:14:50 -08:00
Peter Hawkins
5a250097e4 Fix Windows portability problem in compilation cache test. 2024-12-04 10:00:28 -05:00
Dan Foreman-Mackey
a59bbb7cd7 Add test utility for accessing jaxlib version tuple.
We frequently need to condition tests on the current version of jaxlib. This change exposes the version tuple directly as part of `jtu` so that we don't need to import `jax._src.lib.version` in the tests.

PiperOrigin-RevId: 698097487
2024-11-19 12:00:32 -08:00
Trevor Morris
a79d307ac7 When caching is enabled, also enable XLA caching features as well
Add unit test

Fix typechecker

Set caching mode depending on process id
2024-11-13 10:30:04 -08:00
Bart Chrzaszcz
44158ab0e4 #sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike

Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.

Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.

PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
Peter Hawkins
fc4f554e09 Delete jax.lib.xla_client.execute_with_python_values.
Nothing under jax.lib.xla_client is public, so there's no deprecation period required.

PiperOrigin-RevId: 681166972
2024-10-01 14:32:22 -07:00
Peter Hawkins
a0e4448393 Remove warning filters from pyproject.toml, add local warning
suppressions.

We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
2024-09-24 01:38:24 +00:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
jax authors
44241eeab1 Add a beacon metric to report tasks disabled Jax compilation cache.
- Create metric '/jax/compilation_cache/task_disabled_cache' as a beacon metric to monitor tasks which have disabled compilation cache.
- Modified existing logic for reporting the '/jax/compilation_cache/tasks_using_cache' metric and make it easier to find the two adoption related metrics in the code.

PiperOrigin-RevId: 654970654
2024-07-22 18:44:18 -07:00
rdyro
c6d6207170 Unifying persistent cache messages
and moving them to WARNING logging when explain_cache_misses is true.
2024-07-16 00:47:53 +00:00
Ayaka
bc7addf938 Improve compilation cache tests 2024-06-27 19:51:19 +04:00
Parker Schuh
eba0564b70 Allow disabling compilation cache for particular runtime_types.
PiperOrigin-RevId: 640264856
2024-06-04 13:30:42 -07:00
Jake VanderPlas
57f70e2a60 Avoid global state in compilation_cache_test 2024-05-29 12:54:06 -07:00
jax authors
cc3a380f76 Add unit test to check if the backend serialization/deserialization result equal to the original executable.
PiperOrigin-RevId: 635485374
2024-05-20 09:52:38 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Peter Hawkins
7018e0b085 Fix warnings in CI from compilation_cache_test.
Whether the jitted function __eq__ is cached changes the number of warnings we expect.
2024-04-30 13:40:35 +00:00
Yunlong Liu
2df6b35dce Adds meaningful function names for better debugging.
The default `fn.__name__` was added in `_one_to_one_unop` but not other functions so that it leads to many downstream function wrappers giving unmeaningful names while debugging. For instance,

When a JAX numpy primitive `lax.add` is wrapped by `lu.WrappedFun`, `print(wrapped)` will give,

```
Wrapped function:
0   : _argnums_partial   ((0, 1), ())
1   : flatten_fun   (PyTreeDef(((*, *), {})),)
2   : result_paths   ()
Core: fn
```
instead of
```
Wrapped function:
0   : _argnums_partial   ((0, 1), ())
1   : flatten_fun   (PyTreeDef(((*, *), {})),)
2   : result_paths   ()
Core: add
```
PiperOrigin-RevId: 627417452
2024-04-23 09:45:57 -07:00
jax authors
88dd29a0b5 Re-enable persistent cache on cpu.
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.

PiperOrigin-RevId: 621341638
2024-04-02 17:30:52 -07:00
Jake VanderPlas
84e49bd6ce Remove internal references to deprecated jax.experimental.maps 2024-03-19 09:24:52 -07:00
jax authors
3708336f8f To avoid the inconsistency between process_index and process_id, replace backend.process_index with distributed.global_state.process_id in Jax compilation _cache_write function.
Testing: new unit test.
PiperOrigin-RevId: 607385112
2024-02-15 10:48:05 -08:00
Eugene Zhulenev
28ef77dfb0 Disable JAX compilation cache for XLA:CPU
PiperOrigin-RevId: 603551262
2024-02-01 19:20:39 -08:00
Peter Hawkins
a7023b18d5 [JAX] Disable a compilation cache test that fails on Windows in CI.
PiperOrigin-RevId: 599901235
2024-01-19 12:07:25 -08:00
jax authors
b8b119d9b9 Cleanup deprecated compilation cache APIs.
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.

Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.

Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
2024-01-12 22:44:48 -08:00
Eugene Zhulenev
ba4c2b1c7d [pjrt:cpu] Add CpuTopology to TfrtCpuClient and enable persistent compilation cache for cpu backend
PiperOrigin-RevId: 597327136
2024-01-10 12:40:57 -08:00
jax authors
da96633f11 Correct the cache miss metric instrumentation due to the new min cache entry size flag
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.

PiperOrigin-RevId: 596696013
2024-01-08 14:03:33 -08:00
jax authors
ea66029731 Introduce min entry size check for compilation cache.
Currently, the persistent compilation cache has a time
threshold: the entry is cached only if the compilation
time is less than the threshold. If compilation happens
to take a while, but the resulting executable is small,
there is nothing that prevents caching. This can result
in a large number of small files in the cache.

Introduce a size threshold. If the resulting executable's
size (after serialization and compression) is less than
this threshold, don't cache. This check is in addition to
the compilation time check described above.

Testing: new unit test, test workload.
PiperOrigin-RevId: 595815611
2024-01-04 15:17:05 -08:00
jax authors
c28aa2cecc Remove CPU support from compilation cache.
CPU support was originally added to the compilation cache
in anticipation of the availability of CPU acceleration
compilation. Since this is not available and the
--xla_cpu_use_xla_runtime flag has been deprecated,
cleanup the code and test.

Testing: test workload, revised unit test.
PiperOrigin-RevId: 592962316
2023-12-21 15:25:56 -08:00
jax authors
32c99f627e Remove the old cache-key generation code.
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.

Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.

Testing: test workload.
PiperOrigin-RevId: 590378036
2023-12-12 16:34:32 -08:00
jax authors
d3f4bbfdd0 Fix cache_used metric implementation and test.
The cache_used metric is incremented once per task and is
used to determine how many tasks are using the Jax
compilation cache. The current implementation and unit
test are not thread safe. This results in the test
failing when unit tests are executed concurrently.

The fix is to make the implementation thread safe and
to update the test to examine the delta in the metric.

Testing: Cloud TPU VM testing; test workload.
PiperOrigin-RevId: 589174850
2023-12-08 10:30:22 -08:00
jax authors
896d4cfbaf Disable task_using_cache_metric unit test while debugging.
This test is failing in the OSS environment. Temporarily
disabling the test while debugging.

PiperOrigin-RevId: 586144501
2023-11-28 17:04:23 -08:00
jax authors
b9b5410ddd Default-enable the Jax persistent compilation cache.
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.

Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.

Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
2023-11-27 14:53:20 -08:00
jax authors
45982c8439 Update test since C PjRt API topology description is available.
The compilation_cache_test had an exclusion since the C PjRt
topology description had not been implemented. Now that it is
available, remove the exclusion.

PiperOrigin-RevId: 581396824
2023-11-10 16:12:06 -08:00
jax authors
7e372944f9 Fix the missing cache_misses metric when min compile time is set to zero.
Remove the code which checks if the min compile time is greater than zero. After this change, we can catch cache_misses when min compile time is zero.

Testing: revised unit test.
PiperOrigin-RevId: 579951415
2023-11-06 14:04:35 -08:00
jax authors
9ba305cced Invalidate in-memory caches on XLA-AutoFDO profile version change.
When the value in --jax_xla_profile_version changes, all tracing
and compilation caches should be invalidated since the XLA programs
need to be recompiled with the new XLA-AutoFDO profile.

Testing:
. New unit test.
. Test workload with instrumentation to repeatedly change
  the profile version. Before/after comparison.
PiperOrigin-RevId: 577280639
2023-10-27 12:52:57 -07:00
jax authors
65cfe1a5a3 Instrument metrics for the new JAX compilation cache key generation algorithm.
Metrics:
1) '/jax/compilation_cache/cache_hits' to track the number of times the cached executable is successfully returned from a cache read using the new implementation.
2) '/jax/compilation_cache/compile_time_saved_sec' to record the time saved on cache hits using the new implementation.
PiperOrigin-RevId: 573019115
2023-10-12 14:56:02 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
jax authors
6869000636 Modify the name for JAX compilation cache adoption metric.
Rename the adoption metric from '/jax/compilation_cache/tasks_using_original_cache' to '/jax/compilation_cache/tasks_using_cache' to record the number of tasks using compilation cache for both original and new implementations.

We realized that one adoption metric is enough to monitor and compare the original and new cache adoption rates. To avoid confusion of the word 'original' in the metric name, we decide to change the metric name to better describe the purpose of this adoption metric.

Testing: revised unit test.
PiperOrigin-RevId: 565197027
2023-09-13 16:44:10 -07:00
jax authors
80f6151110 Instrument metrics to track cache hit rate of original JAX compilation cache.
Metrics:
1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of  the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true.
2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation.
3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository.

Created a context manager to register/unregister event listeners.

PiperOrigin-RevId: 561771262
2023-08-31 15:05:23 -07:00
jax authors
fe107fd119 Create a context manager to register/unregister event duration listeners.
PiperOrigin-RevId: 559209418
2023-08-22 13:35:45 -07:00
jax authors
27d19ee233 Instrument a metric to measure the number of tasks using compilation cache in JAX -> PJRT.
Create the metric '/jax/compilation_cache/tasks_using_original_cache' to record the number of tasks using compilation cache.

PiperOrigin-RevId: 559159282
2023-08-22 10:48:58 -07:00
Peter Hawkins
a259df0d76 Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.
Refactoring only, no user-visible changes intended.

PiperOrigin-RevId: 557116160
2023-08-15 06:39:46 -07:00
jax authors
35d33f620c Instrument metrics to measure compilation cache savings in JAX -> PJRT.
Create metrics:
1) '/jax/compilation_cache/cache_retrieval_time_sec' to record the time duration for getting cache entries.
2) '/jax/compilation_cache/original_compile_time_saved_sec' to record the time saved on cache hits.

PiperOrigin-RevId: 556243588
2023-08-12 00:20:42 -07:00
jax authors
3b28d4e180 Refactor Jax compilation cache key generation.
This is in preparation for introducing a more robust key-generation
algorithm.

This refactoring does not introduce any change in behavior.

Testing: refactored unit tests and test workload.
PiperOrigin-RevId: 551744892
2023-07-27 23:01:00 -07:00