Move PGLE documentation to JAX docs.

PiperOrigin-RevId: 739865595
This commit is contained in:
jax authors 2025-03-24 02:49:43 -07:00
parent a2475a66c5
commit 4da1faf5b6

View File

@ -1,6 +1,6 @@
# GPU performance tips
<!--* freshness: { reviewed: '2024-06-10' } *-->
<!--* freshness: { reviewed: '2025-03-20' } *-->
This document focuses on performance tips for neural network workloads
@ -58,7 +58,147 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
* **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for
any GEMM that it supports. The default value is False.
### Communication flags
## Communication tips
### Auto and manual PGLE
The Profile Guided Latency Estimator (PGLE) workflow measures the actual running time
of compute and collectives, the the profile information is fed back into XLA compiler
for a better scheduling decision.
The Profile Guided Latency Estimator can be used manually or automatically. In the auto mode
JAX will collect profile information and recompile a module in a single run. While
in manual mode you need to run a task twice, the first time to collect and save profiles
and the second to compile and run with provided data.
### Auto PGLE
The auto PGLE can be turned on by setting the following environment variables:
Mandatory:
```bash
export JAX_ENABLE_PGLE=true
# For JAX version <= 0.5.0 make sure to include:
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"
```
Optional:
```bash
export JAX_PGLE_PROFILING_RUNS=3
export JAX_PGLE_AGGREGATION_PERCENTILE=85
# Right now the auto PGLE profile collection doesn't work with command buffer.
# If the command buffer is enabled, Auto PGLE will disable it during profile
# colletion and enable it back after the recompilation. If you need to have a
# consistent command buffer logic with and with PGLE profile you can disable it
# manually:
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''"
```
Or in the JAX this can be set as the following:
```
import jax
from jax._src import config
with config.enable_pgle(True), config.pgle_profiling_runs(1):
# Run with the profiler collecting performance information.
train_step()
# Automatically re-compile with PGLE profile results
train_step()
...
```
You can control amount of reruns used to collect profile data by changing `JAX_PGLE_PROFILING_RUNS`.
Increasing this parameter would lead to better profile information, but it will also increase the
amount of non-optimized training steps.
Decreasing the `JAX_PGLE_AGGREGATION_PERCENTILE` parameter might help in case when performance between steps is too noisy to filter out a non-relevant measures.
**Attention:** Auto PGLE doesn't work for pre-compiled modules. Since JAX need to recompile the module during execution the auto PGLE will not work neither for AoT nor for the following case:
```
import jax
from jax._src import config
train_step_compiled = train_step().lower().compile()
with config.enable_pgle(True), config.pgle_profiling_runs(1):
train_step_compiled()
# No effect since module was pre-compiled.
train_step_compiled()
```
### Manual PGLE
If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is:
- 1. Run your workload once, with async collectives and latency hiding scheduler enabled.
You could do so by setting:
```bash
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"
```
- 2. Collect and post process a profile by using JAX profiler, saving the extracted instruction latencies into a binary protobuf file.
```python
import os
from etils import epath
import jax
from jax.experimental import profiler as exp_profiler
# Define your profile directory
profile_dir = 'gs://my_bucket/profile'
jax.profiler.start_trace(profile_dir)
# run your workflow
# for i in range(10):
# train_step()
# Stop trace
jax.profiler.stop_trace()
profile_dir = epath.Path(profile_dir)
directories = profile_dir.glob('plugins/profile/*/')
directories = [d for d in directories if d.is_dir()]
rundir = directories[-1]
logging.info('rundir: %s', rundir)
# Post process the profile
fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir))
# Save the profile proto to a file.
dump_dir = rundir / 'profile.pb'
dump_dir.parent.mkdir(parents=True, exist_ok=True)
dump_dir.write_bytes(fdo_profile)
```
After this step, you will get a `profile.pb` file under the `rundir` printed in the code.
- 3. Run the workload again feeding that file into the compilation.
You need to pass the `profile.pb` file to the `--xla_gpu_pgle_profile_file_or_directory_path` flag.
```bash
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb"
```
To enable logging in the XLA and check if the profile is good, set the logging level to include `INFO`:
```bash
export TF_CPP_MIN_LOG_LEVEL=0
```
Run the real workflow, if you found these loggings in the running log, it means the profiler is used in the latency hiding scheduler:
```
2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb
2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator
```
#### Flags
* **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding
schedulers to overlap asynchronous communication with computation efficiently.