mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Improve based on comment # 1
This commit is contained in:
parent
ae111f7c97
commit
bfb9d3ca4b
@ -32,18 +32,6 @@ class BlockScaleConfig:
|
||||
global_scale: Array | None
|
||||
infer_only: bool
|
||||
|
||||
mxfp8_config = BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
BlockScaleConfigs = List[BlockScaleConfig]
|
||||
mxfp8_configs: BlockScaleConfigs = [mxfp8_config, mxfp8_config, mxfp8_config]
|
||||
|
||||
def default_layouts(*shapes):
|
||||
return [range(len(shape) - 1, -1, -1) for shape in shapes]
|
||||
|
||||
@ -490,6 +478,8 @@ def quantize(x, config):
|
||||
# shuw(TODO): Add when XLA is ready and e2m1fn is available.
|
||||
scales_q = scales
|
||||
scales_x = x
|
||||
else:
|
||||
raise ValueError(f"Unrecognized mode: {config.mode}.")
|
||||
|
||||
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
|
||||
x_q = clipped_x.astype(config.data_type)
|
||||
@ -519,38 +509,6 @@ def quantize_dequantize(x, q_dtype, scale, compute_dtype):
|
||||
)
|
||||
return out
|
||||
|
||||
def generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
configs, dtype=jnp.float32,
|
||||
):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=dtype,
|
||||
)
|
||||
|
||||
k1, k2 = jax.random.split(jax.random.key(123), 2)
|
||||
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[1].data_type,
|
||||
)
|
||||
|
||||
dn = ((2,), (0,))
|
||||
a_3d = shape_normalization(a, dn)
|
||||
b_3d = shape_normalization(b, dn)
|
||||
a_q, a_scales = quantize(a, configs[0])
|
||||
b_q, b_scales = quantize(b, configs[1])
|
||||
|
||||
return a, b, a_q, b_q, a_scales, b_scales
|
||||
|
||||
|
||||
def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
@ -625,7 +583,7 @@ def scaled_dot_general_transpose_lhs(
|
||||
|
||||
def scaled_dot_general_transpose_rhs(
|
||||
g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike,
|
||||
configs: BlockScaleConfigs
|
||||
configs: List[BlockScaleConfig]
|
||||
):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
||||
@ -709,12 +667,23 @@ def _ensure_batch_dim(lhs, rhs, dimension_numbers):
|
||||
def scaled_dot_general_wrapper(
|
||||
lhs, rhs, dimension_numbers,
|
||||
preferred_element_type=jnp.float32,
|
||||
configs: BlockScaleConfigs=mxfp8_configs,
|
||||
configs: List[BlockScaleConfig] | None=None,
|
||||
):
|
||||
if preferred_element_type not in (jnp.float32, jnp.bfloat16, jnp.float16):
|
||||
msg = ('Only support preferred_element_type in (f32, bf16, f16), but got '
|
||||
'{preferred_element_type}')
|
||||
raise TypeError(msg)
|
||||
if configs is None:
|
||||
mxfp8_config = BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
configs = [mxfp8_config, mxfp8_config, mxfp8_config]
|
||||
|
||||
dimension_numbers = ensure_tuple(dimension_numbers)
|
||||
lhs_batched, rhs_batched, dn_batched = _ensure_batch_dim(
|
||||
lhs, rhs, dimension_numbers
|
||||
|
@ -21,13 +21,17 @@ from functools import partial
|
||||
import operator
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Any, Literal
|
||||
from typing import Any, List, Literal
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import custom_jvp
|
||||
from jax import lax
|
||||
from jax._src import config
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
scaled_matmul_wrapper as cudnn_scaled_matmul,
|
||||
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
|
||||
BlockScaleConfig)
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
@ -36,10 +40,6 @@ from jax._src.core import AxisName
|
||||
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
|
||||
from jax._src.cudnn.fused_attention_stablehlo import (
|
||||
dot_product_attention as cudnn_dot_product_attention, MaskType)
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
scaled_matmul_wrapper as cudnn_scaled_matmul,
|
||||
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
|
||||
BlockScaleConfigs, mxfp8_configs)
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.numpy import util as numpy_util
|
||||
@ -1174,10 +1174,8 @@ def scaled_matmul(
|
||||
r"""
|
||||
Performs scaled matrix multiplication between two 3D arrays, with scaling
|
||||
factors applied to the matrices.
|
||||
|
||||
.. math::
|
||||
\mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs)
|
||||
|
||||
Args:
|
||||
lhs (Array): A 3D array of shape (B, M, K).
|
||||
rhs (Array): A 3D array of shape (B, N, K).
|
||||
@ -1185,21 +1183,17 @@ def scaled_matmul(
|
||||
rhs_scales (Array): A 3D array of shape (B, N, K_block).
|
||||
preferred_element_type (DTypeLike, optional): The preferred data type
|
||||
for the computation. Defaults to `jnp.float32`.
|
||||
|
||||
Returns:
|
||||
Array: A 3D array of shape (B, M, N) representing the scaled matrix
|
||||
multiplication result.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
|
||||
match the number of columns in `rhs` (`rhs_K`).
|
||||
|
||||
Notes:
|
||||
- The function ensures that the `preferred_element_type` is
|
||||
danonicalized before passing it to the underlying computation.
|
||||
- Scaling is applied to the matrices based on the `lhs_scales` and
|
||||
`rhs_scales` arrays, enabling efficient computations in blocks.
|
||||
|
||||
"""
|
||||
B, M, lhs_K = lhs.shape
|
||||
_, N, rhs_K = rhs.shape
|
||||
@ -1222,17 +1216,15 @@ def scaled_dot_general(
|
||||
lhs, rhs,
|
||||
dimension_numbers,
|
||||
preferred_element_type=jnp.float32,
|
||||
configs: BlockScaleConfigs=mxfp8_configs,
|
||||
configs: List[BlockScaleConfig] | None = None,
|
||||
implementation: Literal['cudnn'] | None = None,
|
||||
):
|
||||
r"""Scaled dot general operation.
|
||||
|
||||
Computes the scaled dot general on lhs, rhs with quanitzation specified by configs:
|
||||
|
||||
.. math::
|
||||
\widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\
|
||||
\widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\
|
||||
\mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs})
|
||||
|
||||
Args:
|
||||
lhs: Left-hand side input tensor.
|
||||
rhs: Right-hand side input tensor.
|
||||
@ -1241,14 +1233,37 @@ def scaled_dot_general(
|
||||
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
|
||||
preferred_element_type: The preferred output data type. Supported types are
|
||||
`jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`.
|
||||
configs: A `BlockScaleConfigs` objects specifying the scaling
|
||||
configurations for the operation. Defaults to `mxfp8_configs`.
|
||||
|
||||
configs: A list of `BlockScaleConfig` specifying the scaling
|
||||
configurations for the operation. Defaults to `mxfp8`.
|
||||
implementation: A string to control which implementation backend to use.
|
||||
Supported strings are `cudnn` (cuDNN block scaled dot). It defaults
|
||||
to `None`, which will automatically select the best available backend.
|
||||
Returns:
|
||||
The result of the scaled dot general operation.
|
||||
"""
|
||||
return cudnn_scaled_dot_general(
|
||||
lhs, rhs, dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
configs=configs
|
||||
)
|
||||
# Create configs if not provided
|
||||
if configs is None:
|
||||
mxfp8_config = BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
configs = [mxfp8_config for _ in range(3)]
|
||||
|
||||
if implementation is None:
|
||||
implementation = 'cudnn'
|
||||
|
||||
match implementation:
|
||||
case 'cudnn':
|
||||
out = cudnn_scaled_dot_general(
|
||||
lhs, rhs, dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
configs=configs
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported implementation option: {implementation}")
|
||||
|
||||
return out
|
||||
|
@ -30,9 +30,10 @@ from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
generate_quantized_tensors,
|
||||
mxfp8_configs,
|
||||
quantize,
|
||||
quantize_dequantize,
|
||||
shape_normalization,
|
||||
BlockScaleConfig,
|
||||
)
|
||||
from jax.test_util import check_grads
|
||||
from jax import nn
|
||||
@ -42,9 +43,9 @@ import jax.numpy as jnp
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def _is_required_cudnn_version_satisfied(min_cudnn_version):
|
||||
def _is_required_cudnn_version_satisfied(min_cc, min_cudnn_version):
|
||||
return (
|
||||
jtu.is_cuda_compute_capability_at_least("8.0") and
|
||||
jtu.is_cuda_compute_capability_at_least(min_cc) and
|
||||
cuda_versions is not None and
|
||||
cuda_versions.cudnn_get_version() >= min_cudnn_version
|
||||
)
|
||||
@ -56,6 +57,50 @@ def _check_cudnn_backend(fn, *args, **kwargs):
|
||||
|
||||
_cudnn_dbias_error = 'cuDNN only supports bias gradient'
|
||||
|
||||
def _generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
configs, dtype=jnp.float32,
|
||||
):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=dtype,
|
||||
)
|
||||
|
||||
k1, k2 = jax.random.split(jax.random.key(123), 2)
|
||||
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[1].data_type,
|
||||
)
|
||||
|
||||
dn = ((2,), (0,))
|
||||
a_3d = shape_normalization(a, dn)
|
||||
b_3d = shape_normalization(b, dn)
|
||||
a_q, a_scales = quantize(a, configs[0])
|
||||
b_q, b_scales = quantize(b, configs[1])
|
||||
|
||||
return a, b, a_q, b_q, a_scales, b_scales
|
||||
|
||||
_create_mxfp8_config = lambda: BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
mxfp8_configs = [_create_mxfp8_config() for _ in range(3)]
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key="allow",
|
||||
jax_numpy_dtype_promotion="standard")
|
||||
class NNFunctionsTest(jtu.JaxTestCase):
|
||||
@ -64,10 +109,13 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
lhs_non_contract=[240, 100],
|
||||
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
configs=[mxfp8_configs,],
|
||||
impl=['cudnn',],
|
||||
)
|
||||
def testScaledMatmul(self, contract, lhs_non_contract, dtype, configs):
|
||||
def testScaledMatmul(self, contract, lhs_non_contract, dtype, configs, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
|
||||
batch, rhs_non_contract = 4, 256
|
||||
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
||||
a, b, a_q, b_q, a_scales, b_scales = _generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
configs, dtype=dtype,
|
||||
)
|
||||
@ -83,10 +131,12 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
is_training=[True, False],
|
||||
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
configs=[mxfp8_configs,],
|
||||
impl=['cudnn',],
|
||||
)
|
||||
def testScaledDotGeneral(
|
||||
self, is_training, output_type, configs,
|
||||
):
|
||||
self, is_training, output_type, configs, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
@ -132,7 +182,6 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
out = j_inference(a, b)
|
||||
out_ref = j_inference_ref(a, b)
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@parameterized.product(
|
||||
dtype=[jnp.bfloat16, jnp.float16],
|
||||
group_num=[1, 2, 4],
|
||||
@ -140,7 +189,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
impl=['cudnn', 'xla'],
|
||||
)
|
||||
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("8.0", 8904):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
if impl == 'cudnn' and dtype == jnp.float32:
|
||||
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
|
||||
@ -189,7 +238,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
if isinstance(mask_mode, str):
|
||||
mask_mode = (mask_mode,)
|
||||
min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904
|
||||
if not _is_required_cudnn_version_satisfied(min_cudnn_version):
|
||||
if not _is_required_cudnn_version_satisfied("8.0", min_cudnn_version):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
|
||||
dtype = jnp.bfloat16
|
||||
@ -252,7 +301,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
use_vmap=[False, True],
|
||||
)
|
||||
def testDotProductAttentionBiasGradient(self, batch_size, use_vmap):
|
||||
if not _is_required_cudnn_version_satisfied(8904):
|
||||
if not _is_required_cudnn_version_satisfied("8.0", 8904):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
|
||||
dtype = jnp.bfloat16
|
||||
|
@ -13,14 +13,15 @@ from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
scaled_matmul_wrapper,
|
||||
scaled_dot_general_wrapper,
|
||||
mxfp8_configs,
|
||||
generate_quantized_tensors,
|
||||
shape_normalization,
|
||||
quantize,
|
||||
quantize_dequantize,
|
||||
BlockScaleConfig,
|
||||
)
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
input_sharding_configs = [
|
||||
input_shardings = [
|
||||
(("dp", None, "tp"), ("dp", None, "tp")),
|
||||
(("dp", None, "tp"), ("dp", None, None)),
|
||||
(("dp", None, "tp"), ("dp", "tp", None)),
|
||||
@ -41,7 +42,63 @@ expected_hlos = [
|
||||
("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name),
|
||||
("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name),
|
||||
]
|
||||
sharding_configs = [[i, j] for i, j in zip(input_sharding_configs, expected_hlos)]
|
||||
expected_output_spec = [
|
||||
PartitionSpec('dp',),
|
||||
PartitionSpec('dp',),
|
||||
PartitionSpec('dp', None, 'tp'),
|
||||
PartitionSpec('dp', None, 'tp'),
|
||||
PartitionSpec('dp', 'tp', None),
|
||||
PartitionSpec(None, 'dp', 'tp'),
|
||||
PartitionSpec(None, 'tp', None),
|
||||
PartitionSpec(None, None, 'tp'),
|
||||
]
|
||||
sharding_configs = {
|
||||
input_sharding: (hlo, output_spec)
|
||||
for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec)
|
||||
}
|
||||
|
||||
mxfp8_config = BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
mxfp8_configs = [mxfp8_config for _ in range(3)]
|
||||
|
||||
def generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
configs, dtype=jnp.float32,
|
||||
):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=dtype,
|
||||
)
|
||||
|
||||
k1, k2 = jax.random.split(jax.random.key(123), 2)
|
||||
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[1].data_type,
|
||||
)
|
||||
|
||||
dn = ((2,), (0,))
|
||||
a_3d = shape_normalization(a, dn)
|
||||
b_3d = shape_normalization(b, dn)
|
||||
a_q, a_scales = quantize(a, configs[0])
|
||||
b_q, b_scales = quantize(b, configs[1])
|
||||
|
||||
return a, b, a_q, b_q, a_scales, b_scales
|
||||
|
||||
|
||||
def shard_and_device_put(
|
||||
@ -76,7 +133,10 @@ def shard_and_device_put(
|
||||
return a, b, in_shardings
|
||||
|
||||
|
||||
def get_hlo_text(in_shardings, block_scale_configs=mxfp8_configs):
|
||||
def get_hlo_text(in_shardings, block_scale_configs=None):
|
||||
if block_scale_configs is None:
|
||||
block_scale_configs = mxfp8_configs
|
||||
|
||||
mesh_names = ("dp", "tp")
|
||||
devices = np.array(jax.local_devices()[:4]).reshape((2, 2))
|
||||
mesh = Mesh(devices, mesh_names)
|
||||
@ -109,15 +169,15 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
self.skipTest("Requires at least Blackwell arch")
|
||||
|
||||
@jtu.sample_product(
|
||||
sharding_config=sharding_configs,
|
||||
in_shardings=sharding_configs,
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_collectives(self, sharding_config):
|
||||
def test_collectives(self, in_shardings):
|
||||
if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4:
|
||||
self.skipTest("Partition Test enabled for at least 4 GPUs")
|
||||
|
||||
input_sharding, expected_hlo = sharding_config[0], sharding_config[1]
|
||||
hlo_text = get_hlo_text(input_sharding)
|
||||
expected_hlo = sharding_configs[in_shardings][0]
|
||||
hlo_text = get_hlo_text(in_shardings)
|
||||
|
||||
hlo_pattern = re.compile(
|
||||
r".*".join([re.escape(x) for x in expected_hlo]), flags=re.DOTALL
|
||||
@ -171,22 +231,21 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
@jtu.sample_product(
|
||||
sharding_config=sharding_configs,
|
||||
in_shardings=sharding_configs,
|
||||
block_scale_configs=[mxfp8_configs,],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_scaled_matmul_sharded(self, sharding_config, block_scale_configs):
|
||||
def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs):
|
||||
if len(jax.local_devices()) < 4:
|
||||
self.skipTest("Require at least 4 devices to run sharding tests.")
|
||||
batch, contract, non_contract = 2, 1024, 256
|
||||
|
||||
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
||||
batch, non_contract, contract, non_contract, block_scale_configs,
|
||||
)
|
||||
|
||||
devices = np.array(jax.local_devices()[:4])
|
||||
devices = devices.reshape((2, 2))
|
||||
in_shardings = sharding_config[0]
|
||||
expected_output_spec = sharding_configs[in_shardings][1]
|
||||
|
||||
with Mesh(devices, ("dp", "tp")) as mesh:
|
||||
a_q, b_q, a_scales, b_scales, input_shardings = (
|
||||
@ -205,11 +264,11 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
j_scaled_matmul = jax.jit(
|
||||
scaled_matmul_wrapper, in_shardings=input_shardings
|
||||
)
|
||||
hlo_text = j_scaled_matmul.lower(*args).compile().as_text()
|
||||
hlo_compiled = j_scaled_matmul.lower(*args).compile()
|
||||
hlo_pattern = re.compile(
|
||||
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
|
||||
)
|
||||
self.assertRegex(hlo_text, hlo_pattern)
|
||||
self.assertRegex(hlo_compiled.as_text(), hlo_pattern)
|
||||
|
||||
j_ref = jax.jit(
|
||||
partial(
|
||||
@ -221,8 +280,13 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
|
||||
out = j_scaled_matmul(*args)
|
||||
out_ref = j_ref(a, b)
|
||||
expected_output_sharding = NamedSharding(
|
||||
mesh=mesh, spec=expected_output_spec
|
||||
)
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-3, atol=1e-3)
|
||||
|
||||
self.assertTrue(
|
||||
out.sharding.is_equivalent_to(expected_output_sharding, out.ndim)
|
||||
)
|
||||
|
||||
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
||||
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
@ -306,9 +370,9 @@ class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
out_ref = j_inference_ref(a, b)
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@jtu.sample_product(sharding_config=input_sharding_configs)
|
||||
@jtu.sample_product(in_shardings=sharding_configs)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_dot_general_sharded(self, sharding_config):
|
||||
def test_dot_general_sharded(self, in_shardings):
|
||||
if len(jax.local_devices()) < 4:
|
||||
self.skipTest("Require at least 4 devices to run sharding tests.")
|
||||
|
||||
@ -344,7 +408,6 @@ class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
# custom scaled_matmul op.
|
||||
return jnp.sum(jnp.tanh(y))
|
||||
|
||||
in_shardings = sharding_config
|
||||
devices = np.array(jax.local_devices()[:4])
|
||||
devices = devices.reshape((2, 2))
|
||||
with Mesh(devices, ("dp", "tp")) as mesh:
|
||||
@ -375,6 +438,7 @@ class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
||||
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
configs=[
|
||||
((1, 128, 256), (1, 128, 256), (0, 0, 0)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user