From 332af5876504ac0a9bab28e9f1a8cc3b5e5f6a0d Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 13 Feb 2025 04:20:30 +0000 Subject: [PATCH] block_scale_config --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 180 +++++++++++++++++----- jax/_src/nn/functions.py | 95 +++++++++++- jax/nn/__init__.py | 2 + tests/nn_test.py | 79 ++++++++++ tests/scaled_matmul_stablehlo_test.py | 136 +++++++--------- 5 files changed, 368 insertions(+), 124 deletions(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index d7b04de58..269cdbdc7 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -1,6 +1,8 @@ +from dataclasses import dataclass import json import operator from functools import partial, reduce +from typing import List # Third-party imports import jax @@ -21,6 +23,26 @@ from jax.sharding import PartitionSpec as P Array = jnp.ndarray nv_scaled_dot_name = "__nv$scaled_dot" +@dataclass +class BlockScaleConfig: + mode: str + block_size: int + data_type: DTypeLike + scale_type: DTypeLike + global_scale: Array | None + infer_only: bool + +mxfp8_config = BlockScaleConfig( + mode='mxfp8', + block_size=32, + data_type=jnp.float8_e4m3fn, + scale_type=jnp.float8_e8m0fnu, + global_scale=None, + infer_only=False +) + +BlockScaleConfigs = List[BlockScaleConfig] +mxfp8_configs: BlockScaleConfigs = [mxfp8_config, mxfp8_config, mxfp8_config] def default_layouts(*shapes): return [range(len(shape) - 1, -1, -1) for shape in shapes] @@ -301,7 +323,7 @@ def _scaled_matmul( ) return output[0] -def scaled_matmul( +def scaled_matmul_wrapper( lhs: Array, rhs: Array, lhs_scales: Array, @@ -447,39 +469,92 @@ def e8m0_to_dtype(x, dtype): ) return new_x.astype(dtype) - -def quantize_core(x, q_dtype, scale): - scaled_x = x / scale - MAX = jnp.finfo(q_dtype).max.astype(x.dtype) - clipped_x = jnp.clip(scaled_x, -MAX, MAX) - return clipped_x.astype(q_dtype) - - -def quantize(x, quantize_type, block_size=32): +def quantize(x, config): x_shape = x.shape contract_dim = x_shape[-1] + block_size = config.block_size assert contract_dim >= block_size and contract_dim % block_size == 0 x_new_shape = x_shape[:-1] + (x_shape[-1] // block_size, block_size) x = x.reshape(x_new_shape) # shape = (B, M, K / block_size, block_size) amax = jnp.max(jnp.abs(x), axis=-1, keepdims=True) - MAX = jnp.finfo(quantize_type).max.astype(x.dtype) + MAX = jnp.finfo(config.data_type).max.astype(x.dtype) scales = amax / MAX # shape = (B, M, K / block_size, 1) - scales_q = cast_to_e8m0_with_rounding_up(scales) - scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype) + if config.mode == "mxfp8": + assert config.scale_type == jnp.float8_e8m0fnu + scales_q = cast_to_e8m0_with_rounding_up(scales) + scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype) + elif config.mode == "nvfp4": + assert config.scale_type == jnp.float8_e4m3fn + # shuw(TODO): placeholder + scales_q = scales + scales_x = x clipped_x = jnp.clip(scaled_x, -MAX, MAX) - x_q = clipped_x.astype(quantize_type) + x_q = clipped_x.astype(config.data_type) x_q = x_q.reshape(x_shape) # shape = (B, M, K) scales_q = jnp.reshape(scales_q, scales_q.shape[:-1]).view( - jnp.float8_e8m0fnu + config.scale_type ) return x_q, scales_q -def mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type): +def quantize_to_qtype(x, q_dtype, compute_dtype, scale): + # Explicitly cast the max values to the compute dtype to avoid unnecessary + # casting to FP32 during the subsequent math operations." + assert q_dtype in (jnp.float8_e4m3fn, ) + dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype) + scaled_x = x / jnp.broadcast_to( + jnp.asarray(scale, dtype=compute_dtype), x.shape + ) + clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max) + return clipped_x.astype(q_dtype) + +def quantize_dequantize(x, q_dtype, scale, compute_dtype): + qx = quantize_to_qtype(x, q_dtype, compute_dtype, scale) + out = qx.astype(x.dtype) * jnp.broadcast_to( + jnp.asarray(scale, dtype=x.dtype), qx.shape + ) + return out + +def generate_quantized_tensors( + batch, lhs_non_contract, contract, rhs_non_contract, + configs, dtype=jnp.float32, + ): + cast_to_representable = partial( + quantize_dequantize, + scale=jnp.ones((1,)), + compute_dtype=dtype, + ) + + k1, k2 = jax.random.split(jax.random.key(123), 2) + + a = cast_to_representable( + jax.random.uniform( + k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype + ), + configs[0].data_type, + ) + b = cast_to_representable( + jax.random.uniform( + k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype + ), + configs[1].data_type, + ) + + dn = ((2,), (0,)) + a_3d = shape_normalization(a, dn) + b_3d = shape_normalization(b, dn) + a_q, a_scales = quantize(a, configs[0]) + b_q, b_scales = quantize(b, configs[1]) + + return a, b, a_q, b_q, a_scales, b_scales + + +def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type, + configs): if preferred_element_type is None: preferred_element_type = dtypes.result_type( lhs, rhs, return_weak_type_flag=False @@ -495,10 +570,11 @@ def mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type): lhs_3d = shape_normalization(lhs, lhs_dn) rhs_3d = shape_normalization(rhs, rhs_dn) - lhs_q, lhs_scales = quantize(lhs_3d, jnp.float8_e4m3fn) - rhs_q, rhs_scales = quantize(rhs_3d, jnp.float8_e4m3fn) + lhs_config, rhs_config = configs[0], configs[1] + lhs_q, lhs_scales = quantize(lhs_3d, lhs_config) + rhs_q, rhs_scales = quantize(rhs_3d, rhs_config) - out = scaled_matmul( + out = scaled_matmul_wrapper( lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type ) @@ -509,8 +585,9 @@ def mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type): return expanded_out -def mxfp8_dot_general_transpose_lhs( - g, x, y, *, dimension_numbers, preferred_element_type, swap_ans=False +def scaled_dot_general_transpose_lhs( + g, x, y, *, dimension_numbers, preferred_element_type, configs, + swap_ans=False ): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.aval.ndim @@ -530,10 +607,15 @@ def mxfp8_dot_general_transpose_lhs( y_3d = shape_normalization(y, y_dn) g_3d = shape_normalization(g, g_dn) - g_q, g_scales = quantize(g_3d, jnp.float8_e4m3fn) - y_q, y_scales = quantize(y_3d, jnp.float8_e4m3fn) - out = scaled_matmul(g_q, y_q, g_scales, y_scales, preferred_element_type) + g_config, y_config = configs[0], configs[1] + + g_q, g_scales = quantize(g_3d, g_config) + y_q, y_scales = quantize(y_3d, y_config) + + out = scaled_matmul_wrapper( + g_q, y_q, g_scales, y_scales, preferred_element_type + ) expanded_out_shape = compute_dot_output_shape(g.shape, y.shape, g_dn, y_dn) expanded_out = jnp.reshape(out, expanded_out_shape) @@ -541,34 +623,40 @@ def mxfp8_dot_general_transpose_lhs( return x_bar -def mxfp8_dot_general_transpose_rhs( - g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike | None +def scaled_dot_general_transpose_rhs( + g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike, + configs: BlockScaleConfigs ): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) - y_bar = mxfp8_dot_general_transpose_lhs( + y_bar = scaled_dot_general_transpose_lhs( g, y, x, dimension_numbers=swapped_dimension_numbers, preferred_element_type=preferred_element_type, + configs=configs, swap_ans=True, ) return y_bar -@partial(custom_vjp, nondiff_argnums=(2, 3)) -def mxfp8_dot_general_fn(lhs, rhs, dimension_numbers, preferred_element_type): - return mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type) +@partial(custom_vjp, nondiff_argnums=(2, 3, 4)) +def scaled_dot_general_fn(lhs, rhs, dimension_numbers, preferred_element_type, + configs): + return scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type, + configs) -def mxfp8_dot_fwd(lhs, rhs, dimension_numbers, preferred_element_type): - out = mxfp8_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type) +def scaled_dot_fwd(lhs, rhs, dimension_numbers, preferred_element_type, + configs): + out = scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type, + configs) res = (lhs, rhs) return out, res -def mxfp8_dot_bwd(dimension_numbers, preferred_element_type, res, g): +def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): (lhs, rhs) = res args = [g, lhs, rhs] @@ -576,12 +664,20 @@ def mxfp8_dot_bwd(dimension_numbers, preferred_element_type, res, g): "dimension_numbers": dimension_numbers, "preferred_element_type": preferred_element_type, } - grad_lhs = mxfp8_dot_general_transpose_lhs(*args, **kw_args) - grad_rhs = mxfp8_dot_general_transpose_rhs(*args, **kw_args) + lhs_kw_args = { + **kw_args, + "configs": [configs[2], configs[1]] + } + rhs_kw_args = { + **kw_args, + "configs": [configs[2], configs[0]] + } + grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args) + grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) return (grad_lhs, grad_rhs) -mxfp8_dot_general_fn.defvjp(mxfp8_dot_fwd, mxfp8_dot_bwd) +scaled_dot_general_fn.defvjp(scaled_dot_fwd, scaled_dot_bwd) def ensure_tuple(dimension_numbers): @@ -610,9 +706,10 @@ def _ensure_batch_dim(lhs, rhs, dimension_numbers): return lhs_batched, rhs_batched, dn_batched -# TODO(shuw): mxfp8_dot_general should be in nn.function when upstreaming. -def mxfp8_dot_general( - lhs, rhs, dimension_numbers, preferred_element_type=jnp.float32 +def scaled_dot_general_wrapper( + lhs, rhs, dimension_numbers, + preferred_element_type=jnp.float32, + configs: BlockScaleConfigs=mxfp8_configs, ): if preferred_element_type not in (jnp.float32, jnp.bfloat16, jnp.float16): msg = ('Only support preferred_element_type in (f32, bf16, f16), but got ' @@ -622,8 +719,8 @@ def mxfp8_dot_general( lhs_batched, rhs_batched, dn_batched = _ensure_batch_dim( lhs, rhs, dimension_numbers ) - out = mxfp8_dot_general_fn( - lhs_batched, rhs_batched, dn_batched, preferred_element_type + out = scaled_dot_general_fn( + lhs_batched, rhs_batched, dn_batched, preferred_element_type, configs, ) # Expanding batch dims for operands adds a singleton batch dim at axis 0 in @@ -631,4 +728,3 @@ def mxfp8_dot_general( if dn_batched != dimension_numbers: return jnp.squeeze(out, axis=0) return out - diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 72ac74c38..a4acf60c2 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -36,10 +36,14 @@ from jax._src.core import AxisName from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention as cudnn_dot_product_attention, MaskType) +from jax._src.cudnn.scaled_matmul_stablehlo import ( + scaled_matmul_wrapper as cudnn_scaled_matmul, + scaled_dot_general_wrapper as cudnn_scaled_dot_general, + BlockScaleConfigs, mxfp8_configs) from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.numpy import util as numpy_util -from jax._src.typing import Array, ArrayLike, DType +from jax._src.typing import Array, ArrayLike, DType, DTypeLike from jax._src.ops.special import logsumexp as _logsumexp @@ -1159,3 +1163,92 @@ def dot_product_attention( raise ValueError(f"Unsupported implementation option: {implementation}") return jnp.reshape(out, output_shape) + +def scaled_matmul( + lhs: Array, + rhs: Array, + lhs_scales: Array, + rhs_scales: Array, + preferred_element_type: DTypeLike = jnp.float32, +) -> Array: + r""" + Performs scaled matrix multiplication between two 3D arrays, with scaling + factors applied to the matrices. + + .. math:: + \mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales*rhs_scales*\mathrm{dot}(lhs, rhs) + + Args: + lhs (Array): A 3D array of shape (B, M, K). + rhs (Array): A 3D array of shape (B, N, K). + lhs_scales (Array): A 3D array of shape (B, M, K_block). + rhs_scales (Array): A 3D array of shape (B, N, K_block). + preferred_element_type (DTypeLike, optional): The preferred data type + for the computation. Defaults to `jnp.float32`. + + Returns: + Array: A 3D array of shape (B, M, N) representing the scaled matrix + multiplication result. + + Raises: + AssertionError: If the number of columns in `lhs` (`lhs_K`) does not + match the number of columns in `rhs` (`rhs_K`). + + Notes: + - The function ensures that the `preferred_element_type` is + danonicalized before passing it to the underlying computation. + - Scaling is applied to the matrices based on the `lhs_scales` and + `rhs_scales` arrays, enabling efficient computations in blocks. + + """ + B, M, lhs_K = lhs.shape + _, N, rhs_K = rhs.shape + assert lhs_K == rhs_K + _, _, K_block = lhs_scales.shape + + preferred_element_type = dtypes.canonicalize_dtype( + np.dtype(preferred_element_type) + ) + out = cudnn_scaled_matmul( + lhs, + rhs, + lhs_scales, + rhs_scales, + preferred_element_type=preferred_element_type, + ) + return out + +def scaled_dot_general( + lhs, rhs, + dimension_numbers, + preferred_element_type=jnp.float32, + configs: BlockScaleConfigs=mxfp8_configs, + ): + r"""Scaled dot general operation. + + Computes the scaled dot general on lhs, rhs with quanitzation specified by configs: + + .. math:: + \widehat{lhs}, s_a=\mathrm{quantize}(lhs) + \widehat{rhs}, s_b=\mathrm{quantize}(rhs) + \mathrm{ScaledDot}(lhs, rhs)=s_a s_b \mathrm{dot}(\widehat{lhs}, \widehat{rhs}) + + Args: + lhs: Left-hand side input tensor. + rhs: Right-hand side input tensor. + dimension_numbers: A tuple specifying the contraction and batch dimensions + for the dot general operation. Must follow the format: + `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. + preferred_element_type: The preferred output data type. Supported types are + `jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`. + configs: A `BlockScaleConfigs` objects specifying the scaling + configurations for the operation. Defaults to `mxfp8_configs`. + + Returns: + The result of the scaled dot general operation. + """ + return cudnn_scaled_dot_general( + lhs, rhs, dimension_numbers, + preferred_element_type=preferred_element_type, + configs=configs + ) diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index ebe725c44..3f08e1c0f 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -37,6 +37,8 @@ from jax._src.nn.functions import ( relu as relu, relu6 as relu6, dot_product_attention as dot_product_attention, + scaled_dot_general as scaled_dot_general, + scaled_matmul as scaled_matmul, selu as selu, sigmoid as sigmoid, soft_sign as soft_sign, diff --git a/tests/nn_test.py b/tests/nn_test.py index 1f032b3f0..1525a45cd 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -29,6 +29,11 @@ from jax._src import config from jax._src import core from jax._src import test_util as jtu from jax._src.lib import cuda_versions +from jax._src.cudnn.scaled_matmul_stablehlo import ( + generate_quantized_tensors, + mxfp8_configs, + quantize_dequantize, +) from jax.test_util import check_grads from jax import nn from jax import random @@ -54,6 +59,80 @@ _cudnn_dbias_error = 'cuDNN only supports bias gradient' @jtu.with_config(jax_legacy_prng_key="allow", jax_numpy_dtype_promotion="standard") class NNFunctionsTest(jtu.JaxTestCase): + @parameterized.product( + contract=[160, 96], + lhs_non_contract=[240, 100], + dtype=[jnp.float16, jnp.bfloat16, jnp.float32], + configs=[mxfp8_configs,], + ) + def testScaledMatmul(self, contract, lhs_non_contract, dtype, configs): + batch, rhs_non_contract = 4, 256 + a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors( + batch, lhs_non_contract, contract, rhs_non_contract, + configs, dtype=dtype, + ) + out = nn.scaled_matmul(a_q, b_q, a_scales, b_scales, + preferred_element_type=dtype) + out_ref = jnp.matmul(a.astype(jnp.float32), + jnp.transpose(b, (0, 2, 1)).astype(jnp.float32)) + self.assertArraysAllClose( + out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3 + ) + + @parameterized.product( + is_training=[True, False], + output_type=[jnp.float16, jnp.bfloat16, jnp.float32], + configs=[mxfp8_configs,], + ) + def testScaledDotGeneral( + self, is_training, output_type, configs, + ): + cast_to_representable = partial( + quantize_dequantize, + scale=jnp.ones((1,)), + compute_dtype=jnp.float32, + ) + k1, k2 = jax.random.split(jax.random.key(0), 2) + a_shape = [2, 256, 96] + b_shape = [2, 96, 160] + dimension_numbers = (([2], [1]), ([0], [0])) + a = cast_to_representable( + jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type), + configs[0].data_type, + ) + b = cast_to_representable( + jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type), + configs[1].data_type, + ) + + scaled_dot_general_fn = partial( + nn.scaled_dot_general, configs=configs + ) + def fwd(a, b, is_ref=False): + fn = jax.lax.dot_general if is_ref else scaled_dot_general_fn + y = fn(a, b, dimension_numbers, + preferred_element_type=output_type) + return jnp.sum(y) + + if is_training: + j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1])) + + j_train_ref = jax.jit( + jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]) + ) + out, (x_grad, w_grad) = j_train(a, b) + out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b) + + self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) + self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) + else: + j_inference = jax.jit(fwd) + j_inference_ref = jax.jit(partial(fwd, is_ref=True)) + out = j_inference(a, b) + out_ref = j_inference_ref(a, b) + self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + @parameterized.product( dtype=[jnp.bfloat16, jnp.float16], group_num=[1, 2, 4], diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 50e6af06b..ebfa4fbe7 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -11,10 +11,11 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version from jax._src.cudnn.scaled_matmul_stablehlo import ( - scaled_matmul, - mxfp8_dot_general, - quantize, - shape_normalization, + scaled_matmul_wrapper, + scaled_dot_general_wrapper, + mxfp8_configs, + generate_quantized_tensors, + quantize_dequantize, ) @@ -42,57 +43,6 @@ expected_hlos = [ ] sharding_configs = [[i, j] for i, j in zip(input_sharding_configs, expected_hlos)] -def quantize_to_fp8(x, q_dtype, compute_dtype, scale): - # Explicitly cast the max values to the compute dtype to avoid unnecessary - # casting to FP32 during the subsequent math operations." - assert q_dtype in (jnp.float8_e4m3fn, ) - dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype) - scaled_x = x / jnp.broadcast_to( - jnp.asarray(scale, dtype=compute_dtype), x.shape - ) - clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max) - return clipped_x.astype(q_dtype) - -def quantize_dequantize_fp8(x, q_dtype, scale, compute_dtype): - qx = quantize_to_fp8(x, q_dtype, compute_dtype, scale) - out = qx.astype(x.dtype) * jnp.broadcast_to( - jnp.asarray(scale, dtype=x.dtype), qx.shape - ) - return out - -def generate_quantized_tensors( - batch, lhs_non_contract, contract, rhs_non_contract, dtype=jnp.float32 - ): - cast_to_representable = partial( - quantize_dequantize_fp8, - scale=jnp.ones((1,)), - compute_dtype=dtype, - ) - - k1, k2 = jax.random.split(jax.random.key(123), 2) - - f8_dtype = jnp.float8_e4m3fn - - a = cast_to_representable( - jax.random.uniform( - k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype - ), - f8_dtype, - ) - b = cast_to_representable( - jax.random.uniform( - k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype - ), - f8_dtype, - ) - - dn = ((2,), (0,)) - a_3d = shape_normalization(a, dn) - b_3d = shape_normalization(b, dn) - a_q, a_scales = quantize(a, f8_dtype) - b_q, b_scales = quantize(b, f8_dtype) - - return a, b, a_q, b_q, a_scales, b_scales def shard_and_device_put( mesh, a_sharding, b_sharding, a, b, a_scales=None, b_scales=None @@ -126,19 +76,19 @@ def shard_and_device_put( return a, b, in_shardings -def get_hlo_text(in_shardings): +def get_hlo_text(in_shardings, block_scale_configs=mxfp8_configs): mesh_names = ("dp", "tp") devices = np.array(jax.local_devices()[:4]).reshape((2, 2)) mesh = Mesh(devices, mesh_names) _, _, a_q, b_q, a_scales, b_scales = generate_quantized_tensors( - 2, 512, 1024, 512 + 2, 512, 1024, 512, block_scale_configs, ) with mesh: a_q, b_q, a_scales, b_scales, in_shardings = shard_and_device_put( mesh, in_shardings[0], in_shardings[1], a_q, b_q, a_scales, b_scales ) - pjit_fn = jax.jit(scaled_matmul, in_shardings=in_shardings) + pjit_fn = jax.jit(scaled_matmul_wrapper, in_shardings=in_shardings) hlo = pjit_fn.lower(a_q, b_q, a_scales, b_scales).compile() return hlo.as_text() @@ -180,16 +130,20 @@ class ScaledMatmulTest(jtu.JaxTestCase): contract=[160, 96], lhs_non_contract=[240, 100], dtype=[jnp.float16, jnp.bfloat16, jnp.float32], + block_scale_configs=[mxfp8_configs,], ) @jtu.run_on_devices("cuda") - def test_scaled_matmul(self, contract, lhs_non_contract, dtype): + def test_scaled_matmul( + self, contract, lhs_non_contract, dtype, block_scale_configs, + ): batch, rhs_non_contract = 2, 128 a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors( - batch, lhs_non_contract, contract, rhs_non_contract, dtype=dtype + batch, lhs_non_contract, contract, rhs_non_contract, + block_scale_configs, dtype=dtype, ) def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type): - return scaled_matmul( + return scaled_matmul_wrapper( lhs, rhs, lhs_scales, @@ -216,15 +170,18 @@ class ScaledMatmulTest(jtu.JaxTestCase): out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3 ) - @jtu.sample_product(sharding_config=sharding_configs) + @jtu.sample_product( + sharding_config=sharding_configs, + block_scale_configs=[mxfp8_configs,], + ) @jtu.run_on_devices("cuda") - def test_scaled_matmul_sharded(self, sharding_config): + def test_scaled_matmul_sharded(self, sharding_config, block_scale_configs): if len(jax.local_devices()) < 4: self.skipTest("Require at least 4 devices to run sharding tests.") batch, contract, non_contract = 2, 1024, 256 a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors( - batch, non_contract, contract, non_contract + batch, non_contract, contract, non_contract, block_scale_configs, ) devices = np.array(jax.local_devices()[:4]) @@ -246,7 +203,7 @@ class ScaledMatmulTest(jtu.JaxTestCase): args = [a_q, b_q, a_scales, b_scales] j_scaled_matmul = jax.jit( - scaled_matmul, in_shardings=input_shardings + scaled_matmul_wrapper, in_shardings=input_shardings ) hlo_text = j_scaled_matmul.lower(*args).compile().as_text() hlo_pattern = re.compile( @@ -268,7 +225,9 @@ class ScaledMatmulTest(jtu.JaxTestCase): @jtu.with_config(jax_numpy_dtype_promotion="standard") -class Mxfp8DotGeneralTest(jtu.JaxTestCase): +class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase): + + block_scale_configs = mxfp8_configs def setUp(self): super().setUp() @@ -302,7 +261,7 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase): @jtu.run_on_devices("cuda") def test_dot_general(self, configs, output_type): cast_to_representable = partial( - quantize_dequantize_fp8, + quantize_dequantize, scale=jnp.ones((1,)), compute_dtype=jnp.float32, ) @@ -311,16 +270,21 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase): a_shape, b_shape, dimension_numbers, is_training = configs a = cast_to_representable( jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type), - jnp.float8_e4m3fn, + self.block_scale_configs[0].data_type, ) b = cast_to_representable( jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type), - jnp.float8_e4m3fn, + self.block_scale_configs[1].data_type, ) + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=self.block_scale_configs + ) def fwd(a, b, is_ref=False): - fn = jax.lax.dot_general if is_ref else mxfp8_dot_general - y = fn(a, b, dimension_numbers, preferred_element_type=output_type) + fn = jax.lax.dot_general if is_ref else scaled_dot_general + y = fn(a, b, dimension_numbers, + preferred_element_type=output_type) return jnp.sum(y) if is_training: @@ -349,7 +313,7 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase): self.skipTest("Require at least 4 devices to run sharding tests.") cast_to_representable = partial( - quantize_dequantize_fp8, + quantize_dequantize, scale=jnp.ones((1,)), compute_dtype=jnp.float32, ) @@ -360,14 +324,20 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase): k1, k2 = jax.random.split(jax.random.key(0), 2) a = cast_to_representable( - jax.random.uniform(k1, a_shape, minval=-1.0), jnp.float8_e4m3fn + jax.random.uniform(k1, a_shape, minval=-1.0), + self.block_scale_configs[0].data_type, ) b = cast_to_representable( - jax.random.uniform(k2, b_shape, minval=-1.0), jnp.float8_e4m3fn + jax.random.uniform(k2, b_shape, minval=-1.0), + self.block_scale_configs[1].data_type, ) + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=self.block_scale_configs + ) def fwd(a, b, is_ref=False): - fn = jax.lax.dot_general if is_ref else mxfp8_dot_general + fn = jax.lax.dot_general if is_ref else scaled_dot_general y = fn(a, b, dimension_numbers) # Use a little complex loss function to avoid constant grads, whose # sharding info might be optimized off and then cause issue with the @@ -415,7 +385,7 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase): @jtu.run_on_devices("cuda") def test_dot_general_vmap(self, configs): cast_to_representable = partial( - quantize_dequantize_fp8, + quantize_dequantize, scale=jnp.ones((1,)), compute_dtype=jnp.float32, ) @@ -426,15 +396,21 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase): dimension_numbers = (([1], [1]), ([], [])) a = cast_to_representable( - jax.random.uniform(k1, a_shape, minval=-1.0), jnp.float8_e4m3fn + jax.random.uniform(k1, a_shape, minval=-1.0), + self.block_scale_configs[0].data_type, ) b = cast_to_representable( - jax.random.uniform(k2, b_shape, minval=-1.0), jnp.float8_e4m3fn + jax.random.uniform(k2, b_shape, minval=-1.0), + self.block_scale_configs[1].data_type, ) + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=self.block_scale_configs + ) def fwd(a, b, is_ref=False): fn = jax.vmap( - jax.lax.dot_general if is_ref else mxfp8_dot_general, + jax.lax.dot_general if is_ref else scaled_dot_general, in_axes=(a_axis, b_axis, None), out_axes=o_axis, ) @@ -455,5 +431,3 @@ class Mxfp8DotGeneralTest(jtu.JaxTestCase): if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) - -