Merge pull request #26944 from wenscarl:wenscarl/nvfp4

PiperOrigin-RevId: 736620378
This commit is contained in:
jax authors 2025-03-13 13:30:46 -07:00
commit 726f49cbca
2 changed files with 261 additions and 19 deletions

View File

@ -367,6 +367,7 @@ def scaled_matmul_wrapper(
preferred_element_type = dtypes.canonicalize_dtype(
np.dtype(preferred_element_type)
)
out = _scaled_matmul(
lhs,
rhs,
@ -374,6 +375,7 @@ def scaled_matmul_wrapper(
rhs_scales,
preferred_element_type=preferred_element_type,
)
return out
def shape_normalization(x, dimension_numbers):
@ -489,12 +491,14 @@ def quantize(x, config):
scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype)
elif config.mode == "nvfp4":
assert config.scale_type == jnp.float8_e4m3fn
# shuw(TODO): Add when XLA is ready and e2m1fn is available.
scales_q = scales
scales_x = x
assert config.global_scale.dtype == jnp.float32
scales = scales / config.global_scale
scales_q = jax.lax.optimization_barrier(scales.astype(jnp.float8_e4m3fn))
scaled_x = x / (scales_q.astype(jnp.float32) *
config.global_scale).astype(x.dtype)
else:
raise ValueError(f"Unrecognized mode: {config.mode}.")
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
x_q = clipped_x.astype(config.data_type)
@ -504,10 +508,6 @@ def quantize(x, config):
)
return x_q, scales_q
def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
configs):
if preferred_element_type is None:
@ -529,10 +529,18 @@ def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
lhs_q, lhs_scales = quantize(lhs_3d, lhs_config)
rhs_q, rhs_scales = quantize(rhs_3d, rhs_config)
out_dtype = preferred_element_type
if configs[0].mode == 'nvfp4':
out_dtype = jnp.float32
out = scaled_matmul_wrapper(
lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type
lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type=out_dtype
)
if configs[0].mode == 'nvfp4':
out *= (configs[0].global_scale * configs[1].global_scale)
out = out.astype(preferred_element_type)
expanded_out_shape = compute_dot_output_shape(
lhs.shape, rhs.shape, lhs_dn, rhs_dn
)
@ -564,13 +572,15 @@ def scaled_dot_general_transpose_lhs(
g_3d = shape_normalization(g, g_dn)
g_config, y_config = configs[0], configs[1]
if configs[0].mode != 'nvfp4':
g_q, g_scales = quantize(g_3d, g_config)
y_q, y_scales = quantize(y_3d, y_config)
g_q, g_scales = quantize(g_3d, g_config)
y_q, y_scales = quantize(y_3d, y_config)
out = scaled_matmul_wrapper(
g_q, y_q, g_scales, y_scales, preferred_element_type
)
out = scaled_matmul_wrapper(
g_q, y_q, g_scales, y_scales, preferred_element_type
)
else:
out = jnp.matmul(g_3d, jnp.permute_dims(y_3d, (0, 2, 1)), preferred_element_type=preferred_element_type)
expanded_out_shape = compute_dot_output_shape(g.shape, y.shape, g_dn, y_dn)
expanded_out = jnp.reshape(out, expanded_out_shape)

View File

@ -67,13 +67,15 @@ expected_output_spec = [
]
sharding_configs = {
input_sharding: (hlo, output_spec)
for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec)
for input_sharding, hlo, output_spec in zip(input_shardings,
expected_hlos,
expected_output_spec)
}
def quantize_to_qtype(x, q_dtype, compute_dtype, scale):
# Explicitly cast the max values to the compute dtype to avoid unnecessary
# casting to FP32 during the subsequent math operations."
assert q_dtype in (jnp.float8_e4m3fn, )
assert q_dtype in (jnp.float8_e4m3fn, jnp.float4_e2m1fn)
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
scaled_x = x / jnp.broadcast_to(
jnp.asarray(scale, dtype=compute_dtype), x.shape
@ -153,6 +155,84 @@ def shard_and_device_put(
return a, b, in_shardings
def create_nvfp4_configs(global_scale=None):
if _dtypes.float4_e2m1fn is None:
return None
g_one_scale = jnp.ones((1, ), dtype=jnp.float32)
nvfp4_config = BlockScaleConfig(
mode='nvfp4',
block_size=16,
data_type=jnp.float4_e2m1fn,
scale_type=jnp.float8_e4m3fn,
global_scale=g_one_scale if global_scale is None else global_scale,
infer_only=False
)
return [nvfp4_config for _ in range(3)]
def update_global_scale(config, new_global_scale):
config.global_scale = new_global_scale
return config
def generate_nvfp4_quantized_tensors(dot_config, output_type):
k1, k2 = jax.random.split(jax.random.key(0), 2)
a_shape, b_shape, dimension_numbers = dot_config
(a_contract, b_contract), (a_batch, b_batch) = dimension_numbers
a_dn = (a_contract, a_batch)
b_dn = (b_contract, b_batch)
a_raw = jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type)
b_raw = jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type)
a = shape_normalization(a_raw, a_dn)
b = shape_normalization(b_raw, b_dn)
# Initialize NVFP4 configurations
block_scale_configs_nvfp4 = create_nvfp4_configs()
# Compute maximum absolute values for scaling
amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32)
amax_b = jnp.max(jnp.abs(b)).astype(jnp.float32)
# Update global scales
data_max = jnp.finfo(block_scale_configs_nvfp4[0].data_type).max.astype(
jnp.float32
)
scale_max = jnp.finfo(block_scale_configs_nvfp4[0].scale_type).max.astype(
jnp.float32
)
block_scale_configs_nvfp4[0] = update_global_scale(
block_scale_configs_nvfp4[0], amax_a / (data_max * scale_max))
block_scale_configs_nvfp4[1] = update_global_scale(
block_scale_configs_nvfp4[1], amax_b / (data_max * scale_max))
# Quantize tensors
a_nvfp4, a_scale = quantize(a, block_scale_configs_nvfp4[0])
b_nvfp4, b_scale = quantize(b, block_scale_configs_nvfp4[1])
# Reshape and scale quantized tensors
def reshape_and_scale(x, scale, global_scale, bs, k):
reshaped = x.astype(output_type).reshape(*bs, k // 16, 16)
scaled = reshaped * jnp.expand_dims(scale.astype(output_type), -1)
return scaled.reshape(*bs, k) * global_scale.astype(output_type)
*bs_a, k_a = a_nvfp4.shape
*bs_b, k_b = b_nvfp4.shape
assert k_a == k_b
a_dequantized = reshape_and_scale(
a_nvfp4, a_scale, block_scale_configs_nvfp4[0].global_scale, bs_a, k_a)
b_dequantized = reshape_and_scale(
b_nvfp4, b_scale, block_scale_configs_nvfp4[1].global_scale, bs_b, k_b)
return (
(a_raw, b_raw),
(a_dequantized, b_dequantized),
(a_nvfp4, b_nvfp4, a_scale, b_scale),
block_scale_configs_nvfp4
)
def create_mxfp8_configs():
if _dtypes.float8_e8m0fnu is None:
return None
@ -184,7 +264,6 @@ def get_hlo_text(in_shardings, block_scale_configs):
hlo = pjit_fn.lower(a_q, b_q, a_scales, b_scales).compile()
return hlo.as_text()
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class ScaledMatmulTest(jtu.JaxTestCase):
@ -197,6 +276,8 @@ class ScaledMatmulTest(jtu.JaxTestCase):
return
if _dtypes.float8_e8m0fnu is None:
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
if _dtypes.float4_e2m1fn is None:
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float4_e2m1fn")
if cudnn_version < 90700:
self.skipTest("Requires >= cuDNN 9.7.0")
if not jtu.is_cuda_compute_capability_at_least("10.0"):
@ -223,6 +304,57 @@ class ScaledMatmulTest(jtu.JaxTestCase):
hlo_text, hlo_pattern, msg=f"Failed to find pattern: {expected_hlo}"
)
@jtu.sample_product(
contract=[160, 96],
lhs_non_contract=[240, 100],
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
)
@jtu.run_on_devices("cuda")
def test_scaled_matmul_nvfp4(
self, contract, lhs_non_contract, dtype,
):
batch, rhs_non_contract = 2, 128
dot_config = (
(batch, lhs_non_contract, contract),
(batch, rhs_non_contract, contract),
(([2], [2]), ([0], [0]))
)
_, (a_dq, b_dq), (a_q, b_q, a_s, b_s), block_scale_configs = (
generate_nvfp4_quantized_tensors(dot_config, dtype)
)
a_gs = block_scale_configs[0].global_scale
b_gs = block_scale_configs[1].global_scale
def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
out = scaled_matmul_wrapper(
lhs,
rhs,
lhs_scales,
rhs_scales,
preferred_element_type=jnp.float32,
)
gs = a_gs * b_gs
return (out * gs).astype(out_type)
j_scaled_matmul = jax.jit(partial(wrapper, out_type=dtype))
hlo_text = (
j_scaled_matmul.lower(a_q, b_q, a_s, b_s)
.compile()
.as_text()
)
hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
)
self.assertRegex(hlo_text, hlo_pattern)
out = j_scaled_matmul(a_q, b_q, a_s, b_s)
out_ref = jnp.einsum(
"BMK,BNK->BMN", a_dq, b_dq
)
self.assertArraysAllClose(
out, out_ref.astype(dtype), rtol=1e-2, atol=5e-2
)
@jtu.sample_product(
contract=[160, 96],
lhs_non_contract=[240, 100],
@ -326,7 +458,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
)
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
class ScaledDotGeneralTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
@ -344,6 +476,106 @@ class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
block_scale_configs = create_mxfp8_configs()
@jtu.sample_product(
shape=[
(1, 128, 128),
(64, 32),
(1024, 2048),
],
)
@jtu.run_on_devices("cuda")
def test_quantize_nvfp4(self, shape):
# To test the q-dq logic is valid with XLA
output_type = jnp.float32
k1, k2 = jax.random.split(jax.random.key(0), 2)
a = jax.random.uniform(k1, shape, minval=-1.0, dtype=output_type)
block_scale_configs_nvfp4 = create_nvfp4_configs()
data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(jnp.float32)
scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32)
amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32) / (data_max * scale_max)
block_scale_configs_nvfp4[0] = update_global_scale(
block_scale_configs_nvfp4[0], jnp.asarray(amax_a, jnp.float32)
)
def fn(a):
a_nvfp4, a_scale = quantize(a, block_scale_configs_nvfp4[0])
return a_nvfp4, a_scale
out_q, scale = jax.jit(fn)(a)
out_q_ref, scale_ref = fn(a)
self.assertArraysAllClose(out_q, out_q_ref, rtol=1e-5, atol=1e-5)
self.assertArraysAllClose(scale, scale_ref, rtol=1e-5, atol=1e-5)
@jtu.sample_product(
configs=[
# a_shape, b_shape, dimension_numbers, is_training
((1, 128, 128), (1, 128, 128), (([2], [2]), ([0], [0])), False),
((30, 64), (100, 64), (([1], [1]), ([], [])), False),
((192, 96), (160, 96), (([1], [1]), ([], [])), True),
((64, 128, 4), (128, 128), (([1], [0]), ([], [])), True),
((1, 128, 1024), (1, 1024, 128), (([2], [1]), ([0], [0])), True),
(
(1, 128, 128, 2),
(128, 1, 2, 128),
(([2], [0]), ([0, 3], [1, 2])),
True,
),
],
output_type=[jnp.float32, jnp.float16, jnp.bfloat16],
)
@jtu.run_on_devices("cuda")
def test_dot_general_nvfp4(self, configs, output_type):
(a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = (
generate_nvfp4_quantized_tensors(configs[:-1], output_type)
)
a_gs = block_scale_configs[0].global_scale
b_gs = block_scale_configs[1].global_scale
scaled_dot_general = partial(
scaled_dot_general_wrapper,
configs=block_scale_configs
)
dimension_numbers = configs[2]
is_training = configs[-1]
def fwd(a, b, is_ref=False, use_normalized=False):
fn = jax.lax.dot_general if is_ref else scaled_dot_general
if is_ref and use_normalized:
dms = (([2], [2]), ([0], [0]))
else:
dms = dimension_numbers
y = fn(a, b, dms,
preferred_element_type=output_type)
return jnp.sum(y) if is_training else y
if is_training:
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
j_train_ref = jax.jit(
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
)
j_train_fwd_ref = jax.jit(
jax.value_and_grad(
partial(fwd, is_ref=True, use_normalized=True), argnums=[0, 1]
)
)
out, (x_grad, w_grad) = j_train(a_raw, b_raw)
_, (x_grad_ref, w_grad_ref) = j_train_ref(a_raw, b_raw)
out_ref, _ = j_train_fwd_ref(a_dq, b_dq)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
else:
j_inference = jax.jit(fwd)
j_inference_ref = jax.jit(partial(fwd, is_ref=True, use_normalized=True))
out = j_inference(a_raw, b_raw)
out_ref = jnp.reshape(j_inference_ref(a_dq, b_dq), out.shape)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=2e-1)
@jtu.sample_product(
configs=[
# a_shape, b_shape, dimension_numbers, is_training