This allows using external profiling tools, such as Nsight Systems,
with the automatic PGLE workflow supported by JAX with a simple two-step
workflow:
export JAX_COMPILATION_CACHE_DIR=...
JAX_ENABLE_PGLE=yes python model.py
JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python model.py
These flags control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. They can be set via the command line, e.g. . Valid values are between -1.0 and 1.0, default is 0.0.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
- Create metric '/jax/compilation_cache/task_disabled_cache' as a beacon metric to monitor tasks which have disabled compilation cache.
- Modified existing logic for reporting the '/jax/compilation_cache/tasks_using_cache' metric and make it easier to find the two adoption related metrics in the code.
PiperOrigin-RevId: 654970654
The new naming highlights that we have two kinds of configuration options:
flags, set at most once, and states, which can be changed locally per thread
via a context manager.
The renames are
* FlagHolder -> Flag
* DEFINE_<type> -> <type>_flag
* _StateContextManager -> State
* define_<type>_state -> <type>_state
Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data.
1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results.
2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner.
3. The profile session runner should be passed to pxla.py and then called.
4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h
5. Once FDO is collected we need to share it between hosts to keep deterministic compilation.
PiperOrigin-RevId: 638197166
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.
PiperOrigin-RevId: 621341638
This CL adds `supports_executable_serialization` as an IFRT client attribute indicating whether the executables produced by a given IFRT implementation are serializable or not. This is based on a principle where `xla::ifrt::Client::attributes()` returns a set of attributes representing the "capabilities" of an IFRT implementation, so that the users of IFRT can act based on such capabilities without having to know the exact backend that they are using.
This change is backward compatible as IFRT backends that do not implement `supports_executable_serialization` are assumed to implement executable serialization.
PiperOrigin-RevId: 606799188
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.
Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.
Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.
PiperOrigin-RevId: 596696013
CPU support was originally added to the compilation cache
in anticipation of the availability of CPU acceleration
compilation. Since this is not available and the
--xla_cpu_use_xla_runtime flag has been deprecated,
cleanup the code and test.
Testing: test workload, revised unit test.
PiperOrigin-RevId: 592962316