mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Simplify the matmul example
Remove a bunch of WGMMAImpl classes. This is meant to be a simple forkable example, not a complete kernel. PiperOrigin-RevId: 655923069
This commit is contained in:
parent
be9cc807d9
commit
e59303cf3e
@ -15,8 +15,8 @@
|
||||
"""Matmul kernels for H100."""
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import random
|
||||
@ -47,30 +47,19 @@ class Tiling:
|
||||
n: int
|
||||
k: int
|
||||
|
||||
@property
|
||||
def mk(self):
|
||||
return (self.m, self.k)
|
||||
|
||||
@property
|
||||
def kn(self):
|
||||
return (self.k, self.n)
|
||||
|
||||
@property
|
||||
def nk(self):
|
||||
return (self.n, self.k)
|
||||
|
||||
@property
|
||||
def mn(self):
|
||||
return (self.m, self.n)
|
||||
|
||||
|
||||
class F32Precision(enum.Enum):
|
||||
DEFAULT = enum.auto()
|
||||
TF32_X3 = enum.auto()
|
||||
# Allow access by .mk, .kn, .mn, etc.
|
||||
def __getattr__(self, name):
|
||||
if len(name) == 1:
|
||||
return super().__getattribute__(name)
|
||||
return tuple(getattr(self, d) for d in name)
|
||||
|
||||
|
||||
class WGMMADefaultImpl:
|
||||
"""Default WGMMA implementation."""
|
||||
"""Default WGMMA implementation.
|
||||
|
||||
The kernel can accept any class that satisfies the same interface as this
|
||||
class.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def zero_accs(tile_m: int, tile_n: int) -> WGMMAAccumulator:
|
||||
@ -83,144 +72,32 @@ class WGMMADefaultImpl:
|
||||
lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype,
|
||||
rhs_transpose: WGMMALayout,
|
||||
) -> dict[str, jax.ShapeDtypeStruct]:
|
||||
del block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose
|
||||
return {}
|
||||
del block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose # Unused.
|
||||
return ()
|
||||
|
||||
@staticmethod
|
||||
def get_result_tile(acc: WGMMAAccumulator) -> FragmentedArray:
|
||||
def get_result(acc: WGMMAAccumulator) -> FragmentedArray:
|
||||
return acc.value
|
||||
|
||||
@staticmethod
|
||||
def wgmma(
|
||||
smem_scratch: dict[str, SmemRef], # pylint: disable=unused-argument
|
||||
smem_scratch: Any, # pylint: disable=unused-argument
|
||||
acc: WGMMAAccumulator,
|
||||
b_order: WGMMALayout,
|
||||
a_slice: SmemRef,
|
||||
b_slice: SmemRef,
|
||||
) -> dict[str, WGMMAAccumulator]:
|
||||
"""Perform a matrix multiplication.
|
||||
|
||||
This function must guarantee that all WGMMA operations queued before it was
|
||||
called have completed before returning.
|
||||
"""
|
||||
acc = wgmma(acc, a_slice, b_slice, b_order=b_order)
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
nvvm.wgmma_wait_group_sync_aligned(1)
|
||||
return acc
|
||||
|
||||
|
||||
class WGMMATF32x3Impl:
|
||||
"""WGMMA implementation for 3xTF32 precision."""
|
||||
|
||||
@staticmethod
|
||||
def zero_accs(tile_m, tile_n) -> dict[str, WGMMAAccumulator]:
|
||||
zero_acc = WGMMADefaultImpl.zero_accs(tile_m, tile_n)
|
||||
return {"main": zero_acc, "errs": zero_acc}
|
||||
|
||||
@staticmethod
|
||||
def smem_shape_extra(
|
||||
block_tiling: Tiling,
|
||||
tma_tiling: Tiling,
|
||||
lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype,
|
||||
rhs_transpose: bool,
|
||||
) -> dict[str, jax.ShapeDtypeStruct]:
|
||||
del rhs_transpose
|
||||
lhs_err = jax.ShapeDtypeStruct(shape=tile_shape(block_tiling.mk, tma_tiling.mk), dtype=lhs_dtype)
|
||||
rhs_err = jax.ShapeDtypeStruct(shape=tile_shape(block_tiling.kn, tma_tiling.kn), dtype=rhs_dtype)
|
||||
return {"lhs_err": lhs_err, "rhs_err": rhs_err}
|
||||
|
||||
@staticmethod
|
||||
def get_result_tile(accs) -> FragmentedArray:
|
||||
return accs["main"].value + accs["errs"].value
|
||||
|
||||
@staticmethod
|
||||
def rounding_error(x_ref, err_ref):
|
||||
"""Store the TF32 rounding error of x_ref in err_ref."""
|
||||
f32 = ir.F32Type.get()
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
t = FragmentedArray.load_strided(x_ref)
|
||||
tf32_mask = FragmentedArray.splat(c(0xFFFFE000, i32), t.shape, t.layout)
|
||||
t_tf32 = (t.bitcast(i32) & tf32_mask).bitcast(f32)
|
||||
(t - t_tf32).store_untiled(err_ref)
|
||||
|
||||
@staticmethod
|
||||
def wgmma(
|
||||
smem_scratch: dict[str, SmemRef],
|
||||
accs: dict[str, WGMMAAccumulator],
|
||||
b_order: WGMMALayout,
|
||||
a_slice: SmemRef,
|
||||
b_slice: SmemRef,
|
||||
) -> dict[str, WGMMAAccumulator]:
|
||||
acc = wgmma(accs["main"], a_slice, b_slice, b_order=b_order)
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
# Note: we assert that only the slice_ab and err_b mmas are still running
|
||||
# which are unaffected by writing to the err_a shared memory.
|
||||
# After nvvm.wgmma_wait_group_sync_aligned(2) there are no wgmmas
|
||||
# accessing err_a so we can safely write to it.
|
||||
nvvm.wgmma_wait_group_sync_aligned(2)
|
||||
WGMMATF32x3Impl.rounding_error(a_slice, smem_scratch["lhs_err"])
|
||||
commit_shared()
|
||||
acc_err = wgmma(accs["errs"], smem_scratch["lhs_err"], b_slice, b_order=b_order)
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
# Note: similar to the above we wait for the last wgmma access to
|
||||
# err_b which was 2 wgmmas ago.
|
||||
nvvm.wgmma_wait_group_sync_aligned(2)
|
||||
WGMMATF32x3Impl.rounding_error(b_slice, smem_scratch["rhs_err"])
|
||||
commit_shared()
|
||||
acc_err = wgmma(acc_err, a_slice, smem_scratch["rhs_err"], b_order=b_order)
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
nvvm.wgmma_wait_group_sync_aligned(2)
|
||||
return {"main": acc, "errs": acc_err}
|
||||
|
||||
class WGMMACvtRhsImpl:
|
||||
"""Mixed WGMMA implementation where B is converted to A."""
|
||||
|
||||
@staticmethod
|
||||
def zero_accs(tile_m: int, tile_n: int) -> WGMMAAccumulator:
|
||||
return WGMMADefaultImpl.zero_accs(tile_m, tile_n)
|
||||
|
||||
@staticmethod
|
||||
def smem_shape_extra(
|
||||
block_tiling: Tiling,
|
||||
tma_tiling: Tiling,
|
||||
lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype,
|
||||
rhs_transpose: bool,
|
||||
) -> dict[str, jax.ShapeDtypeStruct]:
|
||||
del rhs_dtype
|
||||
if rhs_transpose:
|
||||
raise NotImplementedError("Transpose requires more elaborate handling of tiling.")
|
||||
|
||||
if tma_tiling.k != 64:
|
||||
raise ValueError(f"WGMMA layout needs the left tiling dimension to be 64 {tma_tiling.k=}")
|
||||
|
||||
# The second dim needs to be tma_tiling.k so it is 128b wide and
|
||||
# the first dim needs to line up with the lhs dimension. That's
|
||||
# why we have a strange (k, k) here.
|
||||
cvt_shape = tile_shape(block_tiling.kn, (tma_tiling.k, tma_tiling.k))
|
||||
return {"cvt": jax.ShapeDtypeStruct(shape=cvt_shape, dtype=lhs_dtype)}
|
||||
|
||||
@staticmethod
|
||||
def get_result_tile(acc: WGMMAAccumulator) -> FragmentedArray:
|
||||
return WGMMADefaultImpl.get_result_tile(acc)
|
||||
|
||||
@staticmethod
|
||||
def wgmma(
|
||||
smem_scratch: dict[str, SmemRef], # pylint: disable=unused-argument
|
||||
acc: WGMMAAccumulator,
|
||||
b_order: WGMMALayout,
|
||||
a_slice: SmemRef,
|
||||
b_slice: SmemRef,
|
||||
) -> dict[str, WGMMAAccumulator]:
|
||||
# Convert the load
|
||||
arr = FragmentedArray.load_tiled(b_slice, swizzle=128)
|
||||
cvt_ty = ir.MemRefType(smem_scratch["cvt"].type)
|
||||
# TODO(cperivol): https://research.google/blog/mixed-input-matrix-multiplication-performance-optimizations/
|
||||
arr = arr.astype(cvt_ty.element_type)
|
||||
# Make sure no wgmma is running.
|
||||
# TODO(cperivol): double buffer.
|
||||
nvvm.wgmma_wait_group_sync_aligned(0)
|
||||
arr.store_tiled(smem_scratch["cvt"], swizzle=128)
|
||||
commit_shared()
|
||||
acc = wgmma(acc, a_slice, smem_scratch["cvt"], b_order=b_order)
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
return acc
|
||||
|
||||
|
||||
def mlir_context(f):
|
||||
def wrap(*args, **kw):
|
||||
with mlir.make_ir_context(), ir.Location.unknown():
|
||||
@ -281,21 +158,16 @@ def build_kernel(
|
||||
|
||||
c = arith.ConstantOp.create_index
|
||||
|
||||
compute_scratch_shapes = {
|
||||
"lhs": jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.mk, tma_tiling.mk)), lhs_dtype),
|
||||
"rhs": jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.kn, tma_tiling.kn)), rhs_dtype),
|
||||
}
|
||||
compute_scratch_shapes |= wgmma_impl.smem_shape_extra(block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose)
|
||||
|
||||
epilogue_scratch_shapes = {
|
||||
"acc": jax.ShapeDtypeStruct(out_tile.shape, out_tile.dtype),
|
||||
}
|
||||
|
||||
smem_shape = mosaic_gpu.Union(
|
||||
[compute_scratch_shapes, epilogue_scratch_shapes])
|
||||
compute_scratch_shape = (
|
||||
jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.mk, tma_tiling.mk)), lhs_dtype),
|
||||
jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.kn, tma_tiling.kn)), rhs_dtype),
|
||||
wgmma_impl.smem_shape_extra(block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose),
|
||||
)
|
||||
epilogue_scratch_shape = jax.ShapeDtypeStruct(out_tile.shape, out_tile.dtype)
|
||||
smem_shape = mosaic_gpu.Union([compute_scratch_shape, epilogue_scratch_shape])
|
||||
|
||||
def _main(ctx, a_device, b_device, c_device, smem):
|
||||
(compute_smem, epilogue_smem), barriers = smem
|
||||
((lhs_smem, rhs_smem, impl_smem), epilogue_smem), barriers = smem
|
||||
|
||||
memref.assume_alignment(c_device, 16)
|
||||
|
||||
@ -315,7 +187,7 @@ def build_kernel(
|
||||
barrier.arrive_expect_tx(txcount)
|
||||
ctx.async_copy(
|
||||
src_ref=a_device,
|
||||
dst_ref=memref_slice(compute_smem["lhs"], slot),
|
||||
dst_ref=memref_slice(lhs_smem, slot),
|
||||
gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)),
|
||||
gmem_transform=mosaic_gpu.TileTransform(tma_tiling.mk),
|
||||
**common_copy_args,
|
||||
@ -328,7 +200,7 @@ def build_kernel(
|
||||
assert tma_tiling.n == tma_tiling.k, block_tiling # No need to flip the tiling.
|
||||
ctx.async_copy(
|
||||
src_ref=b_device,
|
||||
dst_ref=memref_slice(compute_smem["rhs"], slot),
|
||||
dst_ref=memref_slice(rhs_smem, slot),
|
||||
gmem_slice=rhs_slice,
|
||||
gmem_transform=rhs_transform,
|
||||
**common_copy_args,
|
||||
@ -348,13 +220,12 @@ def build_kernel(
|
||||
barriers[si].wait()
|
||||
|
||||
with ctx.named_region("WGMMA"):
|
||||
a_slice = memref_slice(compute_smem["lhs"], si)
|
||||
b_slice = memref_slice(compute_smem["rhs"], si)
|
||||
a_slice = memref_slice(lhs_smem, si)
|
||||
b_slice = memref_slice(rhs_smem, si)
|
||||
rhs_smem_order = (
|
||||
WGMMALayout.COL_MAJOR if rhs_transpose else WGMMALayout.ROW_MAJOR
|
||||
)
|
||||
accs = wgmma_impl.wgmma(
|
||||
compute_smem, accs, rhs_smem_order, a_slice, b_slice)
|
||||
accs = wgmma_impl.wgmma(impl_smem, accs, rhs_smem_order, a_slice, b_slice)
|
||||
|
||||
with ctx.named_region("TMA start"):
|
||||
tma_ki = arith.addi(ki, c(stages - 1))
|
||||
@ -370,19 +241,18 @@ def build_kernel(
|
||||
|
||||
return accs
|
||||
|
||||
# Wait until everyone is done with their WMMA
|
||||
# Wait until WGMMA is complete and we can safely read the accumulator.
|
||||
with ctx.named_region("WGMMA drain"):
|
||||
nvvm.wgmma_wait_group_sync_aligned(0)
|
||||
|
||||
with ctx.named_region("SMEM store"):
|
||||
acc_val = wgmma_impl.get_result_tile(stage_loop_body.result)
|
||||
acc_smem = epilogue_smem["acc"]
|
||||
acc_val.store_tiled(acc_smem, swizzle=128)
|
||||
gpu.barrier()
|
||||
acc_val = wgmma_impl.get_result(stage_loop_body.result)
|
||||
acc_val.store_tiled(epilogue_smem, swizzle=128)
|
||||
commit_shared() # Make sure the stores are visible to TMA.
|
||||
|
||||
with ctx.named_region("GMEM store"):
|
||||
ctx.async_copy(
|
||||
src_ref=acc_smem,
|
||||
src_ref=epilogue_smem,
|
||||
dst_ref=c_device,
|
||||
gmem_slice=(ds(m_start, tile_m), ds(n_start, tile_n)),
|
||||
gmem_transform=mosaic_gpu.TileTransform(out_tiling),
|
||||
@ -404,14 +274,6 @@ def build_kernel(
|
||||
)
|
||||
|
||||
|
||||
def random_array(key, shape: tuple[int, ...], dtype: jnp.dtype):
|
||||
if jax.dtypes.issubdtype(dtype, np.floating):
|
||||
return random.uniform(key, shape, dtype=dtype)
|
||||
elif jax.dtypes.issubdtype(dtype, np.integer):
|
||||
return random.randint(key, shape, -127, 127, dtype)
|
||||
else:
|
||||
raise NotImplementedError(dtype)
|
||||
|
||||
def verify(
|
||||
m=(33 * 128),
|
||||
k=2048,
|
||||
@ -423,11 +285,7 @@ def verify(
|
||||
lhs_dtype=jnp.float16,
|
||||
rhs_dtype=jnp.float16,
|
||||
rhs_transpose=False,
|
||||
precision: F32Precision = F32Precision.DEFAULT,
|
||||
):
|
||||
# TODO(cperivol): Transpose is only supported for 16bit wgmma. ATM
|
||||
# that means bf16 x bf16, f16 x f16 and bf16 x s8. When we get more
|
||||
# general mixed precision this check will need to be more nuanced.
|
||||
if not rhs_transpose and jnp.dtype(lhs_dtype).itemsize != 2:
|
||||
raise ValueError(
|
||||
"Implicit transpose can only happen for 16bit types (or mixed precision"
|
||||
@ -435,17 +293,10 @@ def verify(
|
||||
)
|
||||
|
||||
kx, ky = random.split(random.key(1234))
|
||||
x = random_array(kx, (m, k), lhs_dtype)
|
||||
y = random_array(ky, (n, k) if rhs_transpose else (k, n), rhs_dtype)
|
||||
x = random.uniform(kx, (m, k), dtype=lhs_dtype)
|
||||
y = random.uniform(ky, (n, k) if rhs_transpose else (k, n), dtype=rhs_dtype)
|
||||
|
||||
if lhs_dtype != rhs_dtype:
|
||||
impl = WGMMACvtRhsImpl
|
||||
else:
|
||||
match precision:
|
||||
case F32Precision.DEFAULT:
|
||||
impl = WGMMADefaultImpl
|
||||
case F32Precision.TF32_X3:
|
||||
impl = WGMMATF32x3Impl
|
||||
impl = WGMMADefaultImpl
|
||||
|
||||
prof_spec = profiler.ProfilerSpec(4096) if profile else None
|
||||
f = build_kernel(
|
||||
|
@ -53,8 +53,9 @@ class MatmulTestCase(jtu.JaxTestCase):
|
||||
tile_m=(64, 128, 256),
|
||||
tile_n=(64, 128, 256),
|
||||
in_dtype=(jnp.float16, jnp.bfloat16), # f32 tested separately
|
||||
rhs_transpose=(False, True),
|
||||
)
|
||||
def test_matmul(self, m, k, n, stages, tile_m, tile_n, in_dtype):
|
||||
def test_matmul(self, m, k, n, stages, tile_m, tile_n, in_dtype, rhs_transpose):
|
||||
if stages * (128 // jnp.dtype(in_dtype).itemsize) > k:
|
||||
self.skipTest("Too many stages.")
|
||||
|
||||
@ -74,7 +75,7 @@ class MatmulTestCase(jtu.JaxTestCase):
|
||||
tile_n=tile_n,
|
||||
lhs_dtype=in_dtype,
|
||||
rhs_dtype=in_dtype,
|
||||
rhs_transpose=True,
|
||||
rhs_transpose=rhs_transpose,
|
||||
)
|
||||
except ValueError as e:
|
||||
if "Mosaic GPU kernel exceeds available shared memory" in str(e):
|
||||
@ -88,9 +89,8 @@ class MatmulTestCase(jtu.JaxTestCase):
|
||||
stages=(2, 4),
|
||||
tile_m=(64, 128, 256),
|
||||
tile_n=(64, 128, 256),
|
||||
high_precision=(False, True),
|
||||
)
|
||||
def test_matmul_f32(self, m, k, n, stages, tile_m, tile_n, high_precision):
|
||||
def test_matmul_f32(self, m, k, n, stages, tile_m, tile_n):
|
||||
if stages * (128 // jnp.dtype(jnp.float32).itemsize) > k:
|
||||
self.skipTest("Too many stages.")
|
||||
|
||||
@ -111,39 +111,12 @@ class MatmulTestCase(jtu.JaxTestCase):
|
||||
lhs_dtype=jnp.float32,
|
||||
rhs_dtype=jnp.float32,
|
||||
rhs_transpose=True,
|
||||
precision=(
|
||||
matmul.F32Precision.TF32_X3
|
||||
if high_precision
|
||||
else matmul.F32Precision.DEFAULT
|
||||
),
|
||||
)
|
||||
except ValueError as e:
|
||||
if "Mosaic GPU kernel exceeds available shared memory" in str(e):
|
||||
self.skipTest("Not enough shared memory for test, skipping.")
|
||||
raise e
|
||||
|
||||
@parameterized.parameters(
|
||||
dict(m=55 * 128, n=95 * 128, k=48 * 128, stages=4, tile_m=128),
|
||||
dict(m=55 * 128, n=45 * 128, k=48 * 128, stages=4, tile_m=128),
|
||||
dict(m=64, n=95 * 128, k=48 * 128, stages=4, tile_m=64),
|
||||
dict(m=64, n=45 * 128, k=48 * 128, stages=4, tile_m=64),
|
||||
)
|
||||
def test_mixed_matmul(self, m, k, n, stages, tile_m):
|
||||
# RHS.element_size==1b so k_tile=128
|
||||
if stages * 128 > k:
|
||||
self.skipTest("Too many stages.")
|
||||
|
||||
matmul.verify(
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
stages,
|
||||
tile_m=tile_m,
|
||||
rhs_transpose=False,
|
||||
lhs_dtype=jnp.bfloat16,
|
||||
rhs_dtype=jnp.int8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user