[Mosaic GPU] Simplify the matmul example

Remove a bunch of WGMMAImpl classes. This is meant to be a simple forkable example,
not a complete kernel.

PiperOrigin-RevId: 655923069
This commit is contained in:
Adam Paszke 2024-07-25 05:43:11 -07:00 committed by jax authors
parent be9cc807d9
commit e59303cf3e
2 changed files with 45 additions and 221 deletions

View File

@ -15,8 +15,8 @@
"""Matmul kernels for H100."""
import dataclasses
import enum
import functools
from typing import Any
import jax
from jax import random
@ -47,30 +47,19 @@ class Tiling:
n: int
k: int
@property
def mk(self):
return (self.m, self.k)
@property
def kn(self):
return (self.k, self.n)
@property
def nk(self):
return (self.n, self.k)
@property
def mn(self):
return (self.m, self.n)
class F32Precision(enum.Enum):
DEFAULT = enum.auto()
TF32_X3 = enum.auto()
# Allow access by .mk, .kn, .mn, etc.
def __getattr__(self, name):
if len(name) == 1:
return super().__getattribute__(name)
return tuple(getattr(self, d) for d in name)
class WGMMADefaultImpl:
"""Default WGMMA implementation."""
"""Default WGMMA implementation.
The kernel can accept any class that satisfies the same interface as this
class.
"""
@staticmethod
def zero_accs(tile_m: int, tile_n: int) -> WGMMAAccumulator:
@ -83,144 +72,32 @@ class WGMMADefaultImpl:
lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype,
rhs_transpose: WGMMALayout,
) -> dict[str, jax.ShapeDtypeStruct]:
del block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose
return {}
del block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose # Unused.
return ()
@staticmethod
def get_result_tile(acc: WGMMAAccumulator) -> FragmentedArray:
def get_result(acc: WGMMAAccumulator) -> FragmentedArray:
return acc.value
@staticmethod
def wgmma(
smem_scratch: dict[str, SmemRef], # pylint: disable=unused-argument
smem_scratch: Any, # pylint: disable=unused-argument
acc: WGMMAAccumulator,
b_order: WGMMALayout,
a_slice: SmemRef,
b_slice: SmemRef,
) -> dict[str, WGMMAAccumulator]:
"""Perform a matrix multiplication.
This function must guarantee that all WGMMA operations queued before it was
called have completed before returning.
"""
acc = wgmma(acc, a_slice, b_slice, b_order=b_order)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(1)
return acc
class WGMMATF32x3Impl:
"""WGMMA implementation for 3xTF32 precision."""
@staticmethod
def zero_accs(tile_m, tile_n) -> dict[str, WGMMAAccumulator]:
zero_acc = WGMMADefaultImpl.zero_accs(tile_m, tile_n)
return {"main": zero_acc, "errs": zero_acc}
@staticmethod
def smem_shape_extra(
block_tiling: Tiling,
tma_tiling: Tiling,
lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype,
rhs_transpose: bool,
) -> dict[str, jax.ShapeDtypeStruct]:
del rhs_transpose
lhs_err = jax.ShapeDtypeStruct(shape=tile_shape(block_tiling.mk, tma_tiling.mk), dtype=lhs_dtype)
rhs_err = jax.ShapeDtypeStruct(shape=tile_shape(block_tiling.kn, tma_tiling.kn), dtype=rhs_dtype)
return {"lhs_err": lhs_err, "rhs_err": rhs_err}
@staticmethod
def get_result_tile(accs) -> FragmentedArray:
return accs["main"].value + accs["errs"].value
@staticmethod
def rounding_error(x_ref, err_ref):
"""Store the TF32 rounding error of x_ref in err_ref."""
f32 = ir.F32Type.get()
i32 = ir.IntegerType.get_signless(32)
t = FragmentedArray.load_strided(x_ref)
tf32_mask = FragmentedArray.splat(c(0xFFFFE000, i32), t.shape, t.layout)
t_tf32 = (t.bitcast(i32) & tf32_mask).bitcast(f32)
(t - t_tf32).store_untiled(err_ref)
@staticmethod
def wgmma(
smem_scratch: dict[str, SmemRef],
accs: dict[str, WGMMAAccumulator],
b_order: WGMMALayout,
a_slice: SmemRef,
b_slice: SmemRef,
) -> dict[str, WGMMAAccumulator]:
acc = wgmma(accs["main"], a_slice, b_slice, b_order=b_order)
nvvm.wgmma_commit_group_sync_aligned()
# Note: we assert that only the slice_ab and err_b mmas are still running
# which are unaffected by writing to the err_a shared memory.
# After nvvm.wgmma_wait_group_sync_aligned(2) there are no wgmmas
# accessing err_a so we can safely write to it.
nvvm.wgmma_wait_group_sync_aligned(2)
WGMMATF32x3Impl.rounding_error(a_slice, smem_scratch["lhs_err"])
commit_shared()
acc_err = wgmma(accs["errs"], smem_scratch["lhs_err"], b_slice, b_order=b_order)
nvvm.wgmma_commit_group_sync_aligned()
# Note: similar to the above we wait for the last wgmma access to
# err_b which was 2 wgmmas ago.
nvvm.wgmma_wait_group_sync_aligned(2)
WGMMATF32x3Impl.rounding_error(b_slice, smem_scratch["rhs_err"])
commit_shared()
acc_err = wgmma(acc_err, a_slice, smem_scratch["rhs_err"], b_order=b_order)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(2)
return {"main": acc, "errs": acc_err}
class WGMMACvtRhsImpl:
"""Mixed WGMMA implementation where B is converted to A."""
@staticmethod
def zero_accs(tile_m: int, tile_n: int) -> WGMMAAccumulator:
return WGMMADefaultImpl.zero_accs(tile_m, tile_n)
@staticmethod
def smem_shape_extra(
block_tiling: Tiling,
tma_tiling: Tiling,
lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype,
rhs_transpose: bool,
) -> dict[str, jax.ShapeDtypeStruct]:
del rhs_dtype
if rhs_transpose:
raise NotImplementedError("Transpose requires more elaborate handling of tiling.")
if tma_tiling.k != 64:
raise ValueError(f"WGMMA layout needs the left tiling dimension to be 64 {tma_tiling.k=}")
# The second dim needs to be tma_tiling.k so it is 128b wide and
# the first dim needs to line up with the lhs dimension. That's
# why we have a strange (k, k) here.
cvt_shape = tile_shape(block_tiling.kn, (tma_tiling.k, tma_tiling.k))
return {"cvt": jax.ShapeDtypeStruct(shape=cvt_shape, dtype=lhs_dtype)}
@staticmethod
def get_result_tile(acc: WGMMAAccumulator) -> FragmentedArray:
return WGMMADefaultImpl.get_result_tile(acc)
@staticmethod
def wgmma(
smem_scratch: dict[str, SmemRef], # pylint: disable=unused-argument
acc: WGMMAAccumulator,
b_order: WGMMALayout,
a_slice: SmemRef,
b_slice: SmemRef,
) -> dict[str, WGMMAAccumulator]:
# Convert the load
arr = FragmentedArray.load_tiled(b_slice, swizzle=128)
cvt_ty = ir.MemRefType(smem_scratch["cvt"].type)
# TODO(cperivol): https://research.google/blog/mixed-input-matrix-multiplication-performance-optimizations/
arr = arr.astype(cvt_ty.element_type)
# Make sure no wgmma is running.
# TODO(cperivol): double buffer.
nvvm.wgmma_wait_group_sync_aligned(0)
arr.store_tiled(smem_scratch["cvt"], swizzle=128)
commit_shared()
acc = wgmma(acc, a_slice, smem_scratch["cvt"], b_order=b_order)
nvvm.wgmma_commit_group_sync_aligned()
return acc
def mlir_context(f):
def wrap(*args, **kw):
with mlir.make_ir_context(), ir.Location.unknown():
@ -281,21 +158,16 @@ def build_kernel(
c = arith.ConstantOp.create_index
compute_scratch_shapes = {
"lhs": jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.mk, tma_tiling.mk)), lhs_dtype),
"rhs": jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.kn, tma_tiling.kn)), rhs_dtype),
}
compute_scratch_shapes |= wgmma_impl.smem_shape_extra(block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose)
epilogue_scratch_shapes = {
"acc": jax.ShapeDtypeStruct(out_tile.shape, out_tile.dtype),
}
smem_shape = mosaic_gpu.Union(
[compute_scratch_shapes, epilogue_scratch_shapes])
compute_scratch_shape = (
jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.mk, tma_tiling.mk)), lhs_dtype),
jax.ShapeDtypeStruct((stages, *tile_shape(block_tiling.kn, tma_tiling.kn)), rhs_dtype),
wgmma_impl.smem_shape_extra(block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose),
)
epilogue_scratch_shape = jax.ShapeDtypeStruct(out_tile.shape, out_tile.dtype)
smem_shape = mosaic_gpu.Union([compute_scratch_shape, epilogue_scratch_shape])
def _main(ctx, a_device, b_device, c_device, smem):
(compute_smem, epilogue_smem), barriers = smem
((lhs_smem, rhs_smem, impl_smem), epilogue_smem), barriers = smem
memref.assume_alignment(c_device, 16)
@ -315,7 +187,7 @@ def build_kernel(
barrier.arrive_expect_tx(txcount)
ctx.async_copy(
src_ref=a_device,
dst_ref=memref_slice(compute_smem["lhs"], slot),
dst_ref=memref_slice(lhs_smem, slot),
gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)),
gmem_transform=mosaic_gpu.TileTransform(tma_tiling.mk),
**common_copy_args,
@ -328,7 +200,7 @@ def build_kernel(
assert tma_tiling.n == tma_tiling.k, block_tiling # No need to flip the tiling.
ctx.async_copy(
src_ref=b_device,
dst_ref=memref_slice(compute_smem["rhs"], slot),
dst_ref=memref_slice(rhs_smem, slot),
gmem_slice=rhs_slice,
gmem_transform=rhs_transform,
**common_copy_args,
@ -348,13 +220,12 @@ def build_kernel(
barriers[si].wait()
with ctx.named_region("WGMMA"):
a_slice = memref_slice(compute_smem["lhs"], si)
b_slice = memref_slice(compute_smem["rhs"], si)
a_slice = memref_slice(lhs_smem, si)
b_slice = memref_slice(rhs_smem, si)
rhs_smem_order = (
WGMMALayout.COL_MAJOR if rhs_transpose else WGMMALayout.ROW_MAJOR
)
accs = wgmma_impl.wgmma(
compute_smem, accs, rhs_smem_order, a_slice, b_slice)
accs = wgmma_impl.wgmma(impl_smem, accs, rhs_smem_order, a_slice, b_slice)
with ctx.named_region("TMA start"):
tma_ki = arith.addi(ki, c(stages - 1))
@ -370,19 +241,18 @@ def build_kernel(
return accs
# Wait until everyone is done with their WMMA
# Wait until WGMMA is complete and we can safely read the accumulator.
with ctx.named_region("WGMMA drain"):
nvvm.wgmma_wait_group_sync_aligned(0)
with ctx.named_region("SMEM store"):
acc_val = wgmma_impl.get_result_tile(stage_loop_body.result)
acc_smem = epilogue_smem["acc"]
acc_val.store_tiled(acc_smem, swizzle=128)
gpu.barrier()
acc_val = wgmma_impl.get_result(stage_loop_body.result)
acc_val.store_tiled(epilogue_smem, swizzle=128)
commit_shared() # Make sure the stores are visible to TMA.
with ctx.named_region("GMEM store"):
ctx.async_copy(
src_ref=acc_smem,
src_ref=epilogue_smem,
dst_ref=c_device,
gmem_slice=(ds(m_start, tile_m), ds(n_start, tile_n)),
gmem_transform=mosaic_gpu.TileTransform(out_tiling),
@ -404,14 +274,6 @@ def build_kernel(
)
def random_array(key, shape: tuple[int, ...], dtype: jnp.dtype):
if jax.dtypes.issubdtype(dtype, np.floating):
return random.uniform(key, shape, dtype=dtype)
elif jax.dtypes.issubdtype(dtype, np.integer):
return random.randint(key, shape, -127, 127, dtype)
else:
raise NotImplementedError(dtype)
def verify(
m=(33 * 128),
k=2048,
@ -423,11 +285,7 @@ def verify(
lhs_dtype=jnp.float16,
rhs_dtype=jnp.float16,
rhs_transpose=False,
precision: F32Precision = F32Precision.DEFAULT,
):
# TODO(cperivol): Transpose is only supported for 16bit wgmma. ATM
# that means bf16 x bf16, f16 x f16 and bf16 x s8. When we get more
# general mixed precision this check will need to be more nuanced.
if not rhs_transpose and jnp.dtype(lhs_dtype).itemsize != 2:
raise ValueError(
"Implicit transpose can only happen for 16bit types (or mixed precision"
@ -435,17 +293,10 @@ def verify(
)
kx, ky = random.split(random.key(1234))
x = random_array(kx, (m, k), lhs_dtype)
y = random_array(ky, (n, k) if rhs_transpose else (k, n), rhs_dtype)
x = random.uniform(kx, (m, k), dtype=lhs_dtype)
y = random.uniform(ky, (n, k) if rhs_transpose else (k, n), dtype=rhs_dtype)
if lhs_dtype != rhs_dtype:
impl = WGMMACvtRhsImpl
else:
match precision:
case F32Precision.DEFAULT:
impl = WGMMADefaultImpl
case F32Precision.TF32_X3:
impl = WGMMATF32x3Impl
impl = WGMMADefaultImpl
prof_spec = profiler.ProfilerSpec(4096) if profile else None
f = build_kernel(

View File

@ -53,8 +53,9 @@ class MatmulTestCase(jtu.JaxTestCase):
tile_m=(64, 128, 256),
tile_n=(64, 128, 256),
in_dtype=(jnp.float16, jnp.bfloat16), # f32 tested separately
rhs_transpose=(False, True),
)
def test_matmul(self, m, k, n, stages, tile_m, tile_n, in_dtype):
def test_matmul(self, m, k, n, stages, tile_m, tile_n, in_dtype, rhs_transpose):
if stages * (128 // jnp.dtype(in_dtype).itemsize) > k:
self.skipTest("Too many stages.")
@ -74,7 +75,7 @@ class MatmulTestCase(jtu.JaxTestCase):
tile_n=tile_n,
lhs_dtype=in_dtype,
rhs_dtype=in_dtype,
rhs_transpose=True,
rhs_transpose=rhs_transpose,
)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" in str(e):
@ -88,9 +89,8 @@ class MatmulTestCase(jtu.JaxTestCase):
stages=(2, 4),
tile_m=(64, 128, 256),
tile_n=(64, 128, 256),
high_precision=(False, True),
)
def test_matmul_f32(self, m, k, n, stages, tile_m, tile_n, high_precision):
def test_matmul_f32(self, m, k, n, stages, tile_m, tile_n):
if stages * (128 // jnp.dtype(jnp.float32).itemsize) > k:
self.skipTest("Too many stages.")
@ -111,39 +111,12 @@ class MatmulTestCase(jtu.JaxTestCase):
lhs_dtype=jnp.float32,
rhs_dtype=jnp.float32,
rhs_transpose=True,
precision=(
matmul.F32Precision.TF32_X3
if high_precision
else matmul.F32Precision.DEFAULT
),
)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" in str(e):
self.skipTest("Not enough shared memory for test, skipping.")
raise e
@parameterized.parameters(
dict(m=55 * 128, n=95 * 128, k=48 * 128, stages=4, tile_m=128),
dict(m=55 * 128, n=45 * 128, k=48 * 128, stages=4, tile_m=128),
dict(m=64, n=95 * 128, k=48 * 128, stages=4, tile_m=64),
dict(m=64, n=45 * 128, k=48 * 128, stages=4, tile_m=64),
)
def test_mixed_matmul(self, m, k, n, stages, tile_m):
# RHS.element_size==1b so k_tile=128
if stages * 128 > k:
self.skipTest("Too many stages.")
matmul.verify(
m,
k,
n,
stages,
tile_m=tile_m,
rhs_transpose=False,
lhs_dtype=jnp.bfloat16,
rhs_dtype=jnp.int8,
)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())