This commit is contained in:
shuw 2025-02-25 17:11:56 +00:00
parent d9456f36b3
commit 681ee18436
3 changed files with 33 additions and 31 deletions

View File

@ -28,10 +28,6 @@ import jax.numpy as jnp
from jax import custom_jvp
from jax import lax
from jax._src import config
from jax._src.cudnn.scaled_matmul_stablehlo import (
scaled_matmul_wrapper as cudnn_scaled_matmul,
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
BlockScaleConfig)
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
@ -40,6 +36,10 @@ from jax._src.core import AxisName
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention as cudnn_dot_product_attention, MaskType)
from jax._src.cudnn.scaled_matmul_stablehlo import (
scaled_matmul_wrapper as cudnn_scaled_matmul,
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
BlockScaleConfig)
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.numpy import util as numpy_util

View File

@ -104,7 +104,7 @@ def create_mxfp8_configs_if_available():
global_scale=None,
infer_only=False
)
return [_create_mxfp8_config() for _ in range(3)]

View File

@ -8,6 +8,7 @@ import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import PartitionSpec, NamedSharding
from jax._src import config
from jax._src import dtypes as _dtypes
from jax._src import test_util as jtu
from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
from jax._src.cudnn.scaled_matmul_stablehlo import (
@ -19,7 +20,6 @@ from jax._src.cudnn.scaled_matmul_stablehlo import (
BlockScaleConfig,
)
config.parse_flags_with_absl()
input_shardings = [
(("dp", None, "tp"), ("dp", None, "tp")),
@ -57,22 +57,6 @@ sharding_configs = {
for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec)
}
def create_mxfp8_configs_if_available():
if _dtypes.float8_e8m0fnu is None:
raise unittest.SkipTest("float8_e8m0fnu is not available.")
def _create_mxfp8_config():
return BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
return [_create_mxfp8_config() for _ in range(3)]
def generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
@ -139,11 +123,22 @@ def shard_and_device_put(
return a, b, in_shardings
def create_mxfp8_configs():
if _dtypes.float8_e8m0fnu is None:
return None
def get_hlo_text(in_shardings, block_scale_configs=None):
if block_scale_configs is None:
block_scale_configs = create_mxfp8_configs_if_available()
mxfp8_config = BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
return [mxfp8_config for _ in range(3)]
def get_hlo_text(in_shardings, block_scale_configs):
mesh_names = ("dp", "tp")
devices = np.array(jax.local_devices()[:4]).reshape((2, 2))
mesh = Mesh(devices, mesh_names)
@ -170,21 +165,26 @@ class ScaledMatmulTest(jtu.JaxTestCase):
except RuntimeError as e:
self.skipTest(str(e))
return
if _dtypes.float8_e8m0fnu is None:
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
if cudnn_version < 90700:
self.skipTest("Requires >= cuDNN 9.7.0")
if not jtu.is_cuda_compute_capability_at_least("10.0"):
self.skipTest("Requires at least Blackwell arch")
mxfp8_configs = create_mxfp8_configs()
@jtu.sample_product(
in_shardings=sharding_configs,
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
def test_collectives(self, in_shardings):
def test_collectives(self, in_shardings, block_scale_configs):
if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4:
self.skipTest("Partition Test enabled for at least 4 GPUs")
expected_hlo = sharding_configs[in_shardings][0]
hlo_text = get_hlo_text(in_shardings)
hlo_text = get_hlo_text(in_shardings, block_scale_configs)
hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in expected_hlo]), flags=re.DOTALL
@ -197,7 +197,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
contract=[160, 96],
lhs_non_contract=[240, 100],
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
block_scale_configs=[create_mxfp8_configs_if_available(),],
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
def test_scaled_matmul(
@ -239,7 +239,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
@jtu.sample_product(
in_shardings=sharding_configs,
block_scale_configs=[create_mxfp8_configs_if_available(),],
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs):
@ -298,8 +298,6 @@ class ScaledMatmulTest(jtu.JaxTestCase):
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
block_scale_configs = create_mxfp8_configs_if_available()
def setUp(self):
super().setUp()
try:
@ -307,11 +305,15 @@ class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
except RuntimeError as e:
self.skipTest(str(e))
return
if _dtypes.float8_e8m0fnu is None:
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
if cudnn_version < 90700:
self.skipTest("Requires >= cuDNN 9.7.0")
if not jtu.is_cuda_compute_capability_at_least("10.0"):
self.skipTest("Requires at least Blackwell arch")
block_scale_configs = create_mxfp8_configs()
@jtu.sample_product(
configs=[
# a_shape, b_shape, dimension_numbers, is_training