mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix CI
This commit is contained in:
parent
d9456f36b3
commit
681ee18436
@ -28,10 +28,6 @@ import jax.numpy as jnp
|
||||
from jax import custom_jvp
|
||||
from jax import lax
|
||||
from jax._src import config
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
scaled_matmul_wrapper as cudnn_scaled_matmul,
|
||||
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
|
||||
BlockScaleConfig)
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
@ -40,6 +36,10 @@ from jax._src.core import AxisName
|
||||
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
|
||||
from jax._src.cudnn.fused_attention_stablehlo import (
|
||||
dot_product_attention as cudnn_dot_product_attention, MaskType)
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
scaled_matmul_wrapper as cudnn_scaled_matmul,
|
||||
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
|
||||
BlockScaleConfig)
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.numpy import util as numpy_util
|
||||
|
@ -104,7 +104,7 @@ def create_mxfp8_configs_if_available():
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
|
||||
return [_create_mxfp8_config() for _ in range(3)]
|
||||
|
||||
|
||||
|
@ -8,6 +8,7 @@ import jax.numpy as jnp
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec, NamedSharding
|
||||
from jax._src import config
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
@ -19,7 +20,6 @@ from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
BlockScaleConfig,
|
||||
)
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
input_shardings = [
|
||||
(("dp", None, "tp"), ("dp", None, "tp")),
|
||||
@ -57,22 +57,6 @@ sharding_configs = {
|
||||
for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec)
|
||||
}
|
||||
|
||||
def create_mxfp8_configs_if_available():
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
raise unittest.SkipTest("float8_e8m0fnu is not available.")
|
||||
|
||||
def _create_mxfp8_config():
|
||||
return BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
return [_create_mxfp8_config() for _ in range(3)]
|
||||
|
||||
|
||||
def generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
@ -139,11 +123,22 @@ def shard_and_device_put(
|
||||
|
||||
return a, b, in_shardings
|
||||
|
||||
def create_mxfp8_configs():
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
return None
|
||||
|
||||
def get_hlo_text(in_shardings, block_scale_configs=None):
|
||||
if block_scale_configs is None:
|
||||
block_scale_configs = create_mxfp8_configs_if_available()
|
||||
mxfp8_config = BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
return [mxfp8_config for _ in range(3)]
|
||||
|
||||
def get_hlo_text(in_shardings, block_scale_configs):
|
||||
mesh_names = ("dp", "tp")
|
||||
devices = np.array(jax.local_devices()[:4]).reshape((2, 2))
|
||||
mesh = Mesh(devices, mesh_names)
|
||||
@ -170,21 +165,26 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
except RuntimeError as e:
|
||||
self.skipTest(str(e))
|
||||
return
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
|
||||
if cudnn_version < 90700:
|
||||
self.skipTest("Requires >= cuDNN 9.7.0")
|
||||
if not jtu.is_cuda_compute_capability_at_least("10.0"):
|
||||
self.skipTest("Requires at least Blackwell arch")
|
||||
|
||||
mxfp8_configs = create_mxfp8_configs()
|
||||
|
||||
@jtu.sample_product(
|
||||
in_shardings=sharding_configs,
|
||||
block_scale_configs=[mxfp8_configs,],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_collectives(self, in_shardings):
|
||||
def test_collectives(self, in_shardings, block_scale_configs):
|
||||
if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4:
|
||||
self.skipTest("Partition Test enabled for at least 4 GPUs")
|
||||
|
||||
expected_hlo = sharding_configs[in_shardings][0]
|
||||
hlo_text = get_hlo_text(in_shardings)
|
||||
hlo_text = get_hlo_text(in_shardings, block_scale_configs)
|
||||
|
||||
hlo_pattern = re.compile(
|
||||
r".*".join([re.escape(x) for x in expected_hlo]), flags=re.DOTALL
|
||||
@ -197,7 +197,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
contract=[160, 96],
|
||||
lhs_non_contract=[240, 100],
|
||||
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
block_scale_configs=[create_mxfp8_configs_if_available(),],
|
||||
block_scale_configs=[mxfp8_configs,],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_scaled_matmul(
|
||||
@ -239,7 +239,7 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
in_shardings=sharding_configs,
|
||||
block_scale_configs=[create_mxfp8_configs_if_available(),],
|
||||
block_scale_configs=[mxfp8_configs,],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs):
|
||||
@ -298,8 +298,6 @@ class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
||||
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
|
||||
block_scale_configs = create_mxfp8_configs_if_available()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
try:
|
||||
@ -307,11 +305,15 @@ class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
except RuntimeError as e:
|
||||
self.skipTest(str(e))
|
||||
return
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
|
||||
if cudnn_version < 90700:
|
||||
self.skipTest("Requires >= cuDNN 9.7.0")
|
||||
if not jtu.is_cuda_compute_capability_at_least("10.0"):
|
||||
self.skipTest("Requires at least Blackwell arch")
|
||||
|
||||
block_scale_configs = create_mxfp8_configs()
|
||||
|
||||
@jtu.sample_product(
|
||||
configs=[
|
||||
# a_shape, b_shape, dimension_numbers, is_training
|
||||
|
Loading…
x
Reference in New Issue
Block a user