mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Move PGLE documentation to JAX docs.
PiperOrigin-RevId: 739865595
This commit is contained in:
parent
a2475a66c5
commit
4da1faf5b6
@ -1,6 +1,6 @@
|
||||
# GPU performance tips
|
||||
|
||||
<!--* freshness: { reviewed: '2024-06-10' } *-->
|
||||
<!--* freshness: { reviewed: '2025-03-20' } *-->
|
||||
|
||||
This document focuses on performance tips for neural network workloads
|
||||
|
||||
@ -58,7 +58,147 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
|
||||
* **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for
|
||||
any GEMM that it supports. The default value is False.
|
||||
|
||||
### Communication flags
|
||||
## Communication tips
|
||||
|
||||
### Auto and manual PGLE
|
||||
|
||||
The Profile Guided Latency Estimator (PGLE) workflow measures the actual running time
|
||||
of compute and collectives, the the profile information is fed back into XLA compiler
|
||||
for a better scheduling decision.
|
||||
|
||||
The Profile Guided Latency Estimator can be used manually or automatically. In the auto mode
|
||||
JAX will collect profile information and recompile a module in a single run. While
|
||||
in manual mode you need to run a task twice, the first time to collect and save profiles
|
||||
and the second to compile and run with provided data.
|
||||
|
||||
### Auto PGLE
|
||||
The auto PGLE can be turned on by setting the following environment variables:
|
||||
|
||||
Mandatory:
|
||||
```bash
|
||||
export JAX_ENABLE_PGLE=true
|
||||
|
||||
# For JAX version <= 0.5.0 make sure to include:
|
||||
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"
|
||||
```
|
||||
|
||||
Optional:
|
||||
```bash
|
||||
export JAX_PGLE_PROFILING_RUNS=3
|
||||
export JAX_PGLE_AGGREGATION_PERCENTILE=85
|
||||
|
||||
# Right now the auto PGLE profile collection doesn't work with command buffer.
|
||||
# If the command buffer is enabled, Auto PGLE will disable it during profile
|
||||
# colletion and enable it back after the recompilation. If you need to have a
|
||||
# consistent command buffer logic with and with PGLE profile you can disable it
|
||||
# manually:
|
||||
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''"
|
||||
```
|
||||
|
||||
Or in the JAX this can be set as the following:
|
||||
|
||||
```
|
||||
import jax
|
||||
from jax._src import config
|
||||
|
||||
with config.enable_pgle(True), config.pgle_profiling_runs(1):
|
||||
# Run with the profiler collecting performance information.
|
||||
train_step()
|
||||
# Automatically re-compile with PGLE profile results
|
||||
train_step()
|
||||
...
|
||||
```
|
||||
|
||||
You can control amount of reruns used to collect profile data by changing `JAX_PGLE_PROFILING_RUNS`.
|
||||
Increasing this parameter would lead to better profile information, but it will also increase the
|
||||
amount of non-optimized training steps.
|
||||
|
||||
Decreasing the `JAX_PGLE_AGGREGATION_PERCENTILE` parameter might help in case when performance between steps is too noisy to filter out a non-relevant measures.
|
||||
|
||||
**Attention:** Auto PGLE doesn't work for pre-compiled modules. Since JAX need to recompile the module during execution the auto PGLE will not work neither for AoT nor for the following case:
|
||||
|
||||
```
|
||||
import jax
|
||||
from jax._src import config
|
||||
|
||||
train_step_compiled = train_step().lower().compile()
|
||||
|
||||
with config.enable_pgle(True), config.pgle_profiling_runs(1):
|
||||
train_step_compiled()
|
||||
# No effect since module was pre-compiled.
|
||||
train_step_compiled()
|
||||
```
|
||||
|
||||
### Manual PGLE
|
||||
|
||||
If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is:
|
||||
|
||||
- 1. Run your workload once, with async collectives and latency hiding scheduler enabled.
|
||||
|
||||
You could do so by setting:
|
||||
|
||||
```bash
|
||||
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"
|
||||
```
|
||||
|
||||
- 2. Collect and post process a profile by using JAX profiler, saving the extracted instruction latencies into a binary protobuf file.
|
||||
|
||||
```python
|
||||
import os
|
||||
from etils import epath
|
||||
import jax
|
||||
from jax.experimental import profiler as exp_profiler
|
||||
|
||||
# Define your profile directory
|
||||
profile_dir = 'gs://my_bucket/profile'
|
||||
jax.profiler.start_trace(profile_dir)
|
||||
|
||||
# run your workflow
|
||||
# for i in range(10):
|
||||
# train_step()
|
||||
|
||||
# Stop trace
|
||||
jax.profiler.stop_trace()
|
||||
profile_dir = epath.Path(profile_dir)
|
||||
directories = profile_dir.glob('plugins/profile/*/')
|
||||
directories = [d for d in directories if d.is_dir()]
|
||||
rundir = directories[-1]
|
||||
logging.info('rundir: %s', rundir)
|
||||
|
||||
# Post process the profile
|
||||
fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir))
|
||||
|
||||
# Save the profile proto to a file.
|
||||
dump_dir = rundir / 'profile.pb'
|
||||
dump_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
dump_dir.write_bytes(fdo_profile)
|
||||
|
||||
```
|
||||
|
||||
After this step, you will get a `profile.pb` file under the `rundir` printed in the code.
|
||||
|
||||
- 3. Run the workload again feeding that file into the compilation.
|
||||
|
||||
You need to pass the `profile.pb` file to the `--xla_gpu_pgle_profile_file_or_directory_path` flag.
|
||||
|
||||
```bash
|
||||
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb"
|
||||
```
|
||||
|
||||
To enable logging in the XLA and check if the profile is good, set the logging level to include `INFO`:
|
||||
|
||||
```bash
|
||||
export TF_CPP_MIN_LOG_LEVEL=0
|
||||
```
|
||||
|
||||
Run the real workflow, if you found these loggings in the running log, it means the profiler is used in the latency hiding scheduler:
|
||||
|
||||
```
|
||||
2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb
|
||||
2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator
|
||||
```
|
||||
|
||||
#### Flags
|
||||
|
||||
* **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding
|
||||
schedulers to overlap asynchronous communication with computation efficiently.
|
||||
|
Loading…
x
Reference in New Issue
Block a user