The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.
This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.
If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.
PiperOrigin-RevId: 713296722
We frequently need to condition tests on the current version of jaxlib. This change exposes the version tuple directly as part of `jtu` so that we don't need to import `jax._src.lib.version` in the tests.
PiperOrigin-RevId: 698097487
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike
Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.
Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.
PiperOrigin-RevId: 691496997
suppressions.
We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
- Create metric '/jax/compilation_cache/task_disabled_cache' as a beacon metric to monitor tasks which have disabled compilation cache.
- Modified existing logic for reporting the '/jax/compilation_cache/tasks_using_cache' metric and make it easier to find the two adoption related metrics in the code.
PiperOrigin-RevId: 654970654
The default `fn.__name__` was added in `_one_to_one_unop` but not other functions so that it leads to many downstream function wrappers giving unmeaningful names while debugging. For instance,
When a JAX numpy primitive `lax.add` is wrapped by `lu.WrappedFun`, `print(wrapped)` will give,
```
Wrapped function:
0 : _argnums_partial ((0, 1), ())
1 : flatten_fun (PyTreeDef(((*, *), {})),)
2 : result_paths ()
Core: fn
```
instead of
```
Wrapped function:
0 : _argnums_partial ((0, 1), ())
1 : flatten_fun (PyTreeDef(((*, *), {})),)
2 : result_paths ()
Core: add
```
PiperOrigin-RevId: 627417452
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.
PiperOrigin-RevId: 621341638
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.
Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.
Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.
PiperOrigin-RevId: 596696013
Currently, the persistent compilation cache has a time
threshold: the entry is cached only if the compilation
time is less than the threshold. If compilation happens
to take a while, but the resulting executable is small,
there is nothing that prevents caching. This can result
in a large number of small files in the cache.
Introduce a size threshold. If the resulting executable's
size (after serialization and compression) is less than
this threshold, don't cache. This check is in addition to
the compilation time check described above.
Testing: new unit test, test workload.
PiperOrigin-RevId: 595815611
CPU support was originally added to the compilation cache
in anticipation of the availability of CPU acceleration
compilation. Since this is not available and the
--xla_cpu_use_xla_runtime flag has been deprecated,
cleanup the code and test.
Testing: test workload, revised unit test.
PiperOrigin-RevId: 592962316
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.
Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.
Testing: test workload.
PiperOrigin-RevId: 590378036
The cache_used metric is incremented once per task and is
used to determine how many tasks are using the Jax
compilation cache. The current implementation and unit
test are not thread safe. This results in the test
failing when unit tests are executed concurrently.
The fix is to make the implementation thread safe and
to update the test to examine the delta in the metric.
Testing: Cloud TPU VM testing; test workload.
PiperOrigin-RevId: 589174850
To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.
Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.
Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
The compilation_cache_test had an exclusion since the C PjRt
topology description had not been implemented. Now that it is
available, remove the exclusion.
PiperOrigin-RevId: 581396824
Remove the code which checks if the min compile time is greater than zero. After this change, we can catch cache_misses when min compile time is zero.
Testing: revised unit test.
PiperOrigin-RevId: 579951415
When the value in --jax_xla_profile_version changes, all tracing
and compilation caches should be invalidated since the XLA programs
need to be recompiled with the new XLA-AutoFDO profile.
Testing:
. New unit test.
. Test workload with instrumentation to repeatedly change
the profile version. Before/after comparison.
PiperOrigin-RevId: 577280639
Metrics:
1) '/jax/compilation_cache/cache_hits' to track the number of times the cached executable is successfully returned from a cache read using the new implementation.
2) '/jax/compilation_cache/compile_time_saved_sec' to record the time saved on cache hits using the new implementation.
PiperOrigin-RevId: 573019115
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
Rename the adoption metric from '/jax/compilation_cache/tasks_using_original_cache' to '/jax/compilation_cache/tasks_using_cache' to record the number of tasks using compilation cache for both original and new implementations.
We realized that one adoption metric is enough to monitor and compare the original and new cache adoption rates. To avoid confusion of the word 'original' in the metric name, we decide to change the metric name to better describe the purpose of this adoption metric.
Testing: revised unit test.
PiperOrigin-RevId: 565197027
Metrics:
1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true.
2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation.
3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository.
Created a context manager to register/unregister event listeners.
PiperOrigin-RevId: 561771262
Create the metric '/jax/compilation_cache/tasks_using_original_cache' to record the number of tasks using compilation cache.
PiperOrigin-RevId: 559159282
Create metrics:
1) '/jax/compilation_cache/cache_retrieval_time_sec' to record the time duration for getting cache entries.
2) '/jax/compilation_cache/original_compile_time_saved_sec' to record the time saved on cache hits.
PiperOrigin-RevId: 556243588
This is in preparation for introducing a more robust key-generation
algorithm.
This refactoring does not introduce any change in behavior.
Testing: refactored unit tests and test workload.
PiperOrigin-RevId: 551744892