mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Conditionally create mxfp8_configs.
This commit is contained in:
parent
bfb9d3ca4b
commit
08012e9c01
@ -27,6 +27,7 @@ import scipy.stats
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
@ -90,16 +91,22 @@ def _generate_quantized_tensors(
|
||||
|
||||
return a, b, a_q, b_q, a_scales, b_scales
|
||||
|
||||
_create_mxfp8_config = lambda: BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
def create_mxfp8_configs_if_available():
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
raise unittest.SkipTest("float8_e8m0fnu is not available.")
|
||||
|
||||
def _create_mxfp8_config():
|
||||
return BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
return [_create_mxfp8_config() for _ in range(3)]
|
||||
|
||||
mxfp8_configs = [_create_mxfp8_config() for _ in range(3)]
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key="allow",
|
||||
jax_numpy_dtype_promotion="standard")
|
||||
@ -108,12 +115,13 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
contract=[160, 96],
|
||||
lhs_non_contract=[240, 100],
|
||||
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
configs=[mxfp8_configs,],
|
||||
impl=['cudnn',],
|
||||
)
|
||||
def testScaledMatmul(self, contract, lhs_non_contract, dtype, configs, impl):
|
||||
def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
|
||||
# Check if float8_e8m0fnu is available
|
||||
configs = create_mxfp8_configs_if_available()
|
||||
batch, rhs_non_contract = 4, 256
|
||||
a, b, a_q, b_q, a_scales, b_scales = _generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
@ -130,13 +138,14 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
@parameterized.product(
|
||||
is_training=[True, False],
|
||||
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
configs=[mxfp8_configs,],
|
||||
impl=['cudnn',],
|
||||
)
|
||||
def testScaledDotGeneral(
|
||||
self, is_training, output_type, configs, impl):
|
||||
self, is_training, output_type, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
|
||||
|
||||
configs = create_mxfp8_configs_if_available()
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
@ -182,6 +191,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
out = j_inference(a, b)
|
||||
out_ref = j_inference_ref(a, b)
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@parameterized.product(
|
||||
dtype=[jnp.bfloat16, jnp.float16],
|
||||
group_num=[1, 2, 4],
|
||||
|
@ -57,15 +57,21 @@ sharding_configs = {
|
||||
for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec)
|
||||
}
|
||||
|
||||
mxfp8_config = BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
mxfp8_configs = [mxfp8_config for _ in range(3)]
|
||||
def create_mxfp8_configs_if_available():
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
raise unittest.SkipTest("float8_e8m0fnu is not available.")
|
||||
|
||||
def _create_mxfp8_config():
|
||||
return BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
return [_create_mxfp8_config() for _ in range(3)]
|
||||
|
||||
def generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
@ -135,7 +141,7 @@ def shard_and_device_put(
|
||||
|
||||
def get_hlo_text(in_shardings, block_scale_configs=None):
|
||||
if block_scale_configs is None:
|
||||
block_scale_configs = mxfp8_configs
|
||||
block_scale_configs = create_mxfp8_configs_if_available()
|
||||
|
||||
mesh_names = ("dp", "tp")
|
||||
devices = np.array(jax.local_devices()[:4]).reshape((2, 2))
|
||||
@ -190,7 +196,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
contract=[160, 96],
|
||||
lhs_non_contract=[240, 100],
|
||||
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
block_scale_configs=[mxfp8_configs,],
|
||||
block_scale_configs=[create_mxfp8_configs_if_available(),],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_scaled_matmul(
|
||||
@ -232,7 +238,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
in_shardings=sharding_configs,
|
||||
block_scale_configs=[mxfp8_configs,],
|
||||
block_scale_configs=[create_mxfp8_configs_if_available(),],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs):
|
||||
@ -291,7 +297,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
||||
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
|
||||
block_scale_configs = mxfp8_configs
|
||||
block_scale_configs = create_mxfp8_configs_if_available()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
Loading…
x
Reference in New Issue
Block a user