block_scale_config

This commit is contained in:
shuw 2025-02-13 04:20:30 +00:00 committed by JAX Toolbox
parent 061d4acbfb
commit 332af58765
5 changed files with 368 additions and 124 deletions

View File

@ -1,6 +1,8 @@
from dataclasses import dataclass
import json
import operator
from functools import partial, reduce
from typing import List
# Third-party imports
import jax
@ -21,6 +23,26 @@ from jax.sharding import PartitionSpec as P
Array = jnp.ndarray
nv_scaled_dot_name = "__nv$scaled_dot"
@dataclass
class BlockScaleConfig:
mode: str
block_size: int
data_type: DTypeLike
scale_type: DTypeLike
global_scale: Array | None
infer_only: bool
mxfp8_config = BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
BlockScaleConfigs = List[BlockScaleConfig]
mxfp8_configs: BlockScaleConfigs = [mxfp8_config, mxfp8_config, mxfp8_config]
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
@ -301,7 +323,7 @@ def _scaled_matmul(
)
return output[0]
def scaled_matmul(
def scaled_matmul_wrapper(
lhs: Array,
rhs: Array,
lhs_scales: Array,
@ -447,39 +469,92 @@ def e8m0_to_dtype(x, dtype):
)
return new_x.astype(dtype)
def quantize_core(x, q_dtype, scale):
scaled_x = x / scale
MAX = jnp.finfo(q_dtype).max.astype(x.dtype)
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
return clipped_x.astype(q_dtype)
def quantize(x, quantize_type, block_size=32):
def quantize(x, config):
x_shape = x.shape
contract_dim = x_shape[-1]
block_size = config.block_size
assert contract_dim >= block_size and contract_dim % block_size == 0
x_new_shape = x_shape[:-1] + (x_shape[-1] // block_size, block_size)
x = x.reshape(x_new_shape) # shape = (B, M, K / block_size, block_size)
amax = jnp.max(jnp.abs(x), axis=-1, keepdims=True)
MAX = jnp.finfo(quantize_type).max.astype(x.dtype)
MAX = jnp.finfo(config.data_type).max.astype(x.dtype)
scales = amax / MAX # shape = (B, M, K / block_size, 1)
scales_q = cast_to_e8m0_with_rounding_up(scales)
scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype)
if config.mode == "mxfp8":
assert config.scale_type == jnp.float8_e8m0fnu
scales_q = cast_to_e8m0_with_rounding_up(scales)
scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype)
elif config.mode == "nvfp4":
assert config.scale_type == jnp.float8_e4m3fn
# shuw(TODO): placeholder
scales_q = scales
scales_x = x
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
x_q = clipped_x.astype(quantize_type)
x_q = clipped_x.astype(config.data_type)
x_q = x_q.reshape(x_shape) # shape = (B, M, K)
scales_q = jnp.reshape(scales_q, scales_q.shape[:-1]).view(
jnp.float8_e8m0fnu
config.scale_type
)
return x_q, scales_q
def mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type):
def quantize_to_qtype(x, q_dtype, compute_dtype, scale):
# Explicitly cast the max values to the compute dtype to avoid unnecessary
# casting to FP32 during the subsequent math operations."
assert q_dtype in (jnp.float8_e4m3fn, )
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
scaled_x = x / jnp.broadcast_to(
jnp.asarray(scale, dtype=compute_dtype), x.shape
)
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
return clipped_x.astype(q_dtype)
def quantize_dequantize(x, q_dtype, scale, compute_dtype):
qx = quantize_to_qtype(x, q_dtype, compute_dtype, scale)
out = qx.astype(x.dtype) * jnp.broadcast_to(
jnp.asarray(scale, dtype=x.dtype), qx.shape
)
return out
def generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
configs, dtype=jnp.float32,
):
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=dtype,
)
k1, k2 = jax.random.split(jax.random.key(123), 2)
a = cast_to_representable(
jax.random.uniform(
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[1].data_type,
)
dn = ((2,), (0,))
a_3d = shape_normalization(a, dn)
b_3d = shape_normalization(b, dn)
a_q, a_scales = quantize(a, configs[0])
b_q, b_scales = quantize(b, configs[1])
return a, b, a_q, b_q, a_scales, b_scales
def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
configs):
if preferred_element_type is None:
preferred_element_type = dtypes.result_type(
lhs, rhs, return_weak_type_flag=False
@ -495,10 +570,11 @@ def mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type):
lhs_3d = shape_normalization(lhs, lhs_dn)
rhs_3d = shape_normalization(rhs, rhs_dn)
lhs_q, lhs_scales = quantize(lhs_3d, jnp.float8_e4m3fn)
rhs_q, rhs_scales = quantize(rhs_3d, jnp.float8_e4m3fn)
lhs_config, rhs_config = configs[0], configs[1]
lhs_q, lhs_scales = quantize(lhs_3d, lhs_config)
rhs_q, rhs_scales = quantize(rhs_3d, rhs_config)
out = scaled_matmul(
out = scaled_matmul_wrapper(
lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type
)
@ -509,8 +585,9 @@ def mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type):
return expanded_out
def mxfp8_dot_general_transpose_lhs(
g, x, y, *, dimension_numbers, preferred_element_type, swap_ans=False
def scaled_dot_general_transpose_lhs(
g, x, y, *, dimension_numbers, preferred_element_type, configs,
swap_ans=False
):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
x_ndim = x.aval.ndim
@ -530,10 +607,15 @@ def mxfp8_dot_general_transpose_lhs(
y_3d = shape_normalization(y, y_dn)
g_3d = shape_normalization(g, g_dn)
g_q, g_scales = quantize(g_3d, jnp.float8_e4m3fn)
y_q, y_scales = quantize(y_3d, jnp.float8_e4m3fn)
out = scaled_matmul(g_q, y_q, g_scales, y_scales, preferred_element_type)
g_config, y_config = configs[0], configs[1]
g_q, g_scales = quantize(g_3d, g_config)
y_q, y_scales = quantize(y_3d, y_config)
out = scaled_matmul_wrapper(
g_q, y_q, g_scales, y_scales, preferred_element_type
)
expanded_out_shape = compute_dot_output_shape(g.shape, y.shape, g_dn, y_dn)
expanded_out = jnp.reshape(out, expanded_out_shape)
@ -541,34 +623,40 @@ def mxfp8_dot_general_transpose_lhs(
return x_bar
def mxfp8_dot_general_transpose_rhs(
g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike | None
def scaled_dot_general_transpose_rhs(
g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike,
configs: BlockScaleConfigs
):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
y_bar = mxfp8_dot_general_transpose_lhs(
y_bar = scaled_dot_general_transpose_lhs(
g,
y,
x,
dimension_numbers=swapped_dimension_numbers,
preferred_element_type=preferred_element_type,
configs=configs,
swap_ans=True,
)
return y_bar
@partial(custom_vjp, nondiff_argnums=(2, 3))
def mxfp8_dot_general_fn(lhs, rhs, dimension_numbers, preferred_element_type):
return mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type)
@partial(custom_vjp, nondiff_argnums=(2, 3, 4))
def scaled_dot_general_fn(lhs, rhs, dimension_numbers, preferred_element_type,
configs):
return scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
configs)
def mxfp8_dot_fwd(lhs, rhs, dimension_numbers, preferred_element_type):
out = mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type)
def scaled_dot_fwd(lhs, rhs, dimension_numbers, preferred_element_type,
configs):
out = scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
configs)
res = (lhs, rhs)
return out, res
def mxfp8_dot_bwd(dimension_numbers, preferred_element_type, res, g):
def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g):
(lhs, rhs) = res
args = [g, lhs, rhs]
@ -576,12 +664,20 @@ def mxfp8_dot_bwd(dimension_numbers, preferred_element_type, res, g):
"dimension_numbers": dimension_numbers,
"preferred_element_type": preferred_element_type,
}
grad_lhs = mxfp8_dot_general_transpose_lhs(*args, **kw_args)
grad_rhs = mxfp8_dot_general_transpose_rhs(*args, **kw_args)
lhs_kw_args = {
**kw_args,
"configs": [configs[2], configs[1]]
}
rhs_kw_args = {
**kw_args,
"configs": [configs[2], configs[0]]
}
grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args)
grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args)
return (grad_lhs, grad_rhs)
mxfp8_dot_general_fn.defvjp(mxfp8_dot_fwd, mxfp8_dot_bwd)
scaled_dot_general_fn.defvjp(scaled_dot_fwd, scaled_dot_bwd)
def ensure_tuple(dimension_numbers):
@ -610,9 +706,10 @@ def _ensure_batch_dim(lhs, rhs, dimension_numbers):
return lhs_batched, rhs_batched, dn_batched
# TODO(shuw): mxfp8_dot_general should be in nn.function when upstreaming.
def mxfp8_dot_general(
lhs, rhs, dimension_numbers, preferred_element_type=jnp.float32
def scaled_dot_general_wrapper(
lhs, rhs, dimension_numbers,
preferred_element_type=jnp.float32,
configs: BlockScaleConfigs=mxfp8_configs,
):
if preferred_element_type not in (jnp.float32, jnp.bfloat16, jnp.float16):
msg = ('Only support preferred_element_type in (f32, bf16, f16), but got '
@ -622,8 +719,8 @@ def mxfp8_dot_general(
lhs_batched, rhs_batched, dn_batched = _ensure_batch_dim(
lhs, rhs, dimension_numbers
)
out = mxfp8_dot_general_fn(
lhs_batched, rhs_batched, dn_batched, preferred_element_type
out = scaled_dot_general_fn(
lhs_batched, rhs_batched, dn_batched, preferred_element_type, configs,
)
# Expanding batch dims for operands adds a singleton batch dim at axis 0 in
@ -631,4 +728,3 @@ def mxfp8_dot_general(
if dn_batched != dimension_numbers:
return jnp.squeeze(out, axis=0)
return out

View File

@ -36,10 +36,14 @@ from jax._src.core import AxisName
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention as cudnn_dot_product_attention, MaskType)
from jax._src.cudnn.scaled_matmul_stablehlo import (
scaled_matmul_wrapper as cudnn_scaled_matmul,
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
BlockScaleConfigs, mxfp8_configs)
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.numpy import util as numpy_util
from jax._src.typing import Array, ArrayLike, DType
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.ops.special import logsumexp as _logsumexp
@ -1159,3 +1163,92 @@ def dot_product_attention(
raise ValueError(f"Unsupported implementation option: {implementation}")
return jnp.reshape(out, output_shape)
def scaled_matmul(
lhs: Array,
rhs: Array,
lhs_scales: Array,
rhs_scales: Array,
preferred_element_type: DTypeLike = jnp.float32,
) -> Array:
r"""
Performs scaled matrix multiplication between two 3D arrays, with scaling
factors applied to the matrices.
.. math::
\mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales*rhs_scales*\mathrm{dot}(lhs, rhs)
Args:
lhs (Array): A 3D array of shape (B, M, K).
rhs (Array): A 3D array of shape (B, N, K).
lhs_scales (Array): A 3D array of shape (B, M, K_block).
rhs_scales (Array): A 3D array of shape (B, N, K_block).
preferred_element_type (DTypeLike, optional): The preferred data type
for the computation. Defaults to `jnp.float32`.
Returns:
Array: A 3D array of shape (B, M, N) representing the scaled matrix
multiplication result.
Raises:
AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
match the number of columns in `rhs` (`rhs_K`).
Notes:
- The function ensures that the `preferred_element_type` is
danonicalized before passing it to the underlying computation.
- Scaling is applied to the matrices based on the `lhs_scales` and
`rhs_scales` arrays, enabling efficient computations in blocks.
"""
B, M, lhs_K = lhs.shape
_, N, rhs_K = rhs.shape
assert lhs_K == rhs_K
_, _, K_block = lhs_scales.shape
preferred_element_type = dtypes.canonicalize_dtype(
np.dtype(preferred_element_type)
)
out = cudnn_scaled_matmul(
lhs,
rhs,
lhs_scales,
rhs_scales,
preferred_element_type=preferred_element_type,
)
return out
def scaled_dot_general(
lhs, rhs,
dimension_numbers,
preferred_element_type=jnp.float32,
configs: BlockScaleConfigs=mxfp8_configs,
):
r"""Scaled dot general operation.
Computes the scaled dot general on lhs, rhs with quanitzation specified by configs:
.. math::
\widehat{lhs}, s_a=\mathrm{quantize}(lhs)
\widehat{rhs}, s_b=\mathrm{quantize}(rhs)
\mathrm{ScaledDot}(lhs, rhs)=s_a s_b \mathrm{dot}(\widehat{lhs}, \widehat{rhs})
Args:
lhs: Left-hand side input tensor.
rhs: Right-hand side input tensor.
dimension_numbers: A tuple specifying the contraction and batch dimensions
for the dot general operation. Must follow the format:
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
preferred_element_type: The preferred output data type. Supported types are
`jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`.
configs: A `BlockScaleConfigs` objects specifying the scaling
configurations for the operation. Defaults to `mxfp8_configs`.
Returns:
The result of the scaled dot general operation.
"""
return cudnn_scaled_dot_general(
lhs, rhs, dimension_numbers,
preferred_element_type=preferred_element_type,
configs=configs
)

View File

@ -37,6 +37,8 @@ from jax._src.nn.functions import (
relu as relu,
relu6 as relu6,
dot_product_attention as dot_product_attention,
scaled_dot_general as scaled_dot_general,
scaled_matmul as scaled_matmul,
selu as selu,
sigmoid as sigmoid,
soft_sign as soft_sign,

View File

@ -29,6 +29,11 @@ from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax._src.cudnn.scaled_matmul_stablehlo import (
generate_quantized_tensors,
mxfp8_configs,
quantize_dequantize,
)
from jax.test_util import check_grads
from jax import nn
from jax import random
@ -54,6 +59,80 @@ _cudnn_dbias_error = 'cuDNN only supports bias gradient'
@jtu.with_config(jax_legacy_prng_key="allow",
jax_numpy_dtype_promotion="standard")
class NNFunctionsTest(jtu.JaxTestCase):
@parameterized.product(
contract=[160, 96],
lhs_non_contract=[240, 100],
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
configs=[mxfp8_configs,],
)
def testScaledMatmul(self, contract, lhs_non_contract, dtype, configs):
batch, rhs_non_contract = 4, 256
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
configs, dtype=dtype,
)
out = nn.scaled_matmul(a_q, b_q, a_scales, b_scales,
preferred_element_type=dtype)
out_ref = jnp.matmul(a.astype(jnp.float32),
jnp.transpose(b, (0, 2, 1)).astype(jnp.float32))
self.assertArraysAllClose(
out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3
)
@parameterized.product(
is_training=[True, False],
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
configs=[mxfp8_configs,],
)
def testScaledDotGeneral(
self, is_training, output_type, configs,
):
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=jnp.float32,
)
k1, k2 = jax.random.split(jax.random.key(0), 2)
a_shape = [2, 256, 96]
b_shape = [2, 96, 160]
dimension_numbers = (([2], [1]), ([0], [0]))
a = cast_to_representable(
jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type),
configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type),
configs[1].data_type,
)
scaled_dot_general_fn = partial(
nn.scaled_dot_general, configs=configs
)
def fwd(a, b, is_ref=False):
fn = jax.lax.dot_general if is_ref else scaled_dot_general_fn
y = fn(a, b, dimension_numbers,
preferred_element_type=output_type)
return jnp.sum(y)
if is_training:
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
j_train_ref = jax.jit(
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
)
out, (x_grad, w_grad) = j_train(a, b)
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
else:
j_inference = jax.jit(fwd)
j_inference_ref = jax.jit(partial(fwd, is_ref=True))
out = j_inference(a, b)
out_ref = j_inference_ref(a, b)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
@parameterized.product(
dtype=[jnp.bfloat16, jnp.float16],
group_num=[1, 2, 4],

View File

@ -11,10 +11,11 @@ from jax._src import config
from jax._src import test_util as jtu
from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
from jax._src.cudnn.scaled_matmul_stablehlo import (
scaled_matmul,
mxfp8_dot_general,
quantize,
shape_normalization,
scaled_matmul_wrapper,
scaled_dot_general_wrapper,
mxfp8_configs,
generate_quantized_tensors,
quantize_dequantize,
)
@ -42,57 +43,6 @@ expected_hlos = [
]
sharding_configs = [[i, j] for i, j in zip(input_sharding_configs, expected_hlos)]
def quantize_to_fp8(x, q_dtype, compute_dtype, scale):
# Explicitly cast the max values to the compute dtype to avoid unnecessary
# casting to FP32 during the subsequent math operations."
assert q_dtype in (jnp.float8_e4m3fn, )
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
scaled_x = x / jnp.broadcast_to(
jnp.asarray(scale, dtype=compute_dtype), x.shape
)
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
return clipped_x.astype(q_dtype)
def quantize_dequantize_fp8(x, q_dtype, scale, compute_dtype):
qx = quantize_to_fp8(x, q_dtype, compute_dtype, scale)
out = qx.astype(x.dtype) * jnp.broadcast_to(
jnp.asarray(scale, dtype=x.dtype), qx.shape
)
return out
def generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract, dtype=jnp.float32
):
cast_to_representable = partial(
quantize_dequantize_fp8,
scale=jnp.ones((1,)),
compute_dtype=dtype,
)
k1, k2 = jax.random.split(jax.random.key(123), 2)
f8_dtype = jnp.float8_e4m3fn
a = cast_to_representable(
jax.random.uniform(
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
),
f8_dtype,
)
b = cast_to_representable(
jax.random.uniform(
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
),
f8_dtype,
)
dn = ((2,), (0,))
a_3d = shape_normalization(a, dn)
b_3d = shape_normalization(b, dn)
a_q, a_scales = quantize(a, f8_dtype)
b_q, b_scales = quantize(b, f8_dtype)
return a, b, a_q, b_q, a_scales, b_scales
def shard_and_device_put(
mesh, a_sharding, b_sharding, a, b, a_scales=None, b_scales=None
@ -126,19 +76,19 @@ def shard_and_device_put(
return a, b, in_shardings
def get_hlo_text(in_shardings):
def get_hlo_text(in_shardings, block_scale_configs=mxfp8_configs):
mesh_names = ("dp", "tp")
devices = np.array(jax.local_devices()[:4]).reshape((2, 2))
mesh = Mesh(devices, mesh_names)
_, _, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
2, 512, 1024, 512
2, 512, 1024, 512, block_scale_configs,
)
with mesh:
a_q, b_q, a_scales, b_scales, in_shardings = shard_and_device_put(
mesh, in_shardings[0], in_shardings[1], a_q, b_q, a_scales, b_scales
)
pjit_fn = jax.jit(scaled_matmul, in_shardings=in_shardings)
pjit_fn = jax.jit(scaled_matmul_wrapper, in_shardings=in_shardings)
hlo = pjit_fn.lower(a_q, b_q, a_scales, b_scales).compile()
return hlo.as_text()
@ -180,16 +130,20 @@ class ScaledMatmulTest(jtu.JaxTestCase):
contract=[160, 96],
lhs_non_contract=[240, 100],
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
def test_scaled_matmul(self, contract, lhs_non_contract, dtype):
def test_scaled_matmul(
self, contract, lhs_non_contract, dtype, block_scale_configs,
):
batch, rhs_non_contract = 2, 128
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract, dtype=dtype
batch, lhs_non_contract, contract, rhs_non_contract,
block_scale_configs, dtype=dtype,
)
def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
return scaled_matmul(
return scaled_matmul_wrapper(
lhs,
rhs,
lhs_scales,
@ -216,15 +170,18 @@ class ScaledMatmulTest(jtu.JaxTestCase):
out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3
)
@jtu.sample_product(sharding_config=sharding_configs)
@jtu.sample_product(
sharding_config=sharding_configs,
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
def test_scaled_matmul_sharded(self, sharding_config):
def test_scaled_matmul_sharded(self, sharding_config, block_scale_configs):
if len(jax.local_devices()) < 4:
self.skipTest("Require at least 4 devices to run sharding tests.")
batch, contract, non_contract = 2, 1024, 256
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
batch, non_contract, contract, non_contract
batch, non_contract, contract, non_contract, block_scale_configs,
)
devices = np.array(jax.local_devices()[:4])
@ -246,7 +203,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
args = [a_q, b_q, a_scales, b_scales]
j_scaled_matmul = jax.jit(
scaled_matmul, in_shardings=input_shardings
scaled_matmul_wrapper, in_shardings=input_shardings
)
hlo_text = j_scaled_matmul.lower(*args).compile().as_text()
hlo_pattern = re.compile(
@ -268,7 +225,9 @@ class ScaledMatmulTest(jtu.JaxTestCase):
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class Mxfp8DotGeneralTest(jtu.JaxTestCase):
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
block_scale_configs = mxfp8_configs
def setUp(self):
super().setUp()
@ -302,7 +261,7 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
@jtu.run_on_devices("cuda")
def test_dot_general(self, configs, output_type):
cast_to_representable = partial(
quantize_dequantize_fp8,
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=jnp.float32,
)
@ -311,16 +270,21 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
a_shape, b_shape, dimension_numbers, is_training = configs
a = cast_to_representable(
jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type),
jnp.float8_e4m3fn,
self.block_scale_configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type),
jnp.float8_e4m3fn,
self.block_scale_configs[1].data_type,
)
scaled_dot_general = partial(
scaled_dot_general_wrapper,
configs=self.block_scale_configs
)
def fwd(a, b, is_ref=False):
fn = jax.lax.dot_general if is_ref else mxfp8_dot_general
y = fn(a, b, dimension_numbers, preferred_element_type=output_type)
fn = jax.lax.dot_general if is_ref else scaled_dot_general
y = fn(a, b, dimension_numbers,
preferred_element_type=output_type)
return jnp.sum(y)
if is_training:
@ -349,7 +313,7 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
self.skipTest("Require at least 4 devices to run sharding tests.")
cast_to_representable = partial(
quantize_dequantize_fp8,
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=jnp.float32,
)
@ -360,14 +324,20 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
k1, k2 = jax.random.split(jax.random.key(0), 2)
a = cast_to_representable(
jax.random.uniform(k1, a_shape, minval=-1.0), jnp.float8_e4m3fn
jax.random.uniform(k1, a_shape, minval=-1.0),
self.block_scale_configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(k2, b_shape, minval=-1.0), jnp.float8_e4m3fn
jax.random.uniform(k2, b_shape, minval=-1.0),
self.block_scale_configs[1].data_type,
)
scaled_dot_general = partial(
scaled_dot_general_wrapper,
configs=self.block_scale_configs
)
def fwd(a, b, is_ref=False):
fn = jax.lax.dot_general if is_ref else mxfp8_dot_general
fn = jax.lax.dot_general if is_ref else scaled_dot_general
y = fn(a, b, dimension_numbers)
# Use a little complex loss function to avoid constant grads, whose
# sharding info might be optimized off and then cause issue with the
@ -415,7 +385,7 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
@jtu.run_on_devices("cuda")
def test_dot_general_vmap(self, configs):
cast_to_representable = partial(
quantize_dequantize_fp8,
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=jnp.float32,
)
@ -426,15 +396,21 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
dimension_numbers = (([1], [1]), ([], []))
a = cast_to_representable(
jax.random.uniform(k1, a_shape, minval=-1.0), jnp.float8_e4m3fn
jax.random.uniform(k1, a_shape, minval=-1.0),
self.block_scale_configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(k2, b_shape, minval=-1.0), jnp.float8_e4m3fn
jax.random.uniform(k2, b_shape, minval=-1.0),
self.block_scale_configs[1].data_type,
)
scaled_dot_general = partial(
scaled_dot_general_wrapper,
configs=self.block_scale_configs
)
def fwd(a, b, is_ref=False):
fn = jax.vmap(
jax.lax.dot_general if is_ref else mxfp8_dot_general,
jax.lax.dot_general if is_ref else scaled_dot_general,
in_axes=(a_axis, b_axis, None),
out_axes=o_axis,
)
@ -455,5 +431,3 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase):
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())