Improve based on comment # 1

This commit is contained in:
shuw 2025-02-20 20:52:16 +00:00
parent ae111f7c97
commit bfb9d3ca4b
4 changed files with 197 additions and 100 deletions

View File

@ -32,18 +32,6 @@ class BlockScaleConfig:
global_scale: Array | None
infer_only: bool
mxfp8_config = BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
BlockScaleConfigs = List[BlockScaleConfig]
mxfp8_configs: BlockScaleConfigs = [mxfp8_config, mxfp8_config, mxfp8_config]
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
@ -490,6 +478,8 @@ def quantize(x, config):
# shuw(TODO): Add when XLA is ready and e2m1fn is available.
scales_q = scales
scales_x = x
else:
raise ValueError(f"Unrecognized mode: {config.mode}.")
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
x_q = clipped_x.astype(config.data_type)
@ -519,38 +509,6 @@ def quantize_dequantize(x, q_dtype, scale, compute_dtype):
)
return out
def generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
configs, dtype=jnp.float32,
):
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=dtype,
)
k1, k2 = jax.random.split(jax.random.key(123), 2)
a = cast_to_representable(
jax.random.uniform(
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[1].data_type,
)
dn = ((2,), (0,))
a_3d = shape_normalization(a, dn)
b_3d = shape_normalization(b, dn)
a_q, a_scales = quantize(a, configs[0])
b_q, b_scales = quantize(b, configs[1])
return a, b, a_q, b_q, a_scales, b_scales
def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
@ -625,7 +583,7 @@ def scaled_dot_general_transpose_lhs(
def scaled_dot_general_transpose_rhs(
g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike,
configs: BlockScaleConfigs
configs: List[BlockScaleConfig]
):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
@ -709,12 +667,23 @@ def _ensure_batch_dim(lhs, rhs, dimension_numbers):
def scaled_dot_general_wrapper(
lhs, rhs, dimension_numbers,
preferred_element_type=jnp.float32,
configs: BlockScaleConfigs=mxfp8_configs,
configs: List[BlockScaleConfig] | None=None,
):
if preferred_element_type not in (jnp.float32, jnp.bfloat16, jnp.float16):
msg = ('Only support preferred_element_type in (f32, bf16, f16), but got '
'{preferred_element_type}')
raise TypeError(msg)
if configs is None:
mxfp8_config = BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
configs = [mxfp8_config, mxfp8_config, mxfp8_config]
dimension_numbers = ensure_tuple(dimension_numbers)
lhs_batched, rhs_batched, dn_batched = _ensure_batch_dim(
lhs, rhs, dimension_numbers

View File

@ -21,13 +21,17 @@ from functools import partial
import operator
import math
import numpy as np
from typing import Any, Literal
from typing import Any, List, Literal
import jax
import jax.numpy as jnp
from jax import custom_jvp
from jax import lax
from jax._src import config
from jax._src.cudnn.scaled_matmul_stablehlo import (
scaled_matmul_wrapper as cudnn_scaled_matmul,
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
BlockScaleConfig)
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
@ -36,10 +40,6 @@ from jax._src.core import AxisName
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention as cudnn_dot_product_attention, MaskType)
from jax._src.cudnn.scaled_matmul_stablehlo import (
scaled_matmul_wrapper as cudnn_scaled_matmul,
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
BlockScaleConfigs, mxfp8_configs)
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.numpy import util as numpy_util
@ -1174,10 +1174,8 @@ def scaled_matmul(
r"""
Performs scaled matrix multiplication between two 3D arrays, with scaling
factors applied to the matrices.
.. math::
\mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs)
Args:
lhs (Array): A 3D array of shape (B, M, K).
rhs (Array): A 3D array of shape (B, N, K).
@ -1185,21 +1183,17 @@ def scaled_matmul(
rhs_scales (Array): A 3D array of shape (B, N, K_block).
preferred_element_type (DTypeLike, optional): The preferred data type
for the computation. Defaults to `jnp.float32`.
Returns:
Array: A 3D array of shape (B, M, N) representing the scaled matrix
multiplication result.
Raises:
AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
match the number of columns in `rhs` (`rhs_K`).
Notes:
- The function ensures that the `preferred_element_type` is
danonicalized before passing it to the underlying computation.
- Scaling is applied to the matrices based on the `lhs_scales` and
`rhs_scales` arrays, enabling efficient computations in blocks.
"""
B, M, lhs_K = lhs.shape
_, N, rhs_K = rhs.shape
@ -1222,17 +1216,15 @@ def scaled_dot_general(
lhs, rhs,
dimension_numbers,
preferred_element_type=jnp.float32,
configs: BlockScaleConfigs=mxfp8_configs,
configs: List[BlockScaleConfig] | None = None,
implementation: Literal['cudnn'] | None = None,
):
r"""Scaled dot general operation.
Computes the scaled dot general on lhs, rhs with quanitzation specified by configs:
.. math::
\widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\
\widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\
\mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs})
Args:
lhs: Left-hand side input tensor.
rhs: Right-hand side input tensor.
@ -1241,14 +1233,37 @@ def scaled_dot_general(
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
preferred_element_type: The preferred output data type. Supported types are
`jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`.
configs: A `BlockScaleConfigs` objects specifying the scaling
configurations for the operation. Defaults to `mxfp8_configs`.
configs: A list of `BlockScaleConfig` specifying the scaling
configurations for the operation. Defaults to `mxfp8`.
implementation: A string to control which implementation backend to use.
Supported strings are `cudnn` (cuDNN block scaled dot). It defaults
to `None`, which will automatically select the best available backend.
Returns:
The result of the scaled dot general operation.
"""
return cudnn_scaled_dot_general(
lhs, rhs, dimension_numbers,
preferred_element_type=preferred_element_type,
configs=configs
)
# Create configs if not provided
if configs is None:
mxfp8_config = BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
configs = [mxfp8_config for _ in range(3)]
if implementation is None:
implementation = 'cudnn'
match implementation:
case 'cudnn':
out = cudnn_scaled_dot_general(
lhs, rhs, dimension_numbers,
preferred_element_type=preferred_element_type,
configs=configs
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")
return out

View File

@ -30,9 +30,10 @@ from jax._src import core
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax._src.cudnn.scaled_matmul_stablehlo import (
generate_quantized_tensors,
mxfp8_configs,
quantize,
quantize_dequantize,
shape_normalization,
BlockScaleConfig,
)
from jax.test_util import check_grads
from jax import nn
@ -42,9 +43,9 @@ import jax.numpy as jnp
config.parse_flags_with_absl()
def _is_required_cudnn_version_satisfied(min_cudnn_version):
def _is_required_cudnn_version_satisfied(min_cc, min_cudnn_version):
return (
jtu.is_cuda_compute_capability_at_least("8.0") and
jtu.is_cuda_compute_capability_at_least(min_cc) and
cuda_versions is not None and
cuda_versions.cudnn_get_version() >= min_cudnn_version
)
@ -56,6 +57,50 @@ def _check_cudnn_backend(fn, *args, **kwargs):
_cudnn_dbias_error = 'cuDNN only supports bias gradient'
def _generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
configs, dtype=jnp.float32,
):
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=dtype,
)
k1, k2 = jax.random.split(jax.random.key(123), 2)
a = cast_to_representable(
jax.random.uniform(
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[1].data_type,
)
dn = ((2,), (0,))
a_3d = shape_normalization(a, dn)
b_3d = shape_normalization(b, dn)
a_q, a_scales = quantize(a, configs[0])
b_q, b_scales = quantize(b, configs[1])
return a, b, a_q, b_q, a_scales, b_scales
_create_mxfp8_config = lambda: BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
mxfp8_configs = [_create_mxfp8_config() for _ in range(3)]
@jtu.with_config(jax_legacy_prng_key="allow",
jax_numpy_dtype_promotion="standard")
class NNFunctionsTest(jtu.JaxTestCase):
@ -64,10 +109,13 @@ class NNFunctionsTest(jtu.JaxTestCase):
lhs_non_contract=[240, 100],
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
configs=[mxfp8_configs,],
impl=['cudnn',],
)
def testScaledMatmul(self, contract, lhs_non_contract, dtype, configs):
def testScaledMatmul(self, contract, lhs_non_contract, dtype, configs, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
batch, rhs_non_contract = 4, 256
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
a, b, a_q, b_q, a_scales, b_scales = _generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
configs, dtype=dtype,
)
@ -83,10 +131,12 @@ class NNFunctionsTest(jtu.JaxTestCase):
is_training=[True, False],
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
configs=[mxfp8_configs,],
impl=['cudnn',],
)
def testScaledDotGeneral(
self, is_training, output_type, configs,
):
self, is_training, output_type, configs, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
@ -132,7 +182,6 @@ class NNFunctionsTest(jtu.JaxTestCase):
out = j_inference(a, b)
out_ref = j_inference_ref(a, b)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
@parameterized.product(
dtype=[jnp.bfloat16, jnp.float16],
group_num=[1, 2, 4],
@ -140,7 +189,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
impl=['cudnn', 'xla'],
)
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("8.0", 8904):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
@ -189,7 +238,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
if isinstance(mask_mode, str):
mask_mode = (mask_mode,)
min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904
if not _is_required_cudnn_version_satisfied(min_cudnn_version):
if not _is_required_cudnn_version_satisfied("8.0", min_cudnn_version):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
dtype = jnp.bfloat16
@ -252,7 +301,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
use_vmap=[False, True],
)
def testDotProductAttentionBiasGradient(self, batch_size, use_vmap):
if not _is_required_cudnn_version_satisfied(8904):
if not _is_required_cudnn_version_satisfied("8.0", 8904):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
dtype = jnp.bfloat16

View File

@ -13,14 +13,15 @@ from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
from jax._src.cudnn.scaled_matmul_stablehlo import (
scaled_matmul_wrapper,
scaled_dot_general_wrapper,
mxfp8_configs,
generate_quantized_tensors,
shape_normalization,
quantize,
quantize_dequantize,
BlockScaleConfig,
)
config.parse_flags_with_absl()
input_sharding_configs = [
input_shardings = [
(("dp", None, "tp"), ("dp", None, "tp")),
(("dp", None, "tp"), ("dp", None, None)),
(("dp", None, "tp"), ("dp", "tp", None)),
@ -41,7 +42,63 @@ expected_hlos = [
("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name),
("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name),
]
sharding_configs = [[i, j] for i, j in zip(input_sharding_configs, expected_hlos)]
expected_output_spec = [
PartitionSpec('dp',),
PartitionSpec('dp',),
PartitionSpec('dp', None, 'tp'),
PartitionSpec('dp', None, 'tp'),
PartitionSpec('dp', 'tp', None),
PartitionSpec(None, 'dp', 'tp'),
PartitionSpec(None, 'tp', None),
PartitionSpec(None, None, 'tp'),
]
sharding_configs = {
input_sharding: (hlo, output_spec)
for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec)
}
mxfp8_config = BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
mxfp8_configs = [mxfp8_config for _ in range(3)]
def generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
configs, dtype=jnp.float32,
):
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=dtype,
)
k1, k2 = jax.random.split(jax.random.key(123), 2)
a = cast_to_representable(
jax.random.uniform(
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[1].data_type,
)
dn = ((2,), (0,))
a_3d = shape_normalization(a, dn)
b_3d = shape_normalization(b, dn)
a_q, a_scales = quantize(a, configs[0])
b_q, b_scales = quantize(b, configs[1])
return a, b, a_q, b_q, a_scales, b_scales
def shard_and_device_put(
@ -76,7 +133,10 @@ def shard_and_device_put(
return a, b, in_shardings
def get_hlo_text(in_shardings, block_scale_configs=mxfp8_configs):
def get_hlo_text(in_shardings, block_scale_configs=None):
if block_scale_configs is None:
block_scale_configs = mxfp8_configs
mesh_names = ("dp", "tp")
devices = np.array(jax.local_devices()[:4]).reshape((2, 2))
mesh = Mesh(devices, mesh_names)
@ -109,15 +169,15 @@ class ScaledMatmulTest(jtu.JaxTestCase):
self.skipTest("Requires at least Blackwell arch")
@jtu.sample_product(
sharding_config=sharding_configs,
in_shardings=sharding_configs,
)
@jtu.run_on_devices("cuda")
def test_collectives(self, sharding_config):
def test_collectives(self, in_shardings):
if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4:
self.skipTest("Partition Test enabled for at least 4 GPUs")
input_sharding, expected_hlo = sharding_config[0], sharding_config[1]
hlo_text = get_hlo_text(input_sharding)
expected_hlo = sharding_configs[in_shardings][0]
hlo_text = get_hlo_text(in_shardings)
hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in expected_hlo]), flags=re.DOTALL
@ -171,22 +231,21 @@ class ScaledMatmulTest(jtu.JaxTestCase):
)
@jtu.sample_product(
sharding_config=sharding_configs,
in_shardings=sharding_configs,
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
def test_scaled_matmul_sharded(self, sharding_config, block_scale_configs):
def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs):
if len(jax.local_devices()) < 4:
self.skipTest("Require at least 4 devices to run sharding tests.")
batch, contract, non_contract = 2, 1024, 256
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
batch, non_contract, contract, non_contract, block_scale_configs,
)
devices = np.array(jax.local_devices()[:4])
devices = devices.reshape((2, 2))
in_shardings = sharding_config[0]
expected_output_spec = sharding_configs[in_shardings][1]
with Mesh(devices, ("dp", "tp")) as mesh:
a_q, b_q, a_scales, b_scales, input_shardings = (
@ -205,11 +264,11 @@ class ScaledMatmulTest(jtu.JaxTestCase):
j_scaled_matmul = jax.jit(
scaled_matmul_wrapper, in_shardings=input_shardings
)
hlo_text = j_scaled_matmul.lower(*args).compile().as_text()
hlo_compiled = j_scaled_matmul.lower(*args).compile()
hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
)
self.assertRegex(hlo_text, hlo_pattern)
self.assertRegex(hlo_compiled.as_text(), hlo_pattern)
j_ref = jax.jit(
partial(
@ -221,8 +280,13 @@ class ScaledMatmulTest(jtu.JaxTestCase):
out = j_scaled_matmul(*args)
out_ref = j_ref(a, b)
expected_output_sharding = NamedSharding(
mesh=mesh, spec=expected_output_spec
)
self.assertArraysAllClose(out, out_ref, rtol=1e-3, atol=1e-3)
self.assertTrue(
out.sharding.is_equivalent_to(expected_output_sharding, out.ndim)
)
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
@ -306,9 +370,9 @@ class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
out_ref = j_inference_ref(a, b)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
@jtu.sample_product(sharding_config=input_sharding_configs)
@jtu.sample_product(in_shardings=sharding_configs)
@jtu.run_on_devices("cuda")
def test_dot_general_sharded(self, sharding_config):
def test_dot_general_sharded(self, in_shardings):
if len(jax.local_devices()) < 4:
self.skipTest("Require at least 4 devices to run sharding tests.")
@ -344,7 +408,6 @@ class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
# custom scaled_matmul op.
return jnp.sum(jnp.tanh(y))
in_shardings = sharding_config
devices = np.array(jax.local_devices()[:4])
devices = devices.reshape((2, 2))
with Mesh(devices, ("dp", "tp")) as mesh:
@ -375,6 +438,7 @@ class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
@jtu.sample_product(
configs=[
((1, 128, 256), (1, 128, 256), (0, 0, 0)),