mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
block_scale_config
This commit is contained in:
parent
061d4acbfb
commit
332af58765
@ -1,6 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import operator
|
||||
from functools import partial, reduce
|
||||
from typing import List
|
||||
|
||||
# Third-party imports
|
||||
import jax
|
||||
@ -21,6 +23,26 @@ from jax.sharding import PartitionSpec as P
|
||||
Array = jnp.ndarray
|
||||
nv_scaled_dot_name = "__nv$scaled_dot"
|
||||
|
||||
@dataclass
|
||||
class BlockScaleConfig:
|
||||
mode: str
|
||||
block_size: int
|
||||
data_type: DTypeLike
|
||||
scale_type: DTypeLike
|
||||
global_scale: Array | None
|
||||
infer_only: bool
|
||||
|
||||
mxfp8_config = BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
BlockScaleConfigs = List[BlockScaleConfig]
|
||||
mxfp8_configs: BlockScaleConfigs = [mxfp8_config, mxfp8_config, mxfp8_config]
|
||||
|
||||
def default_layouts(*shapes):
|
||||
return [range(len(shape) - 1, -1, -1) for shape in shapes]
|
||||
@ -301,7 +323,7 @@ def _scaled_matmul(
|
||||
)
|
||||
return output[0]
|
||||
|
||||
def scaled_matmul(
|
||||
def scaled_matmul_wrapper(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
lhs_scales: Array,
|
||||
@ -447,39 +469,92 @@ def e8m0_to_dtype(x, dtype):
|
||||
)
|
||||
return new_x.astype(dtype)
|
||||
|
||||
|
||||
def quantize_core(x, q_dtype, scale):
|
||||
scaled_x = x / scale
|
||||
MAX = jnp.finfo(q_dtype).max.astype(x.dtype)
|
||||
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
|
||||
return clipped_x.astype(q_dtype)
|
||||
|
||||
|
||||
def quantize(x, quantize_type, block_size=32):
|
||||
def quantize(x, config):
|
||||
x_shape = x.shape
|
||||
contract_dim = x_shape[-1]
|
||||
block_size = config.block_size
|
||||
assert contract_dim >= block_size and contract_dim % block_size == 0
|
||||
x_new_shape = x_shape[:-1] + (x_shape[-1] // block_size, block_size)
|
||||
x = x.reshape(x_new_shape) # shape = (B, M, K / block_size, block_size)
|
||||
|
||||
amax = jnp.max(jnp.abs(x), axis=-1, keepdims=True)
|
||||
MAX = jnp.finfo(quantize_type).max.astype(x.dtype)
|
||||
MAX = jnp.finfo(config.data_type).max.astype(x.dtype)
|
||||
scales = amax / MAX # shape = (B, M, K / block_size, 1)
|
||||
|
||||
scales_q = cast_to_e8m0_with_rounding_up(scales)
|
||||
scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype)
|
||||
if config.mode == "mxfp8":
|
||||
assert config.scale_type == jnp.float8_e8m0fnu
|
||||
scales_q = cast_to_e8m0_with_rounding_up(scales)
|
||||
scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype)
|
||||
elif config.mode == "nvfp4":
|
||||
assert config.scale_type == jnp.float8_e4m3fn
|
||||
# shuw(TODO): placeholder
|
||||
scales_q = scales
|
||||
scales_x = x
|
||||
|
||||
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
|
||||
x_q = clipped_x.astype(quantize_type)
|
||||
x_q = clipped_x.astype(config.data_type)
|
||||
|
||||
x_q = x_q.reshape(x_shape) # shape = (B, M, K)
|
||||
scales_q = jnp.reshape(scales_q, scales_q.shape[:-1]).view(
|
||||
jnp.float8_e8m0fnu
|
||||
config.scale_type
|
||||
)
|
||||
return x_q, scales_q
|
||||
|
||||
|
||||
def mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type):
|
||||
def quantize_to_qtype(x, q_dtype, compute_dtype, scale):
|
||||
# Explicitly cast the max values to the compute dtype to avoid unnecessary
|
||||
# casting to FP32 during the subsequent math operations."
|
||||
assert q_dtype in (jnp.float8_e4m3fn, )
|
||||
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
|
||||
scaled_x = x / jnp.broadcast_to(
|
||||
jnp.asarray(scale, dtype=compute_dtype), x.shape
|
||||
)
|
||||
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
|
||||
return clipped_x.astype(q_dtype)
|
||||
|
||||
def quantize_dequantize(x, q_dtype, scale, compute_dtype):
|
||||
qx = quantize_to_qtype(x, q_dtype, compute_dtype, scale)
|
||||
out = qx.astype(x.dtype) * jnp.broadcast_to(
|
||||
jnp.asarray(scale, dtype=x.dtype), qx.shape
|
||||
)
|
||||
return out
|
||||
|
||||
def generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
configs, dtype=jnp.float32,
|
||||
):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=dtype,
|
||||
)
|
||||
|
||||
k1, k2 = jax.random.split(jax.random.key(123), 2)
|
||||
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[1].data_type,
|
||||
)
|
||||
|
||||
dn = ((2,), (0,))
|
||||
a_3d = shape_normalization(a, dn)
|
||||
b_3d = shape_normalization(b, dn)
|
||||
a_q, a_scales = quantize(a, configs[0])
|
||||
b_q, b_scales = quantize(b, configs[1])
|
||||
|
||||
return a, b, a_q, b_q, a_scales, b_scales
|
||||
|
||||
|
||||
def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
configs):
|
||||
if preferred_element_type is None:
|
||||
preferred_element_type = dtypes.result_type(
|
||||
lhs, rhs, return_weak_type_flag=False
|
||||
@ -495,10 +570,11 @@ def mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type):
|
||||
|
||||
lhs_3d = shape_normalization(lhs, lhs_dn)
|
||||
rhs_3d = shape_normalization(rhs, rhs_dn)
|
||||
lhs_q, lhs_scales = quantize(lhs_3d, jnp.float8_e4m3fn)
|
||||
rhs_q, rhs_scales = quantize(rhs_3d, jnp.float8_e4m3fn)
|
||||
lhs_config, rhs_config = configs[0], configs[1]
|
||||
lhs_q, lhs_scales = quantize(lhs_3d, lhs_config)
|
||||
rhs_q, rhs_scales = quantize(rhs_3d, rhs_config)
|
||||
|
||||
out = scaled_matmul(
|
||||
out = scaled_matmul_wrapper(
|
||||
lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type
|
||||
)
|
||||
|
||||
@ -509,8 +585,9 @@ def mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type):
|
||||
return expanded_out
|
||||
|
||||
|
||||
def mxfp8_dot_general_transpose_lhs(
|
||||
g, x, y, *, dimension_numbers, preferred_element_type, swap_ans=False
|
||||
def scaled_dot_general_transpose_lhs(
|
||||
g, x, y, *, dimension_numbers, preferred_element_type, configs,
|
||||
swap_ans=False
|
||||
):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
x_ndim = x.aval.ndim
|
||||
@ -530,10 +607,15 @@ def mxfp8_dot_general_transpose_lhs(
|
||||
|
||||
y_3d = shape_normalization(y, y_dn)
|
||||
g_3d = shape_normalization(g, g_dn)
|
||||
g_q, g_scales = quantize(g_3d, jnp.float8_e4m3fn)
|
||||
y_q, y_scales = quantize(y_3d, jnp.float8_e4m3fn)
|
||||
|
||||
out = scaled_matmul(g_q, y_q, g_scales, y_scales, preferred_element_type)
|
||||
g_config, y_config = configs[0], configs[1]
|
||||
|
||||
g_q, g_scales = quantize(g_3d, g_config)
|
||||
y_q, y_scales = quantize(y_3d, y_config)
|
||||
|
||||
out = scaled_matmul_wrapper(
|
||||
g_q, y_q, g_scales, y_scales, preferred_element_type
|
||||
)
|
||||
|
||||
expanded_out_shape = compute_dot_output_shape(g.shape, y.shape, g_dn, y_dn)
|
||||
expanded_out = jnp.reshape(out, expanded_out_shape)
|
||||
@ -541,34 +623,40 @@ def mxfp8_dot_general_transpose_lhs(
|
||||
return x_bar
|
||||
|
||||
|
||||
def mxfp8_dot_general_transpose_rhs(
|
||||
g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike | None
|
||||
def scaled_dot_general_transpose_rhs(
|
||||
g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike,
|
||||
configs: BlockScaleConfigs
|
||||
):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
||||
y_bar = mxfp8_dot_general_transpose_lhs(
|
||||
y_bar = scaled_dot_general_transpose_lhs(
|
||||
g,
|
||||
y,
|
||||
x,
|
||||
dimension_numbers=swapped_dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
configs=configs,
|
||||
swap_ans=True,
|
||||
)
|
||||
return y_bar
|
||||
|
||||
|
||||
@partial(custom_vjp, nondiff_argnums=(2, 3))
|
||||
def mxfp8_dot_general_fn(lhs, rhs, dimension_numbers, preferred_element_type):
|
||||
return mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type)
|
||||
@partial(custom_vjp, nondiff_argnums=(2, 3, 4))
|
||||
def scaled_dot_general_fn(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
configs):
|
||||
return scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
configs)
|
||||
|
||||
|
||||
def mxfp8_dot_fwd(lhs, rhs, dimension_numbers, preferred_element_type):
|
||||
out = mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type)
|
||||
def scaled_dot_fwd(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
configs):
|
||||
out = scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
configs)
|
||||
res = (lhs, rhs)
|
||||
return out, res
|
||||
|
||||
|
||||
def mxfp8_dot_bwd(dimension_numbers, preferred_element_type, res, g):
|
||||
def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g):
|
||||
(lhs, rhs) = res
|
||||
|
||||
args = [g, lhs, rhs]
|
||||
@ -576,12 +664,20 @@ def mxfp8_dot_bwd(dimension_numbers, preferred_element_type, res, g):
|
||||
"dimension_numbers": dimension_numbers,
|
||||
"preferred_element_type": preferred_element_type,
|
||||
}
|
||||
grad_lhs = mxfp8_dot_general_transpose_lhs(*args, **kw_args)
|
||||
grad_rhs = mxfp8_dot_general_transpose_rhs(*args, **kw_args)
|
||||
lhs_kw_args = {
|
||||
**kw_args,
|
||||
"configs": [configs[2], configs[1]]
|
||||
}
|
||||
rhs_kw_args = {
|
||||
**kw_args,
|
||||
"configs": [configs[2], configs[0]]
|
||||
}
|
||||
grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args)
|
||||
grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args)
|
||||
return (grad_lhs, grad_rhs)
|
||||
|
||||
|
||||
mxfp8_dot_general_fn.defvjp(mxfp8_dot_fwd, mxfp8_dot_bwd)
|
||||
scaled_dot_general_fn.defvjp(scaled_dot_fwd, scaled_dot_bwd)
|
||||
|
||||
|
||||
def ensure_tuple(dimension_numbers):
|
||||
@ -610,9 +706,10 @@ def _ensure_batch_dim(lhs, rhs, dimension_numbers):
|
||||
return lhs_batched, rhs_batched, dn_batched
|
||||
|
||||
|
||||
# TODO(shuw): mxfp8_dot_general should be in nn.function when upstreaming.
|
||||
def mxfp8_dot_general(
|
||||
lhs, rhs, dimension_numbers, preferred_element_type=jnp.float32
|
||||
def scaled_dot_general_wrapper(
|
||||
lhs, rhs, dimension_numbers,
|
||||
preferred_element_type=jnp.float32,
|
||||
configs: BlockScaleConfigs=mxfp8_configs,
|
||||
):
|
||||
if preferred_element_type not in (jnp.float32, jnp.bfloat16, jnp.float16):
|
||||
msg = ('Only support preferred_element_type in (f32, bf16, f16), but got '
|
||||
@ -622,8 +719,8 @@ def mxfp8_dot_general(
|
||||
lhs_batched, rhs_batched, dn_batched = _ensure_batch_dim(
|
||||
lhs, rhs, dimension_numbers
|
||||
)
|
||||
out = mxfp8_dot_general_fn(
|
||||
lhs_batched, rhs_batched, dn_batched, preferred_element_type
|
||||
out = scaled_dot_general_fn(
|
||||
lhs_batched, rhs_batched, dn_batched, preferred_element_type, configs,
|
||||
)
|
||||
|
||||
# Expanding batch dims for operands adds a singleton batch dim at axis 0 in
|
||||
@ -631,4 +728,3 @@ def mxfp8_dot_general(
|
||||
if dn_batched != dimension_numbers:
|
||||
return jnp.squeeze(out, axis=0)
|
||||
return out
|
||||
|
||||
|
@ -36,10 +36,14 @@ from jax._src.core import AxisName
|
||||
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
|
||||
from jax._src.cudnn.fused_attention_stablehlo import (
|
||||
dot_product_attention as cudnn_dot_product_attention, MaskType)
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
scaled_matmul_wrapper as cudnn_scaled_matmul,
|
||||
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
|
||||
BlockScaleConfigs, mxfp8_configs)
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.numpy import util as numpy_util
|
||||
from jax._src.typing import Array, ArrayLike, DType
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
|
||||
from jax._src.ops.special import logsumexp as _logsumexp
|
||||
|
||||
|
||||
@ -1159,3 +1163,92 @@ def dot_product_attention(
|
||||
raise ValueError(f"Unsupported implementation option: {implementation}")
|
||||
|
||||
return jnp.reshape(out, output_shape)
|
||||
|
||||
def scaled_matmul(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
lhs_scales: Array,
|
||||
rhs_scales: Array,
|
||||
preferred_element_type: DTypeLike = jnp.float32,
|
||||
) -> Array:
|
||||
r"""
|
||||
Performs scaled matrix multiplication between two 3D arrays, with scaling
|
||||
factors applied to the matrices.
|
||||
|
||||
.. math::
|
||||
\mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales*rhs_scales*\mathrm{dot}(lhs, rhs)
|
||||
|
||||
Args:
|
||||
lhs (Array): A 3D array of shape (B, M, K).
|
||||
rhs (Array): A 3D array of shape (B, N, K).
|
||||
lhs_scales (Array): A 3D array of shape (B, M, K_block).
|
||||
rhs_scales (Array): A 3D array of shape (B, N, K_block).
|
||||
preferred_element_type (DTypeLike, optional): The preferred data type
|
||||
for the computation. Defaults to `jnp.float32`.
|
||||
|
||||
Returns:
|
||||
Array: A 3D array of shape (B, M, N) representing the scaled matrix
|
||||
multiplication result.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
|
||||
match the number of columns in `rhs` (`rhs_K`).
|
||||
|
||||
Notes:
|
||||
- The function ensures that the `preferred_element_type` is
|
||||
danonicalized before passing it to the underlying computation.
|
||||
- Scaling is applied to the matrices based on the `lhs_scales` and
|
||||
`rhs_scales` arrays, enabling efficient computations in blocks.
|
||||
|
||||
"""
|
||||
B, M, lhs_K = lhs.shape
|
||||
_, N, rhs_K = rhs.shape
|
||||
assert lhs_K == rhs_K
|
||||
_, _, K_block = lhs_scales.shape
|
||||
|
||||
preferred_element_type = dtypes.canonicalize_dtype(
|
||||
np.dtype(preferred_element_type)
|
||||
)
|
||||
out = cudnn_scaled_matmul(
|
||||
lhs,
|
||||
rhs,
|
||||
lhs_scales,
|
||||
rhs_scales,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
return out
|
||||
|
||||
def scaled_dot_general(
|
||||
lhs, rhs,
|
||||
dimension_numbers,
|
||||
preferred_element_type=jnp.float32,
|
||||
configs: BlockScaleConfigs=mxfp8_configs,
|
||||
):
|
||||
r"""Scaled dot general operation.
|
||||
|
||||
Computes the scaled dot general on lhs, rhs with quanitzation specified by configs:
|
||||
|
||||
.. math::
|
||||
\widehat{lhs}, s_a=\mathrm{quantize}(lhs)
|
||||
\widehat{rhs}, s_b=\mathrm{quantize}(rhs)
|
||||
\mathrm{ScaledDot}(lhs, rhs)=s_a s_b \mathrm{dot}(\widehat{lhs}, \widehat{rhs})
|
||||
|
||||
Args:
|
||||
lhs: Left-hand side input tensor.
|
||||
rhs: Right-hand side input tensor.
|
||||
dimension_numbers: A tuple specifying the contraction and batch dimensions
|
||||
for the dot general operation. Must follow the format:
|
||||
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
|
||||
preferred_element_type: The preferred output data type. Supported types are
|
||||
`jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`.
|
||||
configs: A `BlockScaleConfigs` objects specifying the scaling
|
||||
configurations for the operation. Defaults to `mxfp8_configs`.
|
||||
|
||||
Returns:
|
||||
The result of the scaled dot general operation.
|
||||
"""
|
||||
return cudnn_scaled_dot_general(
|
||||
lhs, rhs, dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
configs=configs
|
||||
)
|
||||
|
@ -37,6 +37,8 @@ from jax._src.nn.functions import (
|
||||
relu as relu,
|
||||
relu6 as relu6,
|
||||
dot_product_attention as dot_product_attention,
|
||||
scaled_dot_general as scaled_dot_general,
|
||||
scaled_matmul as scaled_matmul,
|
||||
selu as selu,
|
||||
sigmoid as sigmoid,
|
||||
soft_sign as soft_sign,
|
||||
|
@ -29,6 +29,11 @@ from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
generate_quantized_tensors,
|
||||
mxfp8_configs,
|
||||
quantize_dequantize,
|
||||
)
|
||||
from jax.test_util import check_grads
|
||||
from jax import nn
|
||||
from jax import random
|
||||
@ -54,6 +59,80 @@ _cudnn_dbias_error = 'cuDNN only supports bias gradient'
|
||||
@jtu.with_config(jax_legacy_prng_key="allow",
|
||||
jax_numpy_dtype_promotion="standard")
|
||||
class NNFunctionsTest(jtu.JaxTestCase):
|
||||
@parameterized.product(
|
||||
contract=[160, 96],
|
||||
lhs_non_contract=[240, 100],
|
||||
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
configs=[mxfp8_configs,],
|
||||
)
|
||||
def testScaledMatmul(self, contract, lhs_non_contract, dtype, configs):
|
||||
batch, rhs_non_contract = 4, 256
|
||||
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
configs, dtype=dtype,
|
||||
)
|
||||
out = nn.scaled_matmul(a_q, b_q, a_scales, b_scales,
|
||||
preferred_element_type=dtype)
|
||||
out_ref = jnp.matmul(a.astype(jnp.float32),
|
||||
jnp.transpose(b, (0, 2, 1)).astype(jnp.float32))
|
||||
self.assertArraysAllClose(
|
||||
out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
|
||||
@parameterized.product(
|
||||
is_training=[True, False],
|
||||
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
configs=[mxfp8_configs,],
|
||||
)
|
||||
def testScaledDotGeneral(
|
||||
self, is_training, output_type, configs,
|
||||
):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=jnp.float32,
|
||||
)
|
||||
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
||||
a_shape = [2, 256, 96]
|
||||
b_shape = [2, 96, 160]
|
||||
dimension_numbers = (([2], [1]), ([0], [0]))
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type),
|
||||
configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type),
|
||||
configs[1].data_type,
|
||||
)
|
||||
|
||||
scaled_dot_general_fn = partial(
|
||||
nn.scaled_dot_general, configs=configs
|
||||
)
|
||||
def fwd(a, b, is_ref=False):
|
||||
fn = jax.lax.dot_general if is_ref else scaled_dot_general_fn
|
||||
y = fn(a, b, dimension_numbers,
|
||||
preferred_element_type=output_type)
|
||||
return jnp.sum(y)
|
||||
|
||||
if is_training:
|
||||
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
|
||||
|
||||
j_train_ref = jax.jit(
|
||||
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
|
||||
)
|
||||
out, (x_grad, w_grad) = j_train(a, b)
|
||||
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
|
||||
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
||||
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
||||
else:
|
||||
j_inference = jax.jit(fwd)
|
||||
j_inference_ref = jax.jit(partial(fwd, is_ref=True))
|
||||
out = j_inference(a, b)
|
||||
out_ref = j_inference_ref(a, b)
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@parameterized.product(
|
||||
dtype=[jnp.bfloat16, jnp.float16],
|
||||
group_num=[1, 2, 4],
|
||||
|
@ -11,10 +11,11 @@ from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
scaled_matmul,
|
||||
mxfp8_dot_general,
|
||||
quantize,
|
||||
shape_normalization,
|
||||
scaled_matmul_wrapper,
|
||||
scaled_dot_general_wrapper,
|
||||
mxfp8_configs,
|
||||
generate_quantized_tensors,
|
||||
quantize_dequantize,
|
||||
)
|
||||
|
||||
|
||||
@ -42,57 +43,6 @@ expected_hlos = [
|
||||
]
|
||||
sharding_configs = [[i, j] for i, j in zip(input_sharding_configs, expected_hlos)]
|
||||
|
||||
def quantize_to_fp8(x, q_dtype, compute_dtype, scale):
|
||||
# Explicitly cast the max values to the compute dtype to avoid unnecessary
|
||||
# casting to FP32 during the subsequent math operations."
|
||||
assert q_dtype in (jnp.float8_e4m3fn, )
|
||||
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
|
||||
scaled_x = x / jnp.broadcast_to(
|
||||
jnp.asarray(scale, dtype=compute_dtype), x.shape
|
||||
)
|
||||
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
|
||||
return clipped_x.astype(q_dtype)
|
||||
|
||||
def quantize_dequantize_fp8(x, q_dtype, scale, compute_dtype):
|
||||
qx = quantize_to_fp8(x, q_dtype, compute_dtype, scale)
|
||||
out = qx.astype(x.dtype) * jnp.broadcast_to(
|
||||
jnp.asarray(scale, dtype=x.dtype), qx.shape
|
||||
)
|
||||
return out
|
||||
|
||||
def generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract, dtype=jnp.float32
|
||||
):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize_fp8,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=dtype,
|
||||
)
|
||||
|
||||
k1, k2 = jax.random.split(jax.random.key(123), 2)
|
||||
|
||||
f8_dtype = jnp.float8_e4m3fn
|
||||
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
f8_dtype,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
f8_dtype,
|
||||
)
|
||||
|
||||
dn = ((2,), (0,))
|
||||
a_3d = shape_normalization(a, dn)
|
||||
b_3d = shape_normalization(b, dn)
|
||||
a_q, a_scales = quantize(a, f8_dtype)
|
||||
b_q, b_scales = quantize(b, f8_dtype)
|
||||
|
||||
return a, b, a_q, b_q, a_scales, b_scales
|
||||
|
||||
def shard_and_device_put(
|
||||
mesh, a_sharding, b_sharding, a, b, a_scales=None, b_scales=None
|
||||
@ -126,19 +76,19 @@ def shard_and_device_put(
|
||||
return a, b, in_shardings
|
||||
|
||||
|
||||
def get_hlo_text(in_shardings):
|
||||
def get_hlo_text(in_shardings, block_scale_configs=mxfp8_configs):
|
||||
mesh_names = ("dp", "tp")
|
||||
devices = np.array(jax.local_devices()[:4]).reshape((2, 2))
|
||||
mesh = Mesh(devices, mesh_names)
|
||||
_, _, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
||||
2, 512, 1024, 512
|
||||
2, 512, 1024, 512, block_scale_configs,
|
||||
)
|
||||
|
||||
with mesh:
|
||||
a_q, b_q, a_scales, b_scales, in_shardings = shard_and_device_put(
|
||||
mesh, in_shardings[0], in_shardings[1], a_q, b_q, a_scales, b_scales
|
||||
)
|
||||
pjit_fn = jax.jit(scaled_matmul, in_shardings=in_shardings)
|
||||
pjit_fn = jax.jit(scaled_matmul_wrapper, in_shardings=in_shardings)
|
||||
hlo = pjit_fn.lower(a_q, b_q, a_scales, b_scales).compile()
|
||||
return hlo.as_text()
|
||||
|
||||
@ -180,16 +130,20 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
contract=[160, 96],
|
||||
lhs_non_contract=[240, 100],
|
||||
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
block_scale_configs=[mxfp8_configs,],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_scaled_matmul(self, contract, lhs_non_contract, dtype):
|
||||
def test_scaled_matmul(
|
||||
self, contract, lhs_non_contract, dtype, block_scale_configs,
|
||||
):
|
||||
batch, rhs_non_contract = 2, 128
|
||||
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract, dtype=dtype
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
block_scale_configs, dtype=dtype,
|
||||
)
|
||||
|
||||
def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
|
||||
return scaled_matmul(
|
||||
return scaled_matmul_wrapper(
|
||||
lhs,
|
||||
rhs,
|
||||
lhs_scales,
|
||||
@ -216,15 +170,18 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
|
||||
@jtu.sample_product(sharding_config=sharding_configs)
|
||||
@jtu.sample_product(
|
||||
sharding_config=sharding_configs,
|
||||
block_scale_configs=[mxfp8_configs,],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_scaled_matmul_sharded(self, sharding_config):
|
||||
def test_scaled_matmul_sharded(self, sharding_config, block_scale_configs):
|
||||
if len(jax.local_devices()) < 4:
|
||||
self.skipTest("Require at least 4 devices to run sharding tests.")
|
||||
batch, contract, non_contract = 2, 1024, 256
|
||||
|
||||
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
||||
batch, non_contract, contract, non_contract
|
||||
batch, non_contract, contract, non_contract, block_scale_configs,
|
||||
)
|
||||
|
||||
devices = np.array(jax.local_devices()[:4])
|
||||
@ -246,7 +203,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
|
||||
args = [a_q, b_q, a_scales, b_scales]
|
||||
j_scaled_matmul = jax.jit(
|
||||
scaled_matmul, in_shardings=input_shardings
|
||||
scaled_matmul_wrapper, in_shardings=input_shardings
|
||||
)
|
||||
hlo_text = j_scaled_matmul.lower(*args).compile().as_text()
|
||||
hlo_pattern = re.compile(
|
||||
@ -268,7 +225,9 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
||||
class Mxfp8DotGeneralTest(jtu.JaxTestCase):
|
||||
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
|
||||
block_scale_configs = mxfp8_configs
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -302,7 +261,7 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_dot_general(self, configs, output_type):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize_fp8,
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=jnp.float32,
|
||||
)
|
||||
@ -311,16 +270,21 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
|
||||
a_shape, b_shape, dimension_numbers, is_training = configs
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type),
|
||||
jnp.float8_e4m3fn,
|
||||
self.block_scale_configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type),
|
||||
jnp.float8_e4m3fn,
|
||||
self.block_scale_configs[1].data_type,
|
||||
)
|
||||
|
||||
scaled_dot_general = partial(
|
||||
scaled_dot_general_wrapper,
|
||||
configs=self.block_scale_configs
|
||||
)
|
||||
def fwd(a, b, is_ref=False):
|
||||
fn = jax.lax.dot_general if is_ref else mxfp8_dot_general
|
||||
y = fn(a, b, dimension_numbers, preferred_element_type=output_type)
|
||||
fn = jax.lax.dot_general if is_ref else scaled_dot_general
|
||||
y = fn(a, b, dimension_numbers,
|
||||
preferred_element_type=output_type)
|
||||
return jnp.sum(y)
|
||||
|
||||
if is_training:
|
||||
@ -349,7 +313,7 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
|
||||
self.skipTest("Require at least 4 devices to run sharding tests.")
|
||||
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize_fp8,
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=jnp.float32,
|
||||
)
|
||||
@ -360,14 +324,20 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
|
||||
|
||||
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(k1, a_shape, minval=-1.0), jnp.float8_e4m3fn
|
||||
jax.random.uniform(k1, a_shape, minval=-1.0),
|
||||
self.block_scale_configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(k2, b_shape, minval=-1.0), jnp.float8_e4m3fn
|
||||
jax.random.uniform(k2, b_shape, minval=-1.0),
|
||||
self.block_scale_configs[1].data_type,
|
||||
)
|
||||
|
||||
scaled_dot_general = partial(
|
||||
scaled_dot_general_wrapper,
|
||||
configs=self.block_scale_configs
|
||||
)
|
||||
def fwd(a, b, is_ref=False):
|
||||
fn = jax.lax.dot_general if is_ref else mxfp8_dot_general
|
||||
fn = jax.lax.dot_general if is_ref else scaled_dot_general
|
||||
y = fn(a, b, dimension_numbers)
|
||||
# Use a little complex loss function to avoid constant grads, whose
|
||||
# sharding info might be optimized off and then cause issue with the
|
||||
@ -415,7 +385,7 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_dot_general_vmap(self, configs):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize_fp8,
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=jnp.float32,
|
||||
)
|
||||
@ -426,15 +396,21 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
|
||||
dimension_numbers = (([1], [1]), ([], []))
|
||||
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(k1, a_shape, minval=-1.0), jnp.float8_e4m3fn
|
||||
jax.random.uniform(k1, a_shape, minval=-1.0),
|
||||
self.block_scale_configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(k2, b_shape, minval=-1.0), jnp.float8_e4m3fn
|
||||
jax.random.uniform(k2, b_shape, minval=-1.0),
|
||||
self.block_scale_configs[1].data_type,
|
||||
)
|
||||
|
||||
scaled_dot_general = partial(
|
||||
scaled_dot_general_wrapper,
|
||||
configs=self.block_scale_configs
|
||||
)
|
||||
def fwd(a, b, is_ref=False):
|
||||
fn = jax.vmap(
|
||||
jax.lax.dot_general if is_ref else mxfp8_dot_general,
|
||||
jax.lax.dot_general if is_ref else scaled_dot_general,
|
||||
in_axes=(a_axis, b_axis, None),
|
||||
out_axes=o_axis,
|
||||
)
|
||||
@ -455,5 +431,3 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user