Merge pull request #25444 from justinjfu:gpu_docs_update

PiperOrigin-RevId: 705938411
This commit is contained in:
jax authors 2024-12-13 11:05:31 -08:00
commit 99b390ce96

View File

@ -44,11 +44,8 @@ example, we can add this to the top of a Python file:
```python
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
```
@ -58,9 +55,6 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
### Code generation flags
* **--xla_gpu_enable_triton_softmax_fusion** This flag enables an automatic
softmax fusion, based on pattern-matching backed by Triton code generation.
The default value is False.
* **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for
any GEMM that it supports. The default value is False.