mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #27919 from kaixih:enable_doc_scaled_dot_fix
PiperOrigin-RevId: 747578845
This commit is contained in:
commit
19be20fc6f
@ -22,6 +22,7 @@ import operator
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Any, List, Literal
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -1210,10 +1211,10 @@ def dot_product_attention(
|
||||
return jnp.reshape(out, output_shape)
|
||||
|
||||
def scaled_matmul(
|
||||
a: Array,
|
||||
b: Array,
|
||||
a_scales: Array,
|
||||
b_scales: Array,
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
lhs_scales: Array,
|
||||
rhs_scales: Array,
|
||||
preferred_element_type: DTypeLike = jnp.float32,
|
||||
) -> Array:
|
||||
r"""Scaled matrix multiplication function.
|
||||
@ -1230,10 +1231,10 @@ def scaled_matmul(
|
||||
jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled)
|
||||
|
||||
Args:
|
||||
a (Array): Shape (B, M, K).
|
||||
b (Array): Shape (B, N, K).
|
||||
a_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`.
|
||||
b_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`.
|
||||
lhs (Array): Operand a, shape (B, M, K).
|
||||
rhs (Array): Operand b, shape (B, N, K).
|
||||
lhs_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`.
|
||||
rhs_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`.
|
||||
preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`.
|
||||
|
||||
Returns:
|
||||
@ -1274,6 +1275,7 @@ def scaled_matmul(
|
||||
>>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
|
||||
>>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP
|
||||
"""
|
||||
a, b, a_scales, b_scales = lhs, rhs, lhs_scales, rhs_scales
|
||||
if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)):
|
||||
raise ValueError(
|
||||
"scaled_matmul requires all inputs to be 3-dimensional arrays"
|
||||
@ -1353,6 +1355,7 @@ def scaled_dot_general(
|
||||
dimension_numbers,
|
||||
preferred_element_type=jnp.float32,
|
||||
configs: List[BlockScaleConfig] | None = None,
|
||||
implementation: Literal['cudnn'] | None = None,
|
||||
):
|
||||
r"""Scaled dot general operation.
|
||||
|
||||
@ -1379,6 +1382,9 @@ def scaled_dot_general(
|
||||
lhs, rhs, and gradients. Users can obtain valid configurations via
|
||||
`jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8`
|
||||
are supported. If `None`, falls back to `lax.dot_general`.
|
||||
implementation: str
|
||||
(Deprecated) Backend selector, now ignored. The system chooses the backend
|
||||
automatically. Scheduled for removal in future releases.
|
||||
|
||||
Returns:
|
||||
Array: The resulting tensor, with batch dimensions first, followed by
|
||||
@ -1412,19 +1418,13 @@ def scaled_dot_general(
|
||||
>>> rhs = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64))
|
||||
>>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,)))) # doctest: +SKIP
|
||||
"""
|
||||
# Create configs if not provided
|
||||
if implementation is not None:
|
||||
warnings.warn("Backend selector, now ignored. The system chooses the "
|
||||
"backend automatically.", DeprecationWarning)
|
||||
|
||||
if configs is None:
|
||||
if dtypes.float8_e8m0fnu is None:
|
||||
raise ValueError("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
|
||||
mxfp8_config = BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
configs = [mxfp8_config for _ in range(3)]
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers,
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
out = cudnn_scaled_dot_general(
|
||||
lhs, rhs, dimension_numbers,
|
||||
|
Loading…
x
Reference in New Issue
Block a user