Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data.
1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results.
2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner.
3. The profile session runner should be passed to pxla.py and then called.
4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h
5. Once FDO is collected we need to share it between hosts to keep deterministic compilation.
PiperOrigin-RevId: 638197166
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.
PiperOrigin-RevId: 621341638
This CL adds `supports_executable_serialization` as an IFRT client attribute indicating whether the executables produced by a given IFRT implementation are serializable or not. This is based on a principle where `xla::ifrt::Client::attributes()` returns a set of attributes representing the "capabilities" of an IFRT implementation, so that the users of IFRT can act based on such capabilities without having to know the exact backend that they are using.
This change is backward compatible as IFRT backends that do not implement `supports_executable_serialization` are assumed to implement executable serialization.
PiperOrigin-RevId: 606799188
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.
Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.
Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.
PiperOrigin-RevId: 596696013
CPU support was originally added to the compilation cache
in anticipation of the availability of CPU acceleration
compilation. Since this is not available and the
--xla_cpu_use_xla_runtime flag has been deprecated,
cleanup the code and test.
Testing: test workload, revised unit test.
PiperOrigin-RevId: 592962316
Previously the environment variable JAX_DUMP_IR_TO controlled
whether and where to dump the MLIR module prior to compilation. Now we move the code for that support from
compiler.py to mlir.py, so that it can be used in other
parts of the code. We also add support for logging to Sponge.
Using this support we now log the module on errors from
refine_polymorphic_shapes.
PiperOrigin-RevId: 592099633
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.
Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.
Testing: test workload.
PiperOrigin-RevId: 590378036
The cache_used metric is incremented once per task and is
used to determine how many tasks are using the Jax
compilation cache. The current implementation and unit
test are not thread safe. This results in the test
failing when unit tests are executed concurrently.
The fix is to make the implementation thread safe and
to update the test to examine the delta in the metric.
Testing: Cloud TPU VM testing; test workload.
PiperOrigin-RevId: 589174850
The logs related to compilation cache ended up being quite chatty,
which is quite unlike the other logs in JAX. This downgrades a bunch
of them to debug, as they can always be enabled independently
using JAX config. This should also fix the recent failures in
logging_test.py.
XLA-AutoFDO is supported only for TPUs, so requesting the latest
profile version for non-TPU workloads is unnecessary and can delay
the completion of initialization.
Testing: test workload.
PiperOrigin-RevId: 584148686
Remove the code which checks if the min compile time is greater than zero. After this change, we can catch cache_misses when min compile time is zero.
Testing: revised unit test.
PiperOrigin-RevId: 579951415
Metrics:
1) '/jax/compilation_cache/cache_hits' to track the number of times the cached executable is successfully returned from a cache read using the new implementation.
2) '/jax/compilation_cache/compile_time_saved_sec' to record the time saved on cache hits using the new implementation.
PiperOrigin-RevId: 573019115
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
PiperOrigin-RevId: 571932143
We want to suppress detailed logging (notably on TPU, which has pretty verbose detailed logging) separately from disabling HLO dumps. Even if we don't print detailed log information, it's quite surprising if an HLO module doesn't show up in the set of modules dumped by XLA.
PiperOrigin-RevId: 570374492
Rename the adoption metric from '/jax/compilation_cache/tasks_using_original_cache' to '/jax/compilation_cache/tasks_using_cache' to record the number of tasks using compilation cache for both original and new implementations.
We realized that one adoption metric is enough to monitor and compare the original and new cache adoption rates. To avoid confusion of the word 'original' in the metric name, we decide to change the metric name to better describe the purpose of this adoption metric.
Testing: revised unit test.
PiperOrigin-RevId: 565197027
The original cache key generation hashes devices and backend. This
is not future proof: it does not work for accelerators other than
TPUs. Change this to use the serialized version of
PjRtTopologyDescription which is supported for all accelerators.
Note:
. CPU and PjRt C API not supported as yet.
. Stream Executor will not be supported.
Testing: revised unit test.
PiperOrigin-RevId: 564461564
Metrics:
1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true.
2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation.
3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository.
Created a context manager to register/unregister event listeners.
PiperOrigin-RevId: 561771262
Create the metric '/jax/compilation_cache/tasks_using_original_cache' to record the number of tasks using compilation cache.
PiperOrigin-RevId: 559159282