Enable public doc for scaled dot

This commit is contained in:
kaixih 2025-03-26 20:57:30 +00:00
parent c8ccd7570a
commit f949b8b8f6
4 changed files with 168 additions and 81 deletions

View File

@ -54,3 +54,6 @@ Other functions
standardize
one_hot
dot_product_attention
scaled_matmul
get_scaled_dot_general_config
scaled_dot_general

View File

@ -1210,81 +1210,184 @@ def dot_product_attention(
return jnp.reshape(out, output_shape)
def scaled_matmul(
lhs: Array,
rhs: Array,
lhs_scales: Array,
rhs_scales: Array,
a: Array,
b: Array,
a_scales: Array,
b_scales: Array,
preferred_element_type: DTypeLike = jnp.float32,
) -> Array:
r"""
Performs scaled matrix multiplication between two 3D arrays, with scaling
factors applied to the matrices.
.. math::
\mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs)
r"""Scaled matrix multiplication function.
Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`.
The last dim is the contracting dim, and block size is inferred.
Mathematically, this operation is equivalent to::
a_block_size = a.shape[-1] // a_scales.shape[-1]
b_block_size = b.shape[-1] // b_scales.shape[-1]
a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1)
b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1)
jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled)
Args:
lhs (Array): A 3D array of shape (B, M, K).
rhs (Array): A 3D array of shape (B, N, K).
lhs_scales (Array): A 3D array of shape (B, M, K_block).
rhs_scales (Array): A 3D array of shape (B, N, K_block).
preferred_element_type (DTypeLike, optional): The preferred data type
for the computation. Defaults to `jnp.float32`.
a (Array): Shape (B, M, K).
b (Array): Shape (B, N, K).
a_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`.
b_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`.
preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`.
Returns:
Array: A 3D array of shape (B, M, N) representing the scaled matrix
multiplication result.
Raises:
AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
match the number of columns in `rhs` (`rhs_K`).
Array of shape (B, M, N).
Notes:
- The function ensures that the `preferred_element_type` is
danonicalized before passing it to the underlying computation.
- Scaling is applied to the matrices based on the `lhs_scales` and
`rhs_scales` arrays, enabling efficient computations in blocks.
- We currently do not support user-defined `precision` for customizing the
compute data type. It is fixed to `jnp.float32`.
- Block size is inferred as `K // K_a` for `a` and `K // K_b` for `b`.
- To use cuDNN with Nvidia Blackwell GPUs, inputs must match::
# mxfp8
a, b: jnp.float8_e4m3fn | jnp.float8_e5m2
a_scales, b_scales: jnp.float8_e8m0fnu
block_size: 32
# nvfp4
a, b: jnp.float4_e2m1fn
a_scales, b_scales: jnp.float8_e4m3fn
block_size: 16
Examples:
Basic case:
>>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3))
>>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3))
>>> a_scales = jnp.array([0.5]).reshape((1, 1, 1))
>>> b_scales = jnp.array([0.5]).reshape((1, 1, 1))
>>> scaled_matmul(a, b, a_scales, b_scales)
Array([[[8.]]], dtype=float32)
Using fused cuDNN call on Blackwell GPUs:
>>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn)
>>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn)
>>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
>>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
>>> scaled_matmul(a, b, a_scales, b_scales)
"""
B, M, lhs_K = lhs.shape
_, N, rhs_K = rhs.shape
assert lhs_K == rhs_K
_, _, K_block = lhs_scales.shape
assert all(x.ndim == 3 for x in (a, b, a_scales, b_scales))
B_a, M_a, K_a = a.shape
B_b, N_b, K_b = b.shape
assert K_a == K_b and B_a == B_b
B_as, M_as, K_as = a_scales.shape
B_bs, N_bs, K_bs = b_scales.shape
assert K_as == K_bs and B_as == B_bs
assert M_as == M_a and N_bs == N_b
preferred_element_type = dtypes.canonicalize_dtype(
np.dtype(preferred_element_type)
)
out = cudnn_scaled_matmul(
lhs,
rhs,
lhs_scales,
rhs_scales,
a,
b,
a_scales,
b_scales,
preferred_element_type=preferred_element_type,
)
return out
def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'],
global_scale: Array | None = None):
r"""Get quantization configs for scaled_dot_general.
Create quantization configs for the `jax.nn.scaled_dot_general`.
See Also:
- :func:`jax.nn.scaled_dot_general`: Scaled dot general function.
"""
if mode == 'nvfp4':
one = jnp.ones((1,), dtype=jnp.float32)
return BlockScaleConfig(
mode='nvfp4',
block_size=16,
data_type=jnp.float4_e2m1fn,
scale_type=jnp.float8_e4m3fn,
global_scale=one if global_scale is None else global_scale,
infer_only=False
)
elif mode == 'mxfp8':
return BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
else:
raise ValueError(f"Unsupported mode: {mode}")
def scaled_dot_general(
lhs, rhs,
dimension_numbers,
preferred_element_type=jnp.float32,
configs: List[BlockScaleConfig] | None = None,
implementation: Literal['cudnn'] | None = None,
):
r"""Scaled dot general operation.
Computes the scaled dot general on lhs, rhs with quanitzation specified by configs:
.. math::
\widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\
\widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\
\mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs})
Performs a generalized dot product with block-scaled quantization on the
lhs and rhs inputs. This operation extends `lax.dot_general` to support
user-defined scaling configurations.
Essentially, the operation follows::
a, a_scales = quantize(lhs, configs[0])
b, b_scales = quantize(rhs, configs[1])
c = jax.nn.scaled_matmul(a, b, a_scales, b_scales)
Args:
lhs: Left-hand side input tensor.
rhs: Right-hand side input tensor.
dimension_numbers: A tuple specifying the contraction and batch dimensions
for the dot general operation. Must follow the format:
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
preferred_element_type: The preferred output data type. Supported types are
`jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`.
configs: A list of `BlockScaleConfig` specifying the scaling
configurations for the operation. Defaults to `mxfp8`.
implementation: A string to control which implementation backend to use.
Supported strings are `cudnn` (cuDNN block scaled dot). It defaults
to `None`, which will automatically select the best available backend.
lhs (ArrayLike): Input array.
rhs (ArrayLike): Input array.
dimension_numbers (DotDimensionNumbers): A tuple of two tuples specifying
the contraction and batch dimensions:
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
preferred_element_type (DTypeLike, optional): Output data type of the dot
product. Defaults to `jnp.float32`. Other valid types include
`jnp.bfloat16` and `jnp.float16`.
configs (list of BlockScaleConfig, optional): Scaling configurations for
lhs, rhs, and gradients. Users can obtain valid configurations via
`jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8`
are supported. If `None`, falls back to `lax.dot_general`.
Returns:
The result of the scaled dot general operation.
Array: The resulting tensor, with batch dimensions first, followed by
non-contracting/non-batch dimensions of lhs, and then those of rhs.
See Also:
- :func:`jax.nn.scaled_matmul`: Scaled matmul function.
- :func:`jax.lax.dot_general`: General dot product operator.
Notes:
- Unlike `nn.scaled_matmul`, which assumes quantized low-precision
inputs with explicit scaling factors, this operator takes high-precision
inputs, applies quantization internally, and handles the backward pass.
Examples:
Creating config for mxfp8:
>>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3
Creating config for nvfp4:
>>> global_scale = jnp.array([0.5], jnp.float32)
>>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3
Using scaled_dot_general with the configs:
>>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs)
>>> lhs = random.normal(keys[0], (3, 128, 64))
>>> rhs = random.normal(keys[1], (3, 128, 64))
>>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,))))
"""
# Create configs if not provided
if configs is None:
@ -1300,17 +1403,10 @@ def scaled_dot_general(
)
configs = [mxfp8_config for _ in range(3)]
if implementation is None:
implementation = 'cudnn'
match implementation:
case 'cudnn':
out = cudnn_scaled_dot_general(
lhs, rhs, dimension_numbers,
preferred_element_type=preferred_element_type,
configs=configs
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")
out = cudnn_scaled_dot_general(
lhs, rhs, dimension_numbers,
preferred_element_type=preferred_element_type,
configs=configs
)
return out

View File

@ -38,6 +38,7 @@ from jax._src.nn.functions import (
identity as identity,
relu6 as relu6,
dot_product_attention as dot_product_attention,
get_scaled_dot_general_config as get_scaled_dot_general_config,
scaled_dot_general as scaled_dot_general,
scaled_matmul as scaled_matmul,
selu as selu,

View File

@ -31,7 +31,6 @@ from jax._src.lib import cuda_versions
from jax._src.cudnn.scaled_matmul_stablehlo import (
quantize,
shape_normalization,
BlockScaleConfig,
)
from jax.test_util import check_grads
from jax import nn
@ -110,17 +109,7 @@ def create_mxfp8_configs_if_available():
if _dtypes.float8_e8m0fnu is None:
raise unittest.SkipTest("float8_e8m0fnu is not available.")
def _create_mxfp8_config():
return BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
return [_create_mxfp8_config() for _ in range(3)]
return [nn.get_scaled_dot_general_config("mxfp8") for _ in range(3)]
@jtu.with_config(jax_legacy_prng_key="allow",
@ -130,10 +119,9 @@ class NNFunctionsTest(jtu.JaxTestCase):
contract=[160, 96],
lhs_non_contract=[240, 100],
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
impl=['cudnn',],
)
def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
def testScaledMatmul(self, contract, lhs_non_contract, dtype):
if not _is_required_cudnn_version_satisfied("10.0", 90700):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
# Check if float8_e8m0fnu is available
configs = create_mxfp8_configs_if_available()
@ -153,11 +141,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
@parameterized.product(
is_training=[True, False],
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
impl=['cudnn',],
)
def testScaledDotGeneral(
self, is_training, output_type, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
self, is_training, output_type):
if not _is_required_cudnn_version_satisfied("10.0", 90700):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
configs = create_mxfp8_configs_if_available()