diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index bffd0f64c..1a8dee293 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -367,6 +367,7 @@ def scaled_matmul_wrapper( preferred_element_type = dtypes.canonicalize_dtype( np.dtype(preferred_element_type) ) + out = _scaled_matmul( lhs, rhs, @@ -374,6 +375,7 @@ def scaled_matmul_wrapper( rhs_scales, preferred_element_type=preferred_element_type, ) + return out def shape_normalization(x, dimension_numbers): @@ -489,12 +491,14 @@ def quantize(x, config): scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype) elif config.mode == "nvfp4": assert config.scale_type == jnp.float8_e4m3fn - # shuw(TODO): Add when XLA is ready and e2m1fn is available. - scales_q = scales - scales_x = x + assert config.global_scale.dtype == jnp.float32 + + scales = scales / config.global_scale + scales_q = jax.lax.optimization_barrier(scales.astype(jnp.float8_e4m3fn)) + scaled_x = x / (scales_q.astype(jnp.float32) * + config.global_scale).astype(x.dtype) else: raise ValueError(f"Unrecognized mode: {config.mode}.") - clipped_x = jnp.clip(scaled_x, -MAX, MAX) x_q = clipped_x.astype(config.data_type) @@ -504,10 +508,6 @@ def quantize(x, config): ) return x_q, scales_q - - - - def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type, configs): if preferred_element_type is None: @@ -529,10 +529,18 @@ def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type, lhs_q, lhs_scales = quantize(lhs_3d, lhs_config) rhs_q, rhs_scales = quantize(rhs_3d, rhs_config) + out_dtype = preferred_element_type + if configs[0].mode == 'nvfp4': + out_dtype = jnp.float32 + out = scaled_matmul_wrapper( - lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type + lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type=out_dtype ) + if configs[0].mode == 'nvfp4': + out *= (configs[0].global_scale * configs[1].global_scale) + out = out.astype(preferred_element_type) + expanded_out_shape = compute_dot_output_shape( lhs.shape, rhs.shape, lhs_dn, rhs_dn ) @@ -564,13 +572,15 @@ def scaled_dot_general_transpose_lhs( g_3d = shape_normalization(g, g_dn) g_config, y_config = configs[0], configs[1] + if configs[0].mode != 'nvfp4': + g_q, g_scales = quantize(g_3d, g_config) + y_q, y_scales = quantize(y_3d, y_config) - g_q, g_scales = quantize(g_3d, g_config) - y_q, y_scales = quantize(y_3d, y_config) - - out = scaled_matmul_wrapper( - g_q, y_q, g_scales, y_scales, preferred_element_type - ) + out = scaled_matmul_wrapper( + g_q, y_q, g_scales, y_scales, preferred_element_type + ) + else: + out = jnp.matmul(g_3d, jnp.permute_dims(y_3d, (0, 2, 1)), preferred_element_type=preferred_element_type) expanded_out_shape = compute_dot_output_shape(g.shape, y.shape, g_dn, y_dn) expanded_out = jnp.reshape(out, expanded_out_shape) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 0ad9bf94b..141839a19 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -67,13 +67,15 @@ expected_output_spec = [ ] sharding_configs = { input_sharding: (hlo, output_spec) - for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec) + for input_sharding, hlo, output_spec in zip(input_shardings, + expected_hlos, + expected_output_spec) } def quantize_to_qtype(x, q_dtype, compute_dtype, scale): # Explicitly cast the max values to the compute dtype to avoid unnecessary # casting to FP32 during the subsequent math operations." - assert q_dtype in (jnp.float8_e4m3fn, ) + assert q_dtype in (jnp.float8_e4m3fn, jnp.float4_e2m1fn) dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype) scaled_x = x / jnp.broadcast_to( jnp.asarray(scale, dtype=compute_dtype), x.shape @@ -153,6 +155,84 @@ def shard_and_device_put( return a, b, in_shardings +def create_nvfp4_configs(global_scale=None): + if _dtypes.float4_e2m1fn is None: + return None + g_one_scale = jnp.ones((1, ), dtype=jnp.float32) + nvfp4_config = BlockScaleConfig( + mode='nvfp4', + block_size=16, + data_type=jnp.float4_e2m1fn, + scale_type=jnp.float8_e4m3fn, + global_scale=g_one_scale if global_scale is None else global_scale, + infer_only=False + ) + + return [nvfp4_config for _ in range(3)] + +def update_global_scale(config, new_global_scale): + config.global_scale = new_global_scale + return config + +def generate_nvfp4_quantized_tensors(dot_config, output_type): + k1, k2 = jax.random.split(jax.random.key(0), 2) + + a_shape, b_shape, dimension_numbers = dot_config + (a_contract, b_contract), (a_batch, b_batch) = dimension_numbers + a_dn = (a_contract, a_batch) + b_dn = (b_contract, b_batch) + + a_raw = jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type) + b_raw = jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type) + a = shape_normalization(a_raw, a_dn) + b = shape_normalization(b_raw, b_dn) + + # Initialize NVFP4 configurations + block_scale_configs_nvfp4 = create_nvfp4_configs() + + # Compute maximum absolute values for scaling + amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32) + amax_b = jnp.max(jnp.abs(b)).astype(jnp.float32) + + # Update global scales + data_max = jnp.finfo(block_scale_configs_nvfp4[0].data_type).max.astype( + jnp.float32 + ) + scale_max = jnp.finfo(block_scale_configs_nvfp4[0].scale_type).max.astype( + jnp.float32 + ) + + block_scale_configs_nvfp4[0] = update_global_scale( + block_scale_configs_nvfp4[0], amax_a / (data_max * scale_max)) + block_scale_configs_nvfp4[1] = update_global_scale( + block_scale_configs_nvfp4[1], amax_b / (data_max * scale_max)) + + # Quantize tensors + a_nvfp4, a_scale = quantize(a, block_scale_configs_nvfp4[0]) + b_nvfp4, b_scale = quantize(b, block_scale_configs_nvfp4[1]) + + # Reshape and scale quantized tensors + def reshape_and_scale(x, scale, global_scale, bs, k): + reshaped = x.astype(output_type).reshape(*bs, k // 16, 16) + scaled = reshaped * jnp.expand_dims(scale.astype(output_type), -1) + return scaled.reshape(*bs, k) * global_scale.astype(output_type) + + *bs_a, k_a = a_nvfp4.shape + *bs_b, k_b = b_nvfp4.shape + assert k_a == k_b + + a_dequantized = reshape_and_scale( + a_nvfp4, a_scale, block_scale_configs_nvfp4[0].global_scale, bs_a, k_a) + b_dequantized = reshape_and_scale( + b_nvfp4, b_scale, block_scale_configs_nvfp4[1].global_scale, bs_b, k_b) + + return ( + (a_raw, b_raw), + (a_dequantized, b_dequantized), + (a_nvfp4, b_nvfp4, a_scale, b_scale), + block_scale_configs_nvfp4 + ) + def create_mxfp8_configs(): if _dtypes.float8_e8m0fnu is None: return None @@ -184,7 +264,6 @@ def get_hlo_text(in_shardings, block_scale_configs): hlo = pjit_fn.lower(a_q, b_q, a_scales, b_scales).compile() return hlo.as_text() - @jtu.with_config(jax_numpy_dtype_promotion="standard") class ScaledMatmulTest(jtu.JaxTestCase): @@ -197,6 +276,8 @@ class ScaledMatmulTest(jtu.JaxTestCase): return if _dtypes.float8_e8m0fnu is None: self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu") + if _dtypes.float4_e2m1fn is None: + self.skipTest("Requries >= ml_dtypes 0.5.0 to support float4_e2m1fn") if cudnn_version < 90700: self.skipTest("Requires >= cuDNN 9.7.0") if not jtu.is_cuda_compute_capability_at_least("10.0"): @@ -223,6 +304,57 @@ class ScaledMatmulTest(jtu.JaxTestCase): hlo_text, hlo_pattern, msg=f"Failed to find pattern: {expected_hlo}" ) + @jtu.sample_product( + contract=[160, 96], + lhs_non_contract=[240, 100], + dtype=[jnp.float32, jnp.bfloat16, jnp.float16], + ) + @jtu.run_on_devices("cuda") + def test_scaled_matmul_nvfp4( + self, contract, lhs_non_contract, dtype, + ): + batch, rhs_non_contract = 2, 128 + dot_config = ( + (batch, lhs_non_contract, contract), + (batch, rhs_non_contract, contract), + (([2], [2]), ([0], [0])) + ) + _, (a_dq, b_dq), (a_q, b_q, a_s, b_s), block_scale_configs = ( + generate_nvfp4_quantized_tensors(dot_config, dtype) + ) + a_gs = block_scale_configs[0].global_scale + b_gs = block_scale_configs[1].global_scale + + def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type): + out = scaled_matmul_wrapper( + lhs, + rhs, + lhs_scales, + rhs_scales, + preferred_element_type=jnp.float32, + ) + gs = a_gs * b_gs + return (out * gs).astype(out_type) + + j_scaled_matmul = jax.jit(partial(wrapper, out_type=dtype)) + hlo_text = ( + j_scaled_matmul.lower(a_q, b_q, a_s, b_s) + .compile() + .as_text() + ) + hlo_pattern = re.compile( + r".*".join([re.escape(x) for x in ("custom-call", c_name)]) + ) + self.assertRegex(hlo_text, hlo_pattern) + + out = j_scaled_matmul(a_q, b_q, a_s, b_s) + out_ref = jnp.einsum( + "BMK,BNK->BMN", a_dq, b_dq + ) + self.assertArraysAllClose( + out, out_ref.astype(dtype), rtol=1e-2, atol=5e-2 + ) + @jtu.sample_product( contract=[160, 96], lhs_non_contract=[240, 100], @@ -326,7 +458,7 @@ class ScaledMatmulTest(jtu.JaxTestCase): ) @jtu.with_config(jax_numpy_dtype_promotion="standard") -class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase): +class ScaledDotGeneralTest(jtu.JaxTestCase): def setUp(self): super().setUp() @@ -344,6 +476,106 @@ class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase): block_scale_configs = create_mxfp8_configs() + @jtu.sample_product( + shape=[ + (1, 128, 128), + (64, 32), + (1024, 2048), + ], + ) + @jtu.run_on_devices("cuda") + def test_quantize_nvfp4(self, shape): + # To test the q-dq logic is valid with XLA + output_type = jnp.float32 + k1, k2 = jax.random.split(jax.random.key(0), 2) + + a = jax.random.uniform(k1, shape, minval=-1.0, dtype=output_type) + + block_scale_configs_nvfp4 = create_nvfp4_configs() + data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(jnp.float32) + scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32) + amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32) / (data_max * scale_max) + block_scale_configs_nvfp4[0] = update_global_scale( + block_scale_configs_nvfp4[0], jnp.asarray(amax_a, jnp.float32) + ) + + def fn(a): + a_nvfp4, a_scale = quantize(a, block_scale_configs_nvfp4[0]) + return a_nvfp4, a_scale + + out_q, scale = jax.jit(fn)(a) + out_q_ref, scale_ref = fn(a) + self.assertArraysAllClose(out_q, out_q_ref, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(scale, scale_ref, rtol=1e-5, atol=1e-5) + + @jtu.sample_product( + configs=[ + # a_shape, b_shape, dimension_numbers, is_training + ((1, 128, 128), (1, 128, 128), (([2], [2]), ([0], [0])), False), + ((30, 64), (100, 64), (([1], [1]), ([], [])), False), + ((192, 96), (160, 96), (([1], [1]), ([], [])), True), + ((64, 128, 4), (128, 128), (([1], [0]), ([], [])), True), + ((1, 128, 1024), (1, 1024, 128), (([2], [1]), ([0], [0])), True), + ( + (1, 128, 128, 2), + (128, 1, 2, 128), + (([2], [0]), ([0, 3], [1, 2])), + True, + ), + ], + output_type=[jnp.float32, jnp.float16, jnp.bfloat16], + ) + @jtu.run_on_devices("cuda") + def test_dot_general_nvfp4(self, configs, output_type): + (a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = ( + generate_nvfp4_quantized_tensors(configs[:-1], output_type) + ) + a_gs = block_scale_configs[0].global_scale + b_gs = block_scale_configs[1].global_scale + + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=block_scale_configs + ) + + dimension_numbers = configs[2] + is_training = configs[-1] + def fwd(a, b, is_ref=False, use_normalized=False): + fn = jax.lax.dot_general if is_ref else scaled_dot_general + if is_ref and use_normalized: + dms = (([2], [2]), ([0], [0])) + else: + dms = dimension_numbers + + y = fn(a, b, dms, + preferred_element_type=output_type) + + return jnp.sum(y) if is_training else y + + if is_training: + j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1])) + j_train_ref = jax.jit( + jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]) + ) + j_train_fwd_ref = jax.jit( + jax.value_and_grad( + partial(fwd, is_ref=True, use_normalized=True), argnums=[0, 1] + ) + ) + out, (x_grad, w_grad) = j_train(a_raw, b_raw) + _, (x_grad_ref, w_grad_ref) = j_train_ref(a_raw, b_raw) + out_ref, _ = j_train_fwd_ref(a_dq, b_dq) + + self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) + self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) + else: + j_inference = jax.jit(fwd) + j_inference_ref = jax.jit(partial(fwd, is_ref=True, use_normalized=True)) + out = j_inference(a_raw, b_raw) + out_ref = jnp.reshape(j_inference_ref(a_dq, b_dq), out.shape) + self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=2e-1) + @jtu.sample_product( configs=[ # a_shape, b_shape, dimension_numbers, is_training