Conditionally create mxfp8_configs.

This commit is contained in:
Shu Wang 2025-02-21 23:08:22 -06:00 committed by GitHub
parent bfb9d3ca4b
commit 08012e9c01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 26 deletions

View File

@ -27,6 +27,7 @@ import scipy.stats
from jax._src import ad_checkpoint
from jax._src import config
from jax._src import core
from jax._src import dtypes as _dtypes
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax._src.cudnn.scaled_matmul_stablehlo import (
@ -90,16 +91,22 @@ def _generate_quantized_tensors(
return a, b, a_q, b_q, a_scales, b_scales
_create_mxfp8_config = lambda: BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
def create_mxfp8_configs_if_available():
if _dtypes.float8_e8m0fnu is None:
raise unittest.SkipTest("float8_e8m0fnu is not available.")
def _create_mxfp8_config():
return BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
return [_create_mxfp8_config() for _ in range(3)]
mxfp8_configs = [_create_mxfp8_config() for _ in range(3)]
@jtu.with_config(jax_legacy_prng_key="allow",
jax_numpy_dtype_promotion="standard")
@ -108,12 +115,13 @@ class NNFunctionsTest(jtu.JaxTestCase):
contract=[160, 96],
lhs_non_contract=[240, 100],
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
configs=[mxfp8_configs,],
impl=['cudnn',],
)
def testScaledMatmul(self, contract, lhs_non_contract, dtype, configs, impl):
def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
# Check if float8_e8m0fnu is available
configs = create_mxfp8_configs_if_available()
batch, rhs_non_contract = 4, 256
a, b, a_q, b_q, a_scales, b_scales = _generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
@ -130,13 +138,14 @@ class NNFunctionsTest(jtu.JaxTestCase):
@parameterized.product(
is_training=[True, False],
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
configs=[mxfp8_configs,],
impl=['cudnn',],
)
def testScaledDotGeneral(
self, is_training, output_type, configs, impl):
self, is_training, output_type, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
configs = create_mxfp8_configs_if_available()
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
@ -182,6 +191,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
out = j_inference(a, b)
out_ref = j_inference_ref(a, b)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
@parameterized.product(
dtype=[jnp.bfloat16, jnp.float16],
group_num=[1, 2, 4],

View File

@ -57,15 +57,21 @@ sharding_configs = {
for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec)
}
mxfp8_config = BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
mxfp8_configs = [mxfp8_config for _ in range(3)]
def create_mxfp8_configs_if_available():
if _dtypes.float8_e8m0fnu is None:
raise unittest.SkipTest("float8_e8m0fnu is not available.")
def _create_mxfp8_config():
return BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
return [_create_mxfp8_config() for _ in range(3)]
def generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
@ -135,7 +141,7 @@ def shard_and_device_put(
def get_hlo_text(in_shardings, block_scale_configs=None):
if block_scale_configs is None:
block_scale_configs = mxfp8_configs
block_scale_configs = create_mxfp8_configs_if_available()
mesh_names = ("dp", "tp")
devices = np.array(jax.local_devices()[:4]).reshape((2, 2))
@ -190,7 +196,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
contract=[160, 96],
lhs_non_contract=[240, 100],
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
block_scale_configs=[mxfp8_configs,],
block_scale_configs=[create_mxfp8_configs_if_available(),],
)
@jtu.run_on_devices("cuda")
def test_scaled_matmul(
@ -232,7 +238,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
@jtu.sample_product(
in_shardings=sharding_configs,
block_scale_configs=[mxfp8_configs,],
block_scale_configs=[create_mxfp8_configs_if_available(),],
)
@jtu.run_on_devices("cuda")
def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs):
@ -291,7 +297,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
block_scale_configs = mxfp8_configs
block_scale_configs = create_mxfp8_configs_if_available()
def setUp(self):
super().setUp()