Merge pull request #27919 from kaixih:enable_doc_scaled_dot_fix

PiperOrigin-RevId: 747578845
This commit is contained in:
jax authors 2025-04-14 14:55:23 -07:00
commit 19be20fc6f

View File

@ -22,6 +22,7 @@ import operator
import math
import numpy as np
from typing import Any, List, Literal
import warnings
import jax
import jax.numpy as jnp
@ -1210,10 +1211,10 @@ def dot_product_attention(
return jnp.reshape(out, output_shape)
def scaled_matmul(
a: Array,
b: Array,
a_scales: Array,
b_scales: Array,
lhs: Array,
rhs: Array,
lhs_scales: Array,
rhs_scales: Array,
preferred_element_type: DTypeLike = jnp.float32,
) -> Array:
r"""Scaled matrix multiplication function.
@ -1230,10 +1231,10 @@ def scaled_matmul(
jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled)
Args:
a (Array): Shape (B, M, K).
b (Array): Shape (B, N, K).
a_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`.
b_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`.
lhs (Array): Operand a, shape (B, M, K).
rhs (Array): Operand b, shape (B, N, K).
lhs_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`.
rhs_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`.
preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`.
Returns:
@ -1274,6 +1275,7 @@ def scaled_matmul(
>>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
>>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP
"""
a, b, a_scales, b_scales = lhs, rhs, lhs_scales, rhs_scales
if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)):
raise ValueError(
"scaled_matmul requires all inputs to be 3-dimensional arrays"
@ -1353,6 +1355,7 @@ def scaled_dot_general(
dimension_numbers,
preferred_element_type=jnp.float32,
configs: List[BlockScaleConfig] | None = None,
implementation: Literal['cudnn'] | None = None,
):
r"""Scaled dot general operation.
@ -1379,6 +1382,9 @@ def scaled_dot_general(
lhs, rhs, and gradients. Users can obtain valid configurations via
`jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8`
are supported. If `None`, falls back to `lax.dot_general`.
implementation: str
(Deprecated) Backend selector, now ignored. The system chooses the backend
automatically. Scheduled for removal in future releases.
Returns:
Array: The resulting tensor, with batch dimensions first, followed by
@ -1412,19 +1418,13 @@ def scaled_dot_general(
>>> rhs = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64))
>>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,)))) # doctest: +SKIP
"""
# Create configs if not provided
if implementation is not None:
warnings.warn("Backend selector, now ignored. The system chooses the "
"backend automatically.", DeprecationWarning)
if configs is None:
if dtypes.float8_e8m0fnu is None:
raise ValueError("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
mxfp8_config = BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
configs = [mxfp8_config for _ in range(3)]
return lax.dot_general(lhs, rhs, dimension_numbers,
preferred_element_type=preferred_element_type)
out = cudnn_scaled_dot_general(
lhs, rhs, dimension_numbers,