mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00
[JAX] Add missing preset for X9 dot optimization on BF16/BF16 -> F32.
Two PRs that support this feature have been submitted to stablehlo and openxla. Now we could do the last step - enable it in JAX. PiperOrigin-RevId: 736799241
This commit is contained in:
parent
cbece0b00b
commit
43b78c539f
@ -66,6 +66,7 @@ from jax._src.lax.utils import (
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
|
||||
PartitionSpec as P, canonicalize_sharding)
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
|
||||
@ -1911,6 +1912,9 @@ class DotAlgorithmPreset(enum.Enum):
|
||||
BF16_BF16_F32_X6 = enum.auto()
|
||||
"""Like ``BF16_BF16_F32_X3``, but using 6 operations instead of 3."""
|
||||
|
||||
BF16_BF16_F32_X9 = enum.auto()
|
||||
"""Like ``BF16_BF16_F32_X3``, but using 9 operations instead of 3."""
|
||||
|
||||
TF32_TF32_F32 = enum.auto()
|
||||
TF32_TF32_F32_X3 = enum.auto()
|
||||
"""The ``_X3`` suffix indicates that the algorithm uses 3 operations to
|
||||
@ -2064,6 +2068,13 @@ class DotAlgorithmPreset(enum.Enum):
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False)
|
||||
case DotAlgorithmPreset.BF16_BF16_F32_X6:
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False)
|
||||
case DotAlgorithmPreset.BF16_BF16_F32_X9:
|
||||
if xla_extension_version < 320:
|
||||
raise ValueError(
|
||||
"The dot algorithm BF16_BF16_F32_X9 requires XLA extension "
|
||||
"version >= 320."
|
||||
)
|
||||
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 9, False)
|
||||
case DotAlgorithmPreset.TF32_TF32_F32:
|
||||
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False)
|
||||
case DotAlgorithmPreset.TF32_TF32_F32_X3:
|
||||
|
@ -49,6 +49,7 @@ from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.util import NumpyComplexWarning, safe_zip
|
||||
from jax._src.tree_util import tree_map
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -1105,6 +1106,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
(lax.DotAlgorithmPreset.BF16_BF16_F32, [dtypes.bfloat16]),
|
||||
(lax.DotAlgorithmPreset.BF16_BF16_F32_X3, [np.float32]),
|
||||
(lax.DotAlgorithmPreset.BF16_BF16_F32_X6, [np.float32]),
|
||||
(lax.DotAlgorithmPreset.BF16_BF16_F32_X9, [np.float32]),
|
||||
(lax.DotAlgorithmPreset.TF32_TF32_F32, [np.float32]),
|
||||
(lax.DotAlgorithmPreset.TF32_TF32_F32_X3, [np.float32]),
|
||||
(lax.DotAlgorithmPreset.F32_F32_F32, [np.float32]),
|
||||
@ -1126,6 +1128,11 @@ class LaxTest(jtu.JaxTestCase):
|
||||
raise SkipTest(
|
||||
f"The dot algorithm '{algorithm}' is not supported on CPU.")
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
if (algorithm == lax.DotAlgorithmPreset.BF16_BF16_F32_X9 and
|
||||
xla_extension_version < 320):
|
||||
raise SkipTest(
|
||||
f"The dot algorithm ${algorithm} requires XLA extension version "
|
||||
">= 320.")
|
||||
# GPU algorithm support is a little spotty. It is checked in
|
||||
# xla/service/algorithm_util.cc and the logic is copied here.
|
||||
if algorithm in {
|
||||
@ -1134,6 +1141,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
lax.DotAlgorithmPreset.BF16_BF16_F32,
|
||||
lax.DotAlgorithmPreset.BF16_BF16_F32_X3,
|
||||
lax.DotAlgorithmPreset.BF16_BF16_F32_X6,
|
||||
lax.DotAlgorithmPreset.BF16_BF16_F32_X9,
|
||||
}:
|
||||
if not jtu.is_cuda_compute_capability_at_least("8.0"):
|
||||
raise SkipTest(
|
||||
|
Loading…
x
Reference in New Issue
Block a user