mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Enable public doc for scaled dot
This commit is contained in:
parent
c8ccd7570a
commit
f949b8b8f6
@ -54,3 +54,6 @@ Other functions
|
||||
standardize
|
||||
one_hot
|
||||
dot_product_attention
|
||||
scaled_matmul
|
||||
get_scaled_dot_general_config
|
||||
scaled_dot_general
|
||||
|
@ -1210,81 +1210,184 @@ def dot_product_attention(
|
||||
return jnp.reshape(out, output_shape)
|
||||
|
||||
def scaled_matmul(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
lhs_scales: Array,
|
||||
rhs_scales: Array,
|
||||
a: Array,
|
||||
b: Array,
|
||||
a_scales: Array,
|
||||
b_scales: Array,
|
||||
preferred_element_type: DTypeLike = jnp.float32,
|
||||
) -> Array:
|
||||
r"""
|
||||
Performs scaled matrix multiplication between two 3D arrays, with scaling
|
||||
factors applied to the matrices.
|
||||
.. math::
|
||||
\mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs)
|
||||
r"""Scaled matrix multiplication function.
|
||||
|
||||
Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`.
|
||||
The last dim is the contracting dim, and block size is inferred.
|
||||
|
||||
Mathematically, this operation is equivalent to::
|
||||
|
||||
a_block_size = a.shape[-1] // a_scales.shape[-1]
|
||||
b_block_size = b.shape[-1] // b_scales.shape[-1]
|
||||
a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1)
|
||||
b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1)
|
||||
jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled)
|
||||
|
||||
Args:
|
||||
lhs (Array): A 3D array of shape (B, M, K).
|
||||
rhs (Array): A 3D array of shape (B, N, K).
|
||||
lhs_scales (Array): A 3D array of shape (B, M, K_block).
|
||||
rhs_scales (Array): A 3D array of shape (B, N, K_block).
|
||||
preferred_element_type (DTypeLike, optional): The preferred data type
|
||||
for the computation. Defaults to `jnp.float32`.
|
||||
a (Array): Shape (B, M, K).
|
||||
b (Array): Shape (B, N, K).
|
||||
a_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`.
|
||||
b_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`.
|
||||
preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`.
|
||||
|
||||
Returns:
|
||||
Array: A 3D array of shape (B, M, N) representing the scaled matrix
|
||||
multiplication result.
|
||||
Raises:
|
||||
AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
|
||||
match the number of columns in `rhs` (`rhs_K`).
|
||||
Array of shape (B, M, N).
|
||||
|
||||
Notes:
|
||||
- The function ensures that the `preferred_element_type` is
|
||||
danonicalized before passing it to the underlying computation.
|
||||
- Scaling is applied to the matrices based on the `lhs_scales` and
|
||||
`rhs_scales` arrays, enabling efficient computations in blocks.
|
||||
- We currently do not support user-defined `precision` for customizing the
|
||||
compute data type. It is fixed to `jnp.float32`.
|
||||
- Block size is inferred as `K // K_a` for `a` and `K // K_b` for `b`.
|
||||
- To use cuDNN with Nvidia Blackwell GPUs, inputs must match::
|
||||
|
||||
# mxfp8
|
||||
a, b: jnp.float8_e4m3fn | jnp.float8_e5m2
|
||||
a_scales, b_scales: jnp.float8_e8m0fnu
|
||||
block_size: 32
|
||||
# nvfp4
|
||||
a, b: jnp.float4_e2m1fn
|
||||
a_scales, b_scales: jnp.float8_e4m3fn
|
||||
block_size: 16
|
||||
|
||||
Examples:
|
||||
|
||||
Basic case:
|
||||
|
||||
>>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3))
|
||||
>>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3))
|
||||
>>> a_scales = jnp.array([0.5]).reshape((1, 1, 1))
|
||||
>>> b_scales = jnp.array([0.5]).reshape((1, 1, 1))
|
||||
>>> scaled_matmul(a, b, a_scales, b_scales)
|
||||
Array([[[8.]]], dtype=float32)
|
||||
|
||||
Using fused cuDNN call on Blackwell GPUs:
|
||||
|
||||
>>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn)
|
||||
>>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn)
|
||||
>>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
|
||||
>>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
|
||||
>>> scaled_matmul(a, b, a_scales, b_scales)
|
||||
"""
|
||||
B, M, lhs_K = lhs.shape
|
||||
_, N, rhs_K = rhs.shape
|
||||
assert lhs_K == rhs_K
|
||||
_, _, K_block = lhs_scales.shape
|
||||
assert all(x.ndim == 3 for x in (a, b, a_scales, b_scales))
|
||||
B_a, M_a, K_a = a.shape
|
||||
B_b, N_b, K_b = b.shape
|
||||
assert K_a == K_b and B_a == B_b
|
||||
B_as, M_as, K_as = a_scales.shape
|
||||
B_bs, N_bs, K_bs = b_scales.shape
|
||||
assert K_as == K_bs and B_as == B_bs
|
||||
assert M_as == M_a and N_bs == N_b
|
||||
|
||||
preferred_element_type = dtypes.canonicalize_dtype(
|
||||
np.dtype(preferred_element_type)
|
||||
)
|
||||
out = cudnn_scaled_matmul(
|
||||
lhs,
|
||||
rhs,
|
||||
lhs_scales,
|
||||
rhs_scales,
|
||||
a,
|
||||
b,
|
||||
a_scales,
|
||||
b_scales,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
return out
|
||||
|
||||
def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'],
|
||||
global_scale: Array | None = None):
|
||||
r"""Get quantization configs for scaled_dot_general.
|
||||
|
||||
Create quantization configs for the `jax.nn.scaled_dot_general`.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.nn.scaled_dot_general`: Scaled dot general function.
|
||||
"""
|
||||
|
||||
if mode == 'nvfp4':
|
||||
one = jnp.ones((1,), dtype=jnp.float32)
|
||||
return BlockScaleConfig(
|
||||
mode='nvfp4',
|
||||
block_size=16,
|
||||
data_type=jnp.float4_e2m1fn,
|
||||
scale_type=jnp.float8_e4m3fn,
|
||||
global_scale=one if global_scale is None else global_scale,
|
||||
infer_only=False
|
||||
)
|
||||
elif mode == 'mxfp8':
|
||||
return BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported mode: {mode}")
|
||||
|
||||
def scaled_dot_general(
|
||||
lhs, rhs,
|
||||
dimension_numbers,
|
||||
preferred_element_type=jnp.float32,
|
||||
configs: List[BlockScaleConfig] | None = None,
|
||||
implementation: Literal['cudnn'] | None = None,
|
||||
):
|
||||
r"""Scaled dot general operation.
|
||||
Computes the scaled dot general on lhs, rhs with quanitzation specified by configs:
|
||||
.. math::
|
||||
\widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\
|
||||
\widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\
|
||||
\mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs})
|
||||
|
||||
Performs a generalized dot product with block-scaled quantization on the
|
||||
lhs and rhs inputs. This operation extends `lax.dot_general` to support
|
||||
user-defined scaling configurations.
|
||||
|
||||
Essentially, the operation follows::
|
||||
|
||||
a, a_scales = quantize(lhs, configs[0])
|
||||
b, b_scales = quantize(rhs, configs[1])
|
||||
c = jax.nn.scaled_matmul(a, b, a_scales, b_scales)
|
||||
|
||||
Args:
|
||||
lhs: Left-hand side input tensor.
|
||||
rhs: Right-hand side input tensor.
|
||||
dimension_numbers: A tuple specifying the contraction and batch dimensions
|
||||
for the dot general operation. Must follow the format:
|
||||
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
|
||||
preferred_element_type: The preferred output data type. Supported types are
|
||||
`jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`.
|
||||
configs: A list of `BlockScaleConfig` specifying the scaling
|
||||
configurations for the operation. Defaults to `mxfp8`.
|
||||
implementation: A string to control which implementation backend to use.
|
||||
Supported strings are `cudnn` (cuDNN block scaled dot). It defaults
|
||||
to `None`, which will automatically select the best available backend.
|
||||
lhs (ArrayLike): Input array.
|
||||
rhs (ArrayLike): Input array.
|
||||
dimension_numbers (DotDimensionNumbers): A tuple of two tuples specifying
|
||||
the contraction and batch dimensions:
|
||||
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
|
||||
preferred_element_type (DTypeLike, optional): Output data type of the dot
|
||||
product. Defaults to `jnp.float32`. Other valid types include
|
||||
`jnp.bfloat16` and `jnp.float16`.
|
||||
configs (list of BlockScaleConfig, optional): Scaling configurations for
|
||||
lhs, rhs, and gradients. Users can obtain valid configurations via
|
||||
`jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8`
|
||||
are supported. If `None`, falls back to `lax.dot_general`.
|
||||
|
||||
Returns:
|
||||
The result of the scaled dot general operation.
|
||||
Array: The resulting tensor, with batch dimensions first, followed by
|
||||
non-contracting/non-batch dimensions of lhs, and then those of rhs.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.nn.scaled_matmul`: Scaled matmul function.
|
||||
- :func:`jax.lax.dot_general`: General dot product operator.
|
||||
|
||||
Notes:
|
||||
- Unlike `nn.scaled_matmul`, which assumes quantized low-precision
|
||||
inputs with explicit scaling factors, this operator takes high-precision
|
||||
inputs, applies quantization internally, and handles the backward pass.
|
||||
|
||||
Examples:
|
||||
|
||||
Creating config for mxfp8:
|
||||
|
||||
>>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3
|
||||
|
||||
Creating config for nvfp4:
|
||||
|
||||
>>> global_scale = jnp.array([0.5], jnp.float32)
|
||||
>>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3
|
||||
|
||||
Using scaled_dot_general with the configs:
|
||||
|
||||
>>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs)
|
||||
>>> lhs = random.normal(keys[0], (3, 128, 64))
|
||||
>>> rhs = random.normal(keys[1], (3, 128, 64))
|
||||
>>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,))))
|
||||
"""
|
||||
# Create configs if not provided
|
||||
if configs is None:
|
||||
@ -1300,17 +1403,10 @@ def scaled_dot_general(
|
||||
)
|
||||
configs = [mxfp8_config for _ in range(3)]
|
||||
|
||||
if implementation is None:
|
||||
implementation = 'cudnn'
|
||||
|
||||
match implementation:
|
||||
case 'cudnn':
|
||||
out = cudnn_scaled_dot_general(
|
||||
lhs, rhs, dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
configs=configs
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported implementation option: {implementation}")
|
||||
out = cudnn_scaled_dot_general(
|
||||
lhs, rhs, dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
configs=configs
|
||||
)
|
||||
|
||||
return out
|
||||
|
@ -38,6 +38,7 @@ from jax._src.nn.functions import (
|
||||
identity as identity,
|
||||
relu6 as relu6,
|
||||
dot_product_attention as dot_product_attention,
|
||||
get_scaled_dot_general_config as get_scaled_dot_general_config,
|
||||
scaled_dot_general as scaled_dot_general,
|
||||
scaled_matmul as scaled_matmul,
|
||||
selu as selu,
|
||||
|
@ -31,7 +31,6 @@ from jax._src.lib import cuda_versions
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
quantize,
|
||||
shape_normalization,
|
||||
BlockScaleConfig,
|
||||
)
|
||||
from jax.test_util import check_grads
|
||||
from jax import nn
|
||||
@ -110,17 +109,7 @@ def create_mxfp8_configs_if_available():
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
raise unittest.SkipTest("float8_e8m0fnu is not available.")
|
||||
|
||||
def _create_mxfp8_config():
|
||||
return BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
return [_create_mxfp8_config() for _ in range(3)]
|
||||
return [nn.get_scaled_dot_general_config("mxfp8") for _ in range(3)]
|
||||
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key="allow",
|
||||
@ -130,10 +119,9 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
contract=[160, 96],
|
||||
lhs_non_contract=[240, 100],
|
||||
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
impl=['cudnn',],
|
||||
)
|
||||
def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
|
||||
def testScaledMatmul(self, contract, lhs_non_contract, dtype):
|
||||
if not _is_required_cudnn_version_satisfied("10.0", 90700):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
|
||||
# Check if float8_e8m0fnu is available
|
||||
configs = create_mxfp8_configs_if_available()
|
||||
@ -153,11 +141,10 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
@parameterized.product(
|
||||
is_training=[True, False],
|
||||
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
impl=['cudnn',],
|
||||
)
|
||||
def testScaledDotGeneral(
|
||||
self, is_training, output_type, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
|
||||
self, is_training, output_type):
|
||||
if not _is_required_cudnn_version_satisfied("10.0", 90700):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
|
||||
|
||||
configs = create_mxfp8_configs_if_available()
|
||||
|
Loading…
x
Reference in New Issue
Block a user