rocm_jax/tests/mosaic/matmul_test.py

129 lines
4.1 KiB
Python

# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test different parameterizations of a matmul."""
import os
import unittest
from absl.testing import absltest, parameterized
from jax._src import config
from jax._src import test_util as jtu
import jax.numpy as jnp
try:
# We only import this to see if Mosaic is available.
import jax.experimental.mosaic.gpu # noqa: F401
except ImportError:
matmul = None
else:
from jax.experimental.mosaic.gpu.examples import matmul
try:
import hypothesis as hp
import hypothesis.strategies as hps
except (ModuleNotFoundError, ImportError):
raise unittest.SkipTest("these tests require hypothesis")
config.parse_flags_with_absl()
jtu.setup_hypothesis()
os.environ["XLA_FLAGS"] = (
os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0")
def seed_hypothesis(f):
def wrapper(self, seed):
return hp.seed(seed)(f)(self)
return wrapper
@jtu.with_config(jax_traceback_filtering="off")
class MatmulTestCase(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if matmul is None:
self.skipTest("Mosaic GPU not available.")
if (not jtu.test_device_matches(["cuda"]) or
not jtu.is_cuda_compute_capability_equal("9.0")):
self.skipTest("Only works on GPU with capability sm90a")
@parameterized.named_parameters(
(f"_shard{i}", i) for i in range(5)
)
@seed_hypothesis
@hp.settings(max_examples=100) # Add verbosity=hp.Verbosity.verbose to debug
@hp.given(hps.data())
def test_matmul(self, data):
in_dtype = data.draw(
hps.sampled_from([jnp.float16, jnp.bfloat16, jnp.float32]),
label="in_dtype",
)
out_dtype = jnp.float32
if in_dtype != jnp.float32:
out_dtype = data.draw(
hps.sampled_from([in_dtype, jnp.float32]),
label="out_dtype",
)
bytewidth = jnp.dtype(in_dtype).itemsize
m, n, k = (
data.draw(hps.sampled_from([128, 256, 512, 2048]), label=d)
for d in "mnk"
)
stages = data.draw(hps.integers(2, 5), label="stages")
swizzle = data.draw(hps.sampled_from([32, 64, 128]), label="swizzle")
tile_m = data.draw(
hps.sampled_from([t for t in [64, 128, 256] if t <= m]), label="tile_m"
)
tile_n = data.draw(
hps.sampled_from([t for t in [64, 128, 256] if t <= n]), label="tile_n"
)
grid_m, grid_n = m // tile_m, n // tile_n
grid_tile_n = data.draw(hps.sampled_from([1, 2, 4, 8, 16]), label="grid_tile_n")
hp.assume(grid_n % grid_tile_n == 0)
cluster_m = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_m")
hp.assume(grid_m % cluster_m == 0)
cluster_n = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_n")
hp.assume(grid_n % cluster_n == 0)
# TODO(apaszke): Non-portable clusters (16 blocks) sometimes deadlock.
hp.assume(cluster_m * cluster_n <= 8)
if bytewidth == 4:
rhs_transpose = True
else:
rhs_transpose = data.draw(hps.booleans(), label="rhs_transpose")
try:
matmul.verify(
m,
k,
n,
stages=stages,
tile_m=tile_m,
tile_n=tile_n,
in_dtype=in_dtype,
out_dtype=out_dtype,
cluster_m=cluster_m,
cluster_n=cluster_n,
grid_tile_n=grid_tile_n,
swizzle=swizzle,
rhs_transpose=rhs_transpose,
)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" in str(e):
hp.assume(False)
raise e
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())