docs: compilation_cache_expect_pgle option

This commit is contained in:
Olli Lupton 2025-04-01 09:59:38 +00:00
parent 6242ffb1ca
commit 297a4f42de

View File

@ -71,6 +71,10 @@ JAX will collect profile information and recompile a module in a single run. Whi
in manual mode you need to run a task twice, the first time to collect and save profiles
and the second to compile and run with provided data.
**Important**: the JAX profiler, which is used by both of the PGLE workflows documented
below, cannot co-exist with the NVIDIA Nsight Systems profiler. This limitation can be
avoided by using the JAX compilation cache, as described below.
### Auto PGLE
The auto PGLE can be turned on by setting the following environment variables:
@ -129,6 +133,28 @@ with config.enable_pgle(True), config.pgle_profiling_runs(1):
train_step_compiled()
```
#### Collecting NVIDIA Nsight Systems profiles when using AutoPGLE
[jax#24910](https://github.com/jax-ml/jax/pull/24910) (JAX v0.5.1 and newer) added a
new JAX configuration option, `JAX_COMPILATION_CACHE_EXPECT_PGLE`, which tells JAX to
attempt to load PGLE-optimized compiled functions from the persistent compilation
cache.
This allows a two-step process, where the first step writes a PGLE-optimized function
to the cache:
```bash
export JAX_ENABLE_COMPILATION_CACHE=yes # not strictly needed, on by default
export JAX_COMPILATION_CACHE_DIR=/root/jax_cache
JAX_ENABLE_PGLE=yes python my-model.py
```
And the second step uses Nsight Systems and loads the PGLE-optimized function from the
cache:
```bash
JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python my-model.py
```
See also [this page](
https://docs.jax.dev/en/latest/persistent_compilation_cache.html#pitfalls) for more
information about the persistent compilation cache and possible pitfalls.
### Manual PGLE
If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is: