[Docs] Remove --xla_gpu_enable_triton_softmax_fusion from docs

This flag has been a no-op for a while.

PiperOrigin-RevId: 715491248
This commit is contained in:
Benjamin Chetioui 2025-01-14 12:50:22 -08:00 committed by jax authors
parent 6851700ed4
commit 57a259f447

View File

@ -81,7 +81,6 @@ XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py
| `xla_gpu_enable_pipelined_reduce_scatter` | Boolean (true/false) | Enable pipelinling of reduce-scatter instructions. |
| `xla_gpu_enable_pipelined_all_reduce` | Boolean (true/false) | Enable pipelinling of all-reduce instructions. |
| `xla_gpu_enable_while_loop_double_buffering` | Boolean (true/false) | Enable double-buffering for while loop. |
| `xla_gpu_enable_triton_softmax_fusion` | Boolean (true/false) | Use Triton-based Softmax fusion. |
| `xla_gpu_enable_all_gather_combine_by_dim` | Boolean (true/false) | Combine all-gather ops with the same gather dimension or irrespective of their dimension. |
| `xla_gpu_enable_reduce_scatter_combine_by_dim` | Boolean (true/false) | Combine reduce-scatter ops with the same dimension or irrespective of their dimension. |