1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

[JAX] Add missing preset for X9 dot optimization on BF16/BF16 -> F32.

Two PRs that support this feature have been submitted to stablehlo and openxla.
Now we could do the last step - enable it in JAX.

PiperOrigin-RevId: 736799241
This commit is contained in:
Ilya Tikhonovskiy 2025-03-14 02:57:12 -07:00 committed by jax authors
parent cbece0b00b
commit 43b78c539f
2 changed files with 19 additions and 0 deletions
jax/_src/lax
tests

@ -66,6 +66,7 @@ from jax._src.lax.utils import (
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_extension_version
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
PartitionSpec as P, canonicalize_sharding)
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
@ -1911,6 +1912,9 @@ class DotAlgorithmPreset(enum.Enum):
BF16_BF16_F32_X6 = enum.auto()
"""Like ``BF16_BF16_F32_X3``, but using 6 operations instead of 3."""
BF16_BF16_F32_X9 = enum.auto()
"""Like ``BF16_BF16_F32_X3``, but using 9 operations instead of 3."""
TF32_TF32_F32 = enum.auto()
TF32_TF32_F32_X3 = enum.auto()
"""The ``_X3`` suffix indicates that the algorithm uses 3 operations to
@ -2064,6 +2068,13 @@ class DotAlgorithmPreset(enum.Enum):
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False)
case DotAlgorithmPreset.BF16_BF16_F32_X6:
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False)
case DotAlgorithmPreset.BF16_BF16_F32_X9:
if xla_extension_version < 320:
raise ValueError(
"The dot algorithm BF16_BF16_F32_X9 requires XLA extension "
"version >= 320."
)
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 9, False)
case DotAlgorithmPreset.TF32_TF32_F32:
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False)
case DotAlgorithmPreset.TF32_TF32_F32_X3:

@ -49,6 +49,7 @@ from jax._src.internal_test_util import lax_test_util
from jax._src.lax import lax as lax_internal
from jax._src.util import NumpyComplexWarning, safe_zip
from jax._src.tree_util import tree_map
from jax._src.lib import xla_extension_version
config.parse_flags_with_absl()
@ -1105,6 +1106,7 @@ class LaxTest(jtu.JaxTestCase):
(lax.DotAlgorithmPreset.BF16_BF16_F32, [dtypes.bfloat16]),
(lax.DotAlgorithmPreset.BF16_BF16_F32_X3, [np.float32]),
(lax.DotAlgorithmPreset.BF16_BF16_F32_X6, [np.float32]),
(lax.DotAlgorithmPreset.BF16_BF16_F32_X9, [np.float32]),
(lax.DotAlgorithmPreset.TF32_TF32_F32, [np.float32]),
(lax.DotAlgorithmPreset.TF32_TF32_F32_X3, [np.float32]),
(lax.DotAlgorithmPreset.F32_F32_F32, [np.float32]),
@ -1126,6 +1128,11 @@ class LaxTest(jtu.JaxTestCase):
raise SkipTest(
f"The dot algorithm '{algorithm}' is not supported on CPU.")
if jtu.test_device_matches(["gpu"]):
if (algorithm == lax.DotAlgorithmPreset.BF16_BF16_F32_X9 and
xla_extension_version < 320):
raise SkipTest(
f"The dot algorithm ${algorithm} requires XLA extension version "
">= 320.")
# GPU algorithm support is a little spotty. It is checked in
# xla/service/algorithm_util.cc and the logic is copied here.
if algorithm in {
@ -1134,6 +1141,7 @@ class LaxTest(jtu.JaxTestCase):
lax.DotAlgorithmPreset.BF16_BF16_F32,
lax.DotAlgorithmPreset.BF16_BF16_F32_X3,
lax.DotAlgorithmPreset.BF16_BF16_F32_X6,
lax.DotAlgorithmPreset.BF16_BF16_F32_X9,
}:
if not jtu.is_cuda_compute_capability_at_least("8.0"):
raise SkipTest(