mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 19:06:07 +00:00
769 lines
25 KiB
Python
769 lines
25 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
from absl.testing import absltest
|
|
|
|
import re
|
|
import numpy as np
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax.sharding import Mesh
|
|
from jax.sharding import PartitionSpec, NamedSharding
|
|
from jax._src import config
|
|
from jax._src import dtypes as _dtypes
|
|
from jax._src import test_util as jtu
|
|
from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
|
|
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
|
scaled_matmul_wrapper,
|
|
scaled_dot_general_wrapper,
|
|
shape_normalization,
|
|
quantize,
|
|
BlockScaleConfig,
|
|
)
|
|
|
|
config.parse_flags_with_absl()
|
|
input_shardings = [
|
|
(("dp", None, "tp"), ("dp", None, "tp")),
|
|
(("dp", None, "tp"), ("dp", None, None)),
|
|
(("dp", None, "tp"), ("dp", "tp", None)),
|
|
(("dp", None, None), ("dp", "tp", None)),
|
|
(("dp", "tp", None), ("dp", "tp", None)),
|
|
((None, "dp", "tp"), (None, "dp", "tp")),
|
|
((None, "tp", None), (None, "tp", None)),
|
|
((None, None, "tp"), (None, "tp", None)),
|
|
]
|
|
c_name = "__cudnn$blockScaledDot"
|
|
expected_hlos = [
|
|
(c_name, "all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"),
|
|
("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name),
|
|
("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name),
|
|
(c_name,),
|
|
("all-gather", "f8e4m3fn[1,256,1024]", "replica_groups=[2,2]<=[4]", c_name),
|
|
(c_name, "reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}"),
|
|
("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name),
|
|
("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name),
|
|
]
|
|
expected_output_spec = [
|
|
PartitionSpec('dp',),
|
|
PartitionSpec('dp',),
|
|
PartitionSpec('dp', None, 'tp'),
|
|
PartitionSpec('dp', None, 'tp'),
|
|
PartitionSpec('dp', 'tp', None),
|
|
PartitionSpec(None, 'dp', 'tp'),
|
|
PartitionSpec(None, 'tp', None),
|
|
PartitionSpec(None, None, 'tp'),
|
|
]
|
|
sharding_configs = {
|
|
input_sharding: (hlo, output_spec)
|
|
for input_sharding, hlo, output_spec in zip(input_shardings,
|
|
expected_hlos,
|
|
expected_output_spec)
|
|
}
|
|
|
|
def quantize_to_qtype(x, q_dtype, compute_dtype, scale):
|
|
# Explicitly cast the max values to the compute dtype to avoid unnecessary
|
|
# casting to FP32 during the subsequent math operations."
|
|
assert q_dtype in (jnp.float8_e4m3fn, jnp.float4_e2m1fn)
|
|
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
|
|
scaled_x = x / jnp.broadcast_to(
|
|
jnp.asarray(scale, dtype=compute_dtype), x.shape
|
|
)
|
|
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
|
|
return clipped_x.astype(q_dtype)
|
|
|
|
def quantize_dequantize(x, q_dtype, scale, compute_dtype):
|
|
qx = quantize_to_qtype(x, q_dtype, compute_dtype, scale)
|
|
out = qx.astype(x.dtype) * jnp.broadcast_to(
|
|
jnp.asarray(scale, dtype=x.dtype), qx.shape
|
|
)
|
|
return out
|
|
|
|
def generate_quantized_tensors(
|
|
batch, lhs_non_contract, contract, rhs_non_contract,
|
|
configs, dtype=jnp.float32,
|
|
):
|
|
cast_to_representable = partial(
|
|
quantize_dequantize,
|
|
scale=jnp.ones((1,)),
|
|
compute_dtype=dtype,
|
|
)
|
|
|
|
k1, k2 = jax.random.split(jax.random.key(123), 2)
|
|
|
|
a = cast_to_representable(
|
|
jax.random.uniform(
|
|
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
|
|
),
|
|
configs[0].data_type,
|
|
)
|
|
b = cast_to_representable(
|
|
jax.random.uniform(
|
|
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
|
|
),
|
|
configs[1].data_type,
|
|
)
|
|
|
|
dn = ((2,), (0,))
|
|
a_3d = shape_normalization(a, dn)
|
|
b_3d = shape_normalization(b, dn)
|
|
a_q, a_scales = quantize(a, configs[0])
|
|
b_q, b_scales = quantize(b, configs[1])
|
|
|
|
return a, b, a_q, b_q, a_scales, b_scales
|
|
|
|
|
|
def shard_and_device_put(
|
|
mesh, a_sharding, b_sharding, a, b, a_scales=None, b_scales=None
|
|
):
|
|
a_spec = PartitionSpec(*a_sharding)
|
|
b_spec = PartitionSpec(*b_sharding)
|
|
|
|
a_named_sharding = NamedSharding(mesh, a_spec)
|
|
b_named_sharding = NamedSharding(mesh, b_spec)
|
|
|
|
a = jax.device_put(a, a_named_sharding)
|
|
b = jax.device_put(b, b_named_sharding)
|
|
if a_scales is not None:
|
|
a_scales = jax.device_put(a_scales, a_named_sharding)
|
|
if b_scales is not None:
|
|
b_scales = jax.device_put(b_scales, b_named_sharding)
|
|
|
|
in_shardings = (
|
|
a_named_sharding,
|
|
b_named_sharding,
|
|
)
|
|
if a_scales is not None and b_scales is not None:
|
|
in_shardings = (
|
|
a_named_sharding,
|
|
b_named_sharding,
|
|
a_named_sharding,
|
|
b_named_sharding,
|
|
)
|
|
return a, b, a_scales, b_scales, in_shardings
|
|
|
|
return a, b, in_shardings
|
|
|
|
def create_nvfp4_configs(global_scale=None):
|
|
if _dtypes.float4_e2m1fn is None:
|
|
return None
|
|
g_one_scale = jnp.ones((1, ), dtype=jnp.float32)
|
|
nvfp4_config = BlockScaleConfig(
|
|
mode='nvfp4',
|
|
block_size=16,
|
|
data_type=jnp.float4_e2m1fn,
|
|
scale_type=jnp.float8_e4m3fn,
|
|
global_scale=g_one_scale if global_scale is None else global_scale,
|
|
infer_only=False
|
|
)
|
|
|
|
return [nvfp4_config for _ in range(3)]
|
|
|
|
def update_global_scale(config, new_global_scale):
|
|
config.global_scale = new_global_scale
|
|
return config
|
|
|
|
def generate_nvfp4_quantized_tensors(dot_config, output_type):
|
|
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
|
|
|
a_shape, b_shape, dimension_numbers = dot_config
|
|
(a_contract, b_contract), (a_batch, b_batch) = dimension_numbers
|
|
a_dn = (a_contract, a_batch)
|
|
b_dn = (b_contract, b_batch)
|
|
|
|
a_raw = jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type)
|
|
b_raw = jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type)
|
|
a = shape_normalization(a_raw, a_dn)
|
|
b = shape_normalization(b_raw, b_dn)
|
|
|
|
# Initialize NVFP4 configurations
|
|
block_scale_configs_nvfp4 = create_nvfp4_configs()
|
|
|
|
# Compute maximum absolute values for scaling
|
|
amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32)
|
|
amax_b = jnp.max(jnp.abs(b)).astype(jnp.float32)
|
|
|
|
# Update global scales
|
|
data_max = jnp.finfo(block_scale_configs_nvfp4[0].data_type).max.astype(
|
|
jnp.float32
|
|
)
|
|
scale_max = jnp.finfo(block_scale_configs_nvfp4[0].scale_type).max.astype(
|
|
jnp.float32
|
|
)
|
|
|
|
block_scale_configs_nvfp4[0] = update_global_scale(
|
|
block_scale_configs_nvfp4[0], amax_a / (data_max * scale_max))
|
|
block_scale_configs_nvfp4[1] = update_global_scale(
|
|
block_scale_configs_nvfp4[1], amax_b / (data_max * scale_max))
|
|
|
|
# Quantize tensors
|
|
a_nvfp4, a_scale = quantize(a, block_scale_configs_nvfp4[0])
|
|
b_nvfp4, b_scale = quantize(b, block_scale_configs_nvfp4[1])
|
|
|
|
# Reshape and scale quantized tensors
|
|
def reshape_and_scale(x, scale, global_scale, bs, k):
|
|
reshaped = x.astype(output_type).reshape(*bs, k // 16, 16)
|
|
scaled = reshaped * jnp.expand_dims(scale.astype(output_type), -1)
|
|
return scaled.reshape(*bs, k) * global_scale.astype(output_type)
|
|
|
|
*bs_a, k_a = a_nvfp4.shape
|
|
*bs_b, k_b = b_nvfp4.shape
|
|
assert k_a == k_b
|
|
|
|
a_dequantized = reshape_and_scale(
|
|
a_nvfp4, a_scale, block_scale_configs_nvfp4[0].global_scale, bs_a, k_a)
|
|
b_dequantized = reshape_and_scale(
|
|
b_nvfp4, b_scale, block_scale_configs_nvfp4[1].global_scale, bs_b, k_b)
|
|
|
|
return (
|
|
(a_raw, b_raw),
|
|
(a_dequantized, b_dequantized),
|
|
(a_nvfp4, b_nvfp4, a_scale, b_scale),
|
|
block_scale_configs_nvfp4
|
|
)
|
|
|
|
def create_mxfp8_configs():
|
|
if _dtypes.float8_e8m0fnu is None:
|
|
return None
|
|
|
|
mxfp8_config = BlockScaleConfig(
|
|
mode='mxfp8',
|
|
block_size=32,
|
|
data_type=jnp.float8_e4m3fn,
|
|
scale_type=jnp.float8_e8m0fnu,
|
|
global_scale=None,
|
|
infer_only=False
|
|
)
|
|
|
|
return [mxfp8_config for _ in range(3)]
|
|
|
|
def get_hlo_text(in_shardings, block_scale_configs):
|
|
mesh_names = ("dp", "tp")
|
|
devices = np.array(jax.local_devices()[:4]).reshape((2, 2))
|
|
mesh = Mesh(devices, mesh_names)
|
|
_, _, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
|
2, 512, 1024, 512, block_scale_configs,
|
|
)
|
|
|
|
with mesh:
|
|
a_q, b_q, a_scales, b_scales, in_shardings = shard_and_device_put(
|
|
mesh, in_shardings[0], in_shardings[1], a_q, b_q, a_scales, b_scales
|
|
)
|
|
pjit_fn = jax.jit(scaled_matmul_wrapper, in_shardings=in_shardings)
|
|
hlo = pjit_fn.lower(a_q, b_q, a_scales, b_scales).compile()
|
|
return hlo.as_text()
|
|
|
|
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
class ScaledMatmulTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
try:
|
|
cudnn_version = check_cudnn_version()
|
|
except RuntimeError as e:
|
|
self.skipTest(str(e))
|
|
return
|
|
if _dtypes.float8_e8m0fnu is None:
|
|
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
|
|
if _dtypes.float4_e2m1fn is None:
|
|
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float4_e2m1fn")
|
|
if cudnn_version < 90700:
|
|
self.skipTest("Requires >= cuDNN 9.7.0")
|
|
if not jtu.is_cuda_compute_capability_at_least("10.0"):
|
|
self.skipTest("Requires at least Blackwell arch")
|
|
|
|
mxfp8_configs = create_mxfp8_configs()
|
|
|
|
@jtu.sample_product(
|
|
in_shardings=sharding_configs,
|
|
block_scale_configs=[mxfp8_configs,],
|
|
)
|
|
@jtu.run_on_devices("cuda")
|
|
def test_collectives(self, in_shardings, block_scale_configs):
|
|
if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4:
|
|
self.skipTest("Partition Test enabled for at least 4 GPUs")
|
|
|
|
expected_hlo = sharding_configs[in_shardings][0]
|
|
hlo_text = get_hlo_text(in_shardings, block_scale_configs)
|
|
|
|
hlo_pattern = re.compile(
|
|
r".*".join([re.escape(x) for x in expected_hlo]), flags=re.DOTALL
|
|
)
|
|
self.assertRegex(
|
|
hlo_text, hlo_pattern, msg=f"Failed to find pattern: {expected_hlo}"
|
|
)
|
|
|
|
@jtu.sample_product(
|
|
contract=[160, 96],
|
|
lhs_non_contract=[240, 100],
|
|
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
|
|
)
|
|
@jtu.run_on_devices("cuda")
|
|
def test_scaled_matmul_nvfp4(
|
|
self, contract, lhs_non_contract, dtype,
|
|
):
|
|
batch, rhs_non_contract = 2, 128
|
|
dot_config = (
|
|
(batch, lhs_non_contract, contract),
|
|
(batch, rhs_non_contract, contract),
|
|
(([2], [2]), ([0], [0]))
|
|
)
|
|
_, (a_dq, b_dq), (a_q, b_q, a_s, b_s), block_scale_configs = (
|
|
generate_nvfp4_quantized_tensors(dot_config, dtype)
|
|
)
|
|
a_gs = block_scale_configs[0].global_scale
|
|
b_gs = block_scale_configs[1].global_scale
|
|
|
|
def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
|
|
out = scaled_matmul_wrapper(
|
|
lhs,
|
|
rhs,
|
|
lhs_scales,
|
|
rhs_scales,
|
|
preferred_element_type=jnp.float32,
|
|
)
|
|
gs = a_gs * b_gs
|
|
return (out * gs).astype(out_type)
|
|
|
|
j_scaled_matmul = jax.jit(partial(wrapper, out_type=dtype))
|
|
hlo_text = (
|
|
j_scaled_matmul.lower(a_q, b_q, a_s, b_s)
|
|
.compile()
|
|
.as_text()
|
|
)
|
|
hlo_pattern = re.compile(
|
|
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
|
|
)
|
|
self.assertRegex(hlo_text, hlo_pattern)
|
|
|
|
out = j_scaled_matmul(a_q, b_q, a_s, b_s)
|
|
out_ref = jnp.einsum(
|
|
"BMK,BNK->BMN", a_dq, b_dq
|
|
)
|
|
self.assertArraysAllClose(
|
|
out, out_ref.astype(dtype), rtol=1e-2, atol=5e-2
|
|
)
|
|
|
|
@jtu.sample_product(
|
|
contract=[160, 96],
|
|
lhs_non_contract=[240, 100],
|
|
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
|
|
block_scale_configs=[mxfp8_configs,],
|
|
)
|
|
@jtu.run_on_devices("cuda")
|
|
def test_scaled_matmul(
|
|
self, contract, lhs_non_contract, dtype, block_scale_configs,
|
|
):
|
|
batch, rhs_non_contract = 2, 128
|
|
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
|
batch, lhs_non_contract, contract, rhs_non_contract,
|
|
block_scale_configs, dtype=dtype,
|
|
)
|
|
|
|
def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
|
|
return scaled_matmul_wrapper(
|
|
lhs,
|
|
rhs,
|
|
lhs_scales,
|
|
rhs_scales,
|
|
preferred_element_type=out_type,
|
|
)
|
|
|
|
j_scaled_matmul = jax.jit(partial(wrapper, out_type=dtype))
|
|
hlo_text = (
|
|
j_scaled_matmul.lower(a_q, b_q, a_scales, b_scales)
|
|
.compile()
|
|
.as_text()
|
|
)
|
|
hlo_pattern = re.compile(
|
|
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
|
|
)
|
|
self.assertRegex(hlo_text, hlo_pattern)
|
|
|
|
out = j_scaled_matmul(a_q, b_q, a_scales, b_scales)
|
|
out_ref = np.einsum(
|
|
"BMK,BNK->BMN", a.astype(jnp.float32), b.astype(jnp.float32)
|
|
)
|
|
self.assertArraysAllClose(
|
|
out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3
|
|
)
|
|
|
|
@jtu.sample_product(
|
|
in_shardings=sharding_configs,
|
|
block_scale_configs=[mxfp8_configs,],
|
|
)
|
|
@jtu.run_on_devices("cuda")
|
|
def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs):
|
|
if len(jax.local_devices()) < 4:
|
|
self.skipTest("Require at least 4 devices to run sharding tests.")
|
|
batch, contract, non_contract = 2, 1024, 256
|
|
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
|
batch, non_contract, contract, non_contract, block_scale_configs,
|
|
)
|
|
|
|
devices = np.array(jax.local_devices()[:4])
|
|
devices = devices.reshape((2, 2))
|
|
expected_output_spec = sharding_configs[in_shardings][1]
|
|
|
|
with Mesh(devices, ("dp", "tp")) as mesh:
|
|
a_q, b_q, a_scales, b_scales, input_shardings = (
|
|
shard_and_device_put(
|
|
mesh,
|
|
in_shardings[0],
|
|
in_shardings[1],
|
|
a_q,
|
|
b_q,
|
|
a_scales,
|
|
b_scales,
|
|
)
|
|
)
|
|
|
|
args = [a_q, b_q, a_scales, b_scales]
|
|
j_scaled_matmul = jax.jit(
|
|
scaled_matmul_wrapper, in_shardings=input_shardings
|
|
)
|
|
hlo_compiled = j_scaled_matmul.lower(*args).compile()
|
|
hlo_pattern = re.compile(
|
|
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
|
|
)
|
|
self.assertRegex(hlo_compiled.as_text(), hlo_pattern)
|
|
|
|
j_ref = jax.jit(
|
|
partial(
|
|
jax.lax.dot_general,
|
|
dimension_numbers=(([2], [2]), ([0], [0])),
|
|
),
|
|
in_shardings=input_shardings[:2],
|
|
)
|
|
|
|
out = j_scaled_matmul(*args)
|
|
out_ref = j_ref(a, b)
|
|
expected_output_sharding = NamedSharding(
|
|
mesh=mesh, spec=expected_output_spec
|
|
)
|
|
self.assertArraysAllClose(out, out_ref, rtol=1e-3, atol=1e-3)
|
|
self.assertTrue(
|
|
out.sharding.is_equivalent_to(expected_output_sharding, out.ndim)
|
|
)
|
|
|
|
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
class ScaledDotGeneralTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
try:
|
|
cudnn_version = check_cudnn_version()
|
|
except RuntimeError as e:
|
|
self.skipTest(str(e))
|
|
return
|
|
if _dtypes.float8_e8m0fnu is None:
|
|
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
|
|
if cudnn_version < 90700:
|
|
self.skipTest("Requires >= cuDNN 9.7.0")
|
|
if not jtu.is_cuda_compute_capability_at_least("10.0"):
|
|
self.skipTest("Requires at least Blackwell arch")
|
|
|
|
block_scale_configs = create_mxfp8_configs()
|
|
|
|
@jtu.sample_product(
|
|
shape=[
|
|
(1, 128, 128),
|
|
(64, 32),
|
|
(1024, 2048),
|
|
],
|
|
)
|
|
@jtu.run_on_devices("cuda")
|
|
def test_quantize_nvfp4(self, shape):
|
|
# To test the q-dq logic is valid with XLA
|
|
output_type = jnp.float32
|
|
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
|
|
|
a = jax.random.uniform(k1, shape, minval=-1.0, dtype=output_type)
|
|
|
|
block_scale_configs_nvfp4 = create_nvfp4_configs()
|
|
data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(jnp.float32)
|
|
scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32)
|
|
amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32) / (data_max * scale_max)
|
|
block_scale_configs_nvfp4[0] = update_global_scale(
|
|
block_scale_configs_nvfp4[0], jnp.asarray(amax_a, jnp.float32)
|
|
)
|
|
|
|
def fn(a):
|
|
a_nvfp4, a_scale = quantize(a, block_scale_configs_nvfp4[0])
|
|
return a_nvfp4, a_scale
|
|
|
|
out_q, scale = jax.jit(fn)(a)
|
|
out_q_ref, scale_ref = fn(a)
|
|
self.assertArraysAllClose(out_q, out_q_ref, rtol=1e-5, atol=1e-5)
|
|
self.assertArraysAllClose(scale, scale_ref, rtol=1e-5, atol=1e-5)
|
|
|
|
@jtu.sample_product(
|
|
configs=[
|
|
# a_shape, b_shape, dimension_numbers, is_training
|
|
((1, 128, 128), (1, 128, 128), (([2], [2]), ([0], [0])), False),
|
|
((30, 64), (100, 64), (([1], [1]), ([], [])), False),
|
|
((192, 96), (160, 96), (([1], [1]), ([], [])), True),
|
|
((64, 128, 4), (128, 128), (([1], [0]), ([], [])), True),
|
|
((1, 128, 1024), (1, 1024, 128), (([2], [1]), ([0], [0])), True),
|
|
(
|
|
(1, 128, 128, 2),
|
|
(128, 1, 2, 128),
|
|
(([2], [0]), ([0, 3], [1, 2])),
|
|
True,
|
|
),
|
|
],
|
|
output_type=[jnp.float32, jnp.float16, jnp.bfloat16],
|
|
)
|
|
@jtu.run_on_devices("cuda")
|
|
def test_dot_general_nvfp4(self, configs, output_type):
|
|
(a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = (
|
|
generate_nvfp4_quantized_tensors(configs[:-1], output_type)
|
|
)
|
|
a_gs = block_scale_configs[0].global_scale
|
|
b_gs = block_scale_configs[1].global_scale
|
|
|
|
scaled_dot_general = partial(
|
|
scaled_dot_general_wrapper,
|
|
configs=block_scale_configs
|
|
)
|
|
|
|
dimension_numbers = configs[2]
|
|
is_training = configs[-1]
|
|
def fwd(a, b, is_ref=False, use_normalized=False):
|
|
fn = jax.lax.dot_general if is_ref else scaled_dot_general
|
|
if is_ref and use_normalized:
|
|
dms = (([2], [2]), ([0], [0]))
|
|
else:
|
|
dms = dimension_numbers
|
|
|
|
y = fn(a, b, dms,
|
|
preferred_element_type=output_type)
|
|
|
|
return jnp.sum(y) if is_training else y
|
|
|
|
if is_training:
|
|
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
|
|
j_train_ref = jax.jit(
|
|
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
|
|
)
|
|
j_train_fwd_ref = jax.jit(
|
|
jax.value_and_grad(
|
|
partial(fwd, is_ref=True, use_normalized=True), argnums=[0, 1]
|
|
)
|
|
)
|
|
out, (x_grad, w_grad) = j_train(a_raw, b_raw)
|
|
_, (x_grad_ref, w_grad_ref) = j_train_ref(a_raw, b_raw)
|
|
out_ref, _ = j_train_fwd_ref(a_dq, b_dq)
|
|
|
|
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
|
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
|
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
|
else:
|
|
j_inference = jax.jit(fwd)
|
|
j_inference_ref = jax.jit(partial(fwd, is_ref=True, use_normalized=True))
|
|
out = j_inference(a_raw, b_raw)
|
|
out_ref = jnp.reshape(j_inference_ref(a_dq, b_dq), out.shape)
|
|
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=2e-1)
|
|
|
|
@jtu.sample_product(
|
|
configs=[
|
|
# a_shape, b_shape, dimension_numbers, is_training
|
|
((1, 32), (2, 32), (([1], [1]), ([], [])), False),
|
|
((30, 64), (100, 64), (([1], [1]), ([], [])), False),
|
|
((192, 96), (160, 96), (([1], [1]), ([], [])), True),
|
|
((64, 128, 4), (128, 128), (([1], [0]), ([], [])), True),
|
|
((1, 128, 1024), (1, 1024, 128), (([2], [1]), ([0], [0])), True),
|
|
(
|
|
(1, 128, 128, 2),
|
|
(128, 1, 2, 128),
|
|
(([2], [0]), ([0, 3], [1, 2])),
|
|
True,
|
|
),
|
|
],
|
|
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
|
|
)
|
|
@jtu.run_on_devices("cuda")
|
|
def test_dot_general(self, configs, output_type):
|
|
cast_to_representable = partial(
|
|
quantize_dequantize,
|
|
scale=jnp.ones((1,)),
|
|
compute_dtype=jnp.float32,
|
|
)
|
|
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
|
|
|
a_shape, b_shape, dimension_numbers, is_training = configs
|
|
a = cast_to_representable(
|
|
jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type),
|
|
self.block_scale_configs[0].data_type,
|
|
)
|
|
b = cast_to_representable(
|
|
jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type),
|
|
self.block_scale_configs[1].data_type,
|
|
)
|
|
|
|
scaled_dot_general = partial(
|
|
scaled_dot_general_wrapper,
|
|
configs=self.block_scale_configs
|
|
)
|
|
def fwd(a, b, is_ref=False):
|
|
fn = jax.lax.dot_general if is_ref else scaled_dot_general
|
|
y = fn(a, b, dimension_numbers,
|
|
preferred_element_type=output_type)
|
|
return jnp.sum(y)
|
|
|
|
if is_training:
|
|
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
|
|
|
|
j_train_ref = jax.jit(
|
|
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
|
|
)
|
|
out, (x_grad, w_grad) = j_train(a, b)
|
|
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
|
|
|
|
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
|
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
|
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
|
else:
|
|
j_inference = jax.jit(fwd)
|
|
j_inference_ref = jax.jit(partial(fwd, is_ref=True))
|
|
out = j_inference(a, b)
|
|
out_ref = j_inference_ref(a, b)
|
|
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
|
|
|
@jtu.sample_product(in_shardings=sharding_configs)
|
|
@jtu.run_on_devices("cuda")
|
|
def test_dot_general_sharded(self, in_shardings):
|
|
if len(jax.local_devices()) < 4:
|
|
self.skipTest("Require at least 4 devices to run sharding tests.")
|
|
|
|
cast_to_representable = partial(
|
|
quantize_dequantize,
|
|
scale=jnp.ones((1,)),
|
|
compute_dtype=jnp.float32,
|
|
)
|
|
|
|
dimension_numbers = (([2], [2]), ([0], [0]))
|
|
a_shape = (2, 128, 512)
|
|
b_shape = (2, 256, 512)
|
|
|
|
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
|
a = cast_to_representable(
|
|
jax.random.uniform(k1, a_shape, minval=-1.0),
|
|
self.block_scale_configs[0].data_type,
|
|
)
|
|
b = cast_to_representable(
|
|
jax.random.uniform(k2, b_shape, minval=-1.0),
|
|
self.block_scale_configs[1].data_type,
|
|
)
|
|
|
|
scaled_dot_general = partial(
|
|
scaled_dot_general_wrapper,
|
|
configs=self.block_scale_configs
|
|
)
|
|
def fwd(a, b, is_ref=False):
|
|
fn = jax.lax.dot_general if is_ref else scaled_dot_general
|
|
y = fn(a, b, dimension_numbers)
|
|
# Use a little complex loss function to avoid constant grads, whose
|
|
# sharding info might be optimized off and then cause issue with the
|
|
# custom scaled_matmul op.
|
|
return jnp.sum(jnp.tanh(y))
|
|
|
|
devices = np.array(jax.local_devices()[:4])
|
|
devices = devices.reshape((2, 2))
|
|
with Mesh(devices, ("dp", "tp")) as mesh:
|
|
a, b, input_shardings = (
|
|
shard_and_device_put(
|
|
mesh,
|
|
in_shardings[0],
|
|
in_shardings[1],
|
|
a,
|
|
b,
|
|
)
|
|
)
|
|
|
|
j_train = jax.jit(jax.value_and_grad(partial(fwd), argnums=[0, 1]),
|
|
in_shardings=input_shardings)
|
|
hlo_text = j_train.lower(a, b).compile().as_text()
|
|
hlo_pattern = re.compile(
|
|
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
|
|
)
|
|
|
|
j_train_ref = jax.jit(
|
|
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]),
|
|
in_shardings=input_shardings
|
|
)
|
|
out, (x_grad, w_grad) = j_train(a, b)
|
|
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
|
|
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
|
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
|
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
|
|
|
|
|
@jtu.sample_product(
|
|
configs=[
|
|
((1, 128, 256), (1, 128, 256), (0, 0, 0)),
|
|
((2, 128, 128), (2, 128, 128), (0, 0, 0)),
|
|
((2, 128, 128), (128, 2, 128), (0, 1, 2)),
|
|
]
|
|
)
|
|
@jtu.run_on_devices("cuda")
|
|
def test_dot_general_vmap(self, configs):
|
|
cast_to_representable = partial(
|
|
quantize_dequantize,
|
|
scale=jnp.ones((1,)),
|
|
compute_dtype=jnp.float32,
|
|
)
|
|
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
|
|
|
a_shape, b_shape, vmap_axes = configs
|
|
a_axis, b_axis, o_axis = vmap_axes
|
|
dimension_numbers = (([1], [1]), ([], []))
|
|
|
|
a = cast_to_representable(
|
|
jax.random.uniform(k1, a_shape, minval=-1.0),
|
|
self.block_scale_configs[0].data_type,
|
|
)
|
|
b = cast_to_representable(
|
|
jax.random.uniform(k2, b_shape, minval=-1.0),
|
|
self.block_scale_configs[1].data_type,
|
|
)
|
|
|
|
scaled_dot_general = partial(
|
|
scaled_dot_general_wrapper,
|
|
configs=self.block_scale_configs
|
|
)
|
|
def fwd(a, b, is_ref=False):
|
|
fn = jax.vmap(
|
|
jax.lax.dot_general if is_ref else scaled_dot_general,
|
|
in_axes=(a_axis, b_axis, None),
|
|
out_axes=o_axis,
|
|
)
|
|
y = fn(a, b, dimension_numbers)
|
|
return jnp.sum(y)
|
|
|
|
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
|
|
j_train_ref = jax.jit(
|
|
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
|
|
)
|
|
out, (x_grad, w_grad) = j_train(a, b)
|
|
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
|
|
|
|
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e2)
|
|
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
|
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|