mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #25444 from justinjfu:gpu_docs_update
PiperOrigin-RevId: 705938411
This commit is contained in:
commit
99b390ce96
@ -44,11 +44,8 @@ example, we can add this to the top of a Python file:
|
||||
```python
|
||||
import os
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_gpu_enable_triton_softmax_fusion=true '
|
||||
'--xla_gpu_triton_gemm_any=True '
|
||||
'--xla_gpu_enable_async_collectives=true '
|
||||
'--xla_gpu_enable_latency_hiding_scheduler=true '
|
||||
'--xla_gpu_enable_highest_priority_async_stream=true '
|
||||
)
|
||||
```
|
||||
|
||||
@ -58,9 +55,6 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
|
||||
|
||||
### Code generation flags
|
||||
|
||||
* **--xla_gpu_enable_triton_softmax_fusion** This flag enables an automatic
|
||||
softmax fusion, based on pattern-matching backed by Triton code generation.
|
||||
The default value is False.
|
||||
* **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for
|
||||
any GEMM that it supports. The default value is False.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user