1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Remove deprecated XLA GPU flags.

This commit is contained in:
Justin Fu 2024-12-12 09:55:41 -08:00
parent 99d675ac25
commit 1021603f85

@ -44,11 +44,8 @@ example, we can add this to the top of a Python file:
```python
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
```
@ -58,9 +55,6 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
### Code generation flags
* **--xla_gpu_enable_triton_softmax_fusion** This flag enables an automatic
softmax fusion, based on pattern-matching backed by Triton code generation.
The default value is False.
* **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for
any GEMM that it supports. The default value is False.