mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #26944 from wenscarl:wenscarl/nvfp4
PiperOrigin-RevId: 736620378
This commit is contained in:
commit
726f49cbca
@ -367,6 +367,7 @@ def scaled_matmul_wrapper(
|
||||
preferred_element_type = dtypes.canonicalize_dtype(
|
||||
np.dtype(preferred_element_type)
|
||||
)
|
||||
|
||||
out = _scaled_matmul(
|
||||
lhs,
|
||||
rhs,
|
||||
@ -374,6 +375,7 @@ def scaled_matmul_wrapper(
|
||||
rhs_scales,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def shape_normalization(x, dimension_numbers):
|
||||
@ -489,12 +491,14 @@ def quantize(x, config):
|
||||
scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype)
|
||||
elif config.mode == "nvfp4":
|
||||
assert config.scale_type == jnp.float8_e4m3fn
|
||||
# shuw(TODO): Add when XLA is ready and e2m1fn is available.
|
||||
scales_q = scales
|
||||
scales_x = x
|
||||
assert config.global_scale.dtype == jnp.float32
|
||||
|
||||
scales = scales / config.global_scale
|
||||
scales_q = jax.lax.optimization_barrier(scales.astype(jnp.float8_e4m3fn))
|
||||
scaled_x = x / (scales_q.astype(jnp.float32) *
|
||||
config.global_scale).astype(x.dtype)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized mode: {config.mode}.")
|
||||
|
||||
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
|
||||
x_q = clipped_x.astype(config.data_type)
|
||||
|
||||
@ -504,10 +508,6 @@ def quantize(x, config):
|
||||
)
|
||||
return x_q, scales_q
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
configs):
|
||||
if preferred_element_type is None:
|
||||
@ -529,10 +529,18 @@ def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
lhs_q, lhs_scales = quantize(lhs_3d, lhs_config)
|
||||
rhs_q, rhs_scales = quantize(rhs_3d, rhs_config)
|
||||
|
||||
out_dtype = preferred_element_type
|
||||
if configs[0].mode == 'nvfp4':
|
||||
out_dtype = jnp.float32
|
||||
|
||||
out = scaled_matmul_wrapper(
|
||||
lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type
|
||||
lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type=out_dtype
|
||||
)
|
||||
|
||||
if configs[0].mode == 'nvfp4':
|
||||
out *= (configs[0].global_scale * configs[1].global_scale)
|
||||
out = out.astype(preferred_element_type)
|
||||
|
||||
expanded_out_shape = compute_dot_output_shape(
|
||||
lhs.shape, rhs.shape, lhs_dn, rhs_dn
|
||||
)
|
||||
@ -564,13 +572,15 @@ def scaled_dot_general_transpose_lhs(
|
||||
g_3d = shape_normalization(g, g_dn)
|
||||
|
||||
g_config, y_config = configs[0], configs[1]
|
||||
if configs[0].mode != 'nvfp4':
|
||||
g_q, g_scales = quantize(g_3d, g_config)
|
||||
y_q, y_scales = quantize(y_3d, y_config)
|
||||
|
||||
g_q, g_scales = quantize(g_3d, g_config)
|
||||
y_q, y_scales = quantize(y_3d, y_config)
|
||||
|
||||
out = scaled_matmul_wrapper(
|
||||
g_q, y_q, g_scales, y_scales, preferred_element_type
|
||||
)
|
||||
out = scaled_matmul_wrapper(
|
||||
g_q, y_q, g_scales, y_scales, preferred_element_type
|
||||
)
|
||||
else:
|
||||
out = jnp.matmul(g_3d, jnp.permute_dims(y_3d, (0, 2, 1)), preferred_element_type=preferred_element_type)
|
||||
|
||||
expanded_out_shape = compute_dot_output_shape(g.shape, y.shape, g_dn, y_dn)
|
||||
expanded_out = jnp.reshape(out, expanded_out_shape)
|
||||
|
@ -67,13 +67,15 @@ expected_output_spec = [
|
||||
]
|
||||
sharding_configs = {
|
||||
input_sharding: (hlo, output_spec)
|
||||
for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec)
|
||||
for input_sharding, hlo, output_spec in zip(input_shardings,
|
||||
expected_hlos,
|
||||
expected_output_spec)
|
||||
}
|
||||
|
||||
def quantize_to_qtype(x, q_dtype, compute_dtype, scale):
|
||||
# Explicitly cast the max values to the compute dtype to avoid unnecessary
|
||||
# casting to FP32 during the subsequent math operations."
|
||||
assert q_dtype in (jnp.float8_e4m3fn, )
|
||||
assert q_dtype in (jnp.float8_e4m3fn, jnp.float4_e2m1fn)
|
||||
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
|
||||
scaled_x = x / jnp.broadcast_to(
|
||||
jnp.asarray(scale, dtype=compute_dtype), x.shape
|
||||
@ -153,6 +155,84 @@ def shard_and_device_put(
|
||||
|
||||
return a, b, in_shardings
|
||||
|
||||
def create_nvfp4_configs(global_scale=None):
|
||||
if _dtypes.float4_e2m1fn is None:
|
||||
return None
|
||||
g_one_scale = jnp.ones((1, ), dtype=jnp.float32)
|
||||
nvfp4_config = BlockScaleConfig(
|
||||
mode='nvfp4',
|
||||
block_size=16,
|
||||
data_type=jnp.float4_e2m1fn,
|
||||
scale_type=jnp.float8_e4m3fn,
|
||||
global_scale=g_one_scale if global_scale is None else global_scale,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
return [nvfp4_config for _ in range(3)]
|
||||
|
||||
def update_global_scale(config, new_global_scale):
|
||||
config.global_scale = new_global_scale
|
||||
return config
|
||||
|
||||
def generate_nvfp4_quantized_tensors(dot_config, output_type):
|
||||
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
||||
|
||||
a_shape, b_shape, dimension_numbers = dot_config
|
||||
(a_contract, b_contract), (a_batch, b_batch) = dimension_numbers
|
||||
a_dn = (a_contract, a_batch)
|
||||
b_dn = (b_contract, b_batch)
|
||||
|
||||
a_raw = jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type)
|
||||
b_raw = jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type)
|
||||
a = shape_normalization(a_raw, a_dn)
|
||||
b = shape_normalization(b_raw, b_dn)
|
||||
|
||||
# Initialize NVFP4 configurations
|
||||
block_scale_configs_nvfp4 = create_nvfp4_configs()
|
||||
|
||||
# Compute maximum absolute values for scaling
|
||||
amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32)
|
||||
amax_b = jnp.max(jnp.abs(b)).astype(jnp.float32)
|
||||
|
||||
# Update global scales
|
||||
data_max = jnp.finfo(block_scale_configs_nvfp4[0].data_type).max.astype(
|
||||
jnp.float32
|
||||
)
|
||||
scale_max = jnp.finfo(block_scale_configs_nvfp4[0].scale_type).max.astype(
|
||||
jnp.float32
|
||||
)
|
||||
|
||||
block_scale_configs_nvfp4[0] = update_global_scale(
|
||||
block_scale_configs_nvfp4[0], amax_a / (data_max * scale_max))
|
||||
block_scale_configs_nvfp4[1] = update_global_scale(
|
||||
block_scale_configs_nvfp4[1], amax_b / (data_max * scale_max))
|
||||
|
||||
# Quantize tensors
|
||||
a_nvfp4, a_scale = quantize(a, block_scale_configs_nvfp4[0])
|
||||
b_nvfp4, b_scale = quantize(b, block_scale_configs_nvfp4[1])
|
||||
|
||||
# Reshape and scale quantized tensors
|
||||
def reshape_and_scale(x, scale, global_scale, bs, k):
|
||||
reshaped = x.astype(output_type).reshape(*bs, k // 16, 16)
|
||||
scaled = reshaped * jnp.expand_dims(scale.astype(output_type), -1)
|
||||
return scaled.reshape(*bs, k) * global_scale.astype(output_type)
|
||||
|
||||
*bs_a, k_a = a_nvfp4.shape
|
||||
*bs_b, k_b = b_nvfp4.shape
|
||||
assert k_a == k_b
|
||||
|
||||
a_dequantized = reshape_and_scale(
|
||||
a_nvfp4, a_scale, block_scale_configs_nvfp4[0].global_scale, bs_a, k_a)
|
||||
b_dequantized = reshape_and_scale(
|
||||
b_nvfp4, b_scale, block_scale_configs_nvfp4[1].global_scale, bs_b, k_b)
|
||||
|
||||
return (
|
||||
(a_raw, b_raw),
|
||||
(a_dequantized, b_dequantized),
|
||||
(a_nvfp4, b_nvfp4, a_scale, b_scale),
|
||||
block_scale_configs_nvfp4
|
||||
)
|
||||
|
||||
def create_mxfp8_configs():
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
return None
|
||||
@ -184,7 +264,6 @@ def get_hlo_text(in_shardings, block_scale_configs):
|
||||
hlo = pjit_fn.lower(a_q, b_q, a_scales, b_scales).compile()
|
||||
return hlo.as_text()
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
||||
class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
|
||||
@ -197,6 +276,8 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
return
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
|
||||
if _dtypes.float4_e2m1fn is None:
|
||||
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float4_e2m1fn")
|
||||
if cudnn_version < 90700:
|
||||
self.skipTest("Requires >= cuDNN 9.7.0")
|
||||
if not jtu.is_cuda_compute_capability_at_least("10.0"):
|
||||
@ -223,6 +304,57 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
hlo_text, hlo_pattern, msg=f"Failed to find pattern: {expected_hlo}"
|
||||
)
|
||||
|
||||
@jtu.sample_product(
|
||||
contract=[160, 96],
|
||||
lhs_non_contract=[240, 100],
|
||||
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_scaled_matmul_nvfp4(
|
||||
self, contract, lhs_non_contract, dtype,
|
||||
):
|
||||
batch, rhs_non_contract = 2, 128
|
||||
dot_config = (
|
||||
(batch, lhs_non_contract, contract),
|
||||
(batch, rhs_non_contract, contract),
|
||||
(([2], [2]), ([0], [0]))
|
||||
)
|
||||
_, (a_dq, b_dq), (a_q, b_q, a_s, b_s), block_scale_configs = (
|
||||
generate_nvfp4_quantized_tensors(dot_config, dtype)
|
||||
)
|
||||
a_gs = block_scale_configs[0].global_scale
|
||||
b_gs = block_scale_configs[1].global_scale
|
||||
|
||||
def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
|
||||
out = scaled_matmul_wrapper(
|
||||
lhs,
|
||||
rhs,
|
||||
lhs_scales,
|
||||
rhs_scales,
|
||||
preferred_element_type=jnp.float32,
|
||||
)
|
||||
gs = a_gs * b_gs
|
||||
return (out * gs).astype(out_type)
|
||||
|
||||
j_scaled_matmul = jax.jit(partial(wrapper, out_type=dtype))
|
||||
hlo_text = (
|
||||
j_scaled_matmul.lower(a_q, b_q, a_s, b_s)
|
||||
.compile()
|
||||
.as_text()
|
||||
)
|
||||
hlo_pattern = re.compile(
|
||||
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
|
||||
)
|
||||
self.assertRegex(hlo_text, hlo_pattern)
|
||||
|
||||
out = j_scaled_matmul(a_q, b_q, a_s, b_s)
|
||||
out_ref = jnp.einsum(
|
||||
"BMK,BNK->BMN", a_dq, b_dq
|
||||
)
|
||||
self.assertArraysAllClose(
|
||||
out, out_ref.astype(dtype), rtol=1e-2, atol=5e-2
|
||||
)
|
||||
|
||||
@jtu.sample_product(
|
||||
contract=[160, 96],
|
||||
lhs_non_contract=[240, 100],
|
||||
@ -326,7 +458,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
||||
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
class ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -344,6 +476,106 @@ class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
|
||||
block_scale_configs = create_mxfp8_configs()
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[
|
||||
(1, 128, 128),
|
||||
(64, 32),
|
||||
(1024, 2048),
|
||||
],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_quantize_nvfp4(self, shape):
|
||||
# To test the q-dq logic is valid with XLA
|
||||
output_type = jnp.float32
|
||||
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
||||
|
||||
a = jax.random.uniform(k1, shape, minval=-1.0, dtype=output_type)
|
||||
|
||||
block_scale_configs_nvfp4 = create_nvfp4_configs()
|
||||
data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(jnp.float32)
|
||||
scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32)
|
||||
amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32) / (data_max * scale_max)
|
||||
block_scale_configs_nvfp4[0] = update_global_scale(
|
||||
block_scale_configs_nvfp4[0], jnp.asarray(amax_a, jnp.float32)
|
||||
)
|
||||
|
||||
def fn(a):
|
||||
a_nvfp4, a_scale = quantize(a, block_scale_configs_nvfp4[0])
|
||||
return a_nvfp4, a_scale
|
||||
|
||||
out_q, scale = jax.jit(fn)(a)
|
||||
out_q_ref, scale_ref = fn(a)
|
||||
self.assertArraysAllClose(out_q, out_q_ref, rtol=1e-5, atol=1e-5)
|
||||
self.assertArraysAllClose(scale, scale_ref, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@jtu.sample_product(
|
||||
configs=[
|
||||
# a_shape, b_shape, dimension_numbers, is_training
|
||||
((1, 128, 128), (1, 128, 128), (([2], [2]), ([0], [0])), False),
|
||||
((30, 64), (100, 64), (([1], [1]), ([], [])), False),
|
||||
((192, 96), (160, 96), (([1], [1]), ([], [])), True),
|
||||
((64, 128, 4), (128, 128), (([1], [0]), ([], [])), True),
|
||||
((1, 128, 1024), (1, 1024, 128), (([2], [1]), ([0], [0])), True),
|
||||
(
|
||||
(1, 128, 128, 2),
|
||||
(128, 1, 2, 128),
|
||||
(([2], [0]), ([0, 3], [1, 2])),
|
||||
True,
|
||||
),
|
||||
],
|
||||
output_type=[jnp.float32, jnp.float16, jnp.bfloat16],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_dot_general_nvfp4(self, configs, output_type):
|
||||
(a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = (
|
||||
generate_nvfp4_quantized_tensors(configs[:-1], output_type)
|
||||
)
|
||||
a_gs = block_scale_configs[0].global_scale
|
||||
b_gs = block_scale_configs[1].global_scale
|
||||
|
||||
scaled_dot_general = partial(
|
||||
scaled_dot_general_wrapper,
|
||||
configs=block_scale_configs
|
||||
)
|
||||
|
||||
dimension_numbers = configs[2]
|
||||
is_training = configs[-1]
|
||||
def fwd(a, b, is_ref=False, use_normalized=False):
|
||||
fn = jax.lax.dot_general if is_ref else scaled_dot_general
|
||||
if is_ref and use_normalized:
|
||||
dms = (([2], [2]), ([0], [0]))
|
||||
else:
|
||||
dms = dimension_numbers
|
||||
|
||||
y = fn(a, b, dms,
|
||||
preferred_element_type=output_type)
|
||||
|
||||
return jnp.sum(y) if is_training else y
|
||||
|
||||
if is_training:
|
||||
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
|
||||
j_train_ref = jax.jit(
|
||||
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
|
||||
)
|
||||
j_train_fwd_ref = jax.jit(
|
||||
jax.value_and_grad(
|
||||
partial(fwd, is_ref=True, use_normalized=True), argnums=[0, 1]
|
||||
)
|
||||
)
|
||||
out, (x_grad, w_grad) = j_train(a_raw, b_raw)
|
||||
_, (x_grad_ref, w_grad_ref) = j_train_ref(a_raw, b_raw)
|
||||
out_ref, _ = j_train_fwd_ref(a_dq, b_dq)
|
||||
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
||||
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
||||
else:
|
||||
j_inference = jax.jit(fwd)
|
||||
j_inference_ref = jax.jit(partial(fwd, is_ref=True, use_normalized=True))
|
||||
out = j_inference(a_raw, b_raw)
|
||||
out_ref = jnp.reshape(j_inference_ref(a_dq, b_dq), out.shape)
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=2e-1)
|
||||
|
||||
@jtu.sample_product(
|
||||
configs=[
|
||||
# a_shape, b_shape, dimension_numbers, is_training
|
||||
|
Loading…
x
Reference in New Issue
Block a user