rocm_jax/tests/mosaic/gpu_test.py
Adam Paszke 30a9e1b3bf [Mosaic GPU] Add support for .cta_group::2 MMA with n=512 on Blackwell
This one is particularly annoying, because we have to break up the MMA
into two collective N=256 MMAs. However, TensorCore only updates a contiguous
chunk of columns in TMEM and so after executing two of those we end up with
a TMEM layout that looks like this:

```
Contributing CTA |    0    |    1    |    0    |    1    |
N local          |   0:128 |   0:128 | 128:256 | 128:256 |
N                |   0:128 | 256:384 | 128:256 | 384:512 |
```

You can see that the TMEM columns no longer monotonically go over all
columns until N=512, but they include a number of jumps!

We could fix this on the load side, by ensuring that each CTA in the group
does a strided load along the tiled dimension, but that just seems more
trouble than it's worth (and is not that well supported by TMA unless we
increase the number of striding levels).

Instead, we encode this weirdness in the TMEM layout we use and make sure
to rearrange the data properly while loading the tiles into registers.

PiperOrigin-RevId: 735791426
2025-03-11 09:53:20 -07:00

2979 lines
104 KiB
Python

# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Mosaic GPU DSL functions and utilities."""
from collections.abc import Sequence
import dataclasses
import enum
import itertools
import math
import operator
import os
import re
import unittest
from absl.testing import absltest, parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir import passmanager
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member
from jax.experimental.mosaic.gpu import fragmented_array as fa
from jax.experimental.mosaic.gpu import tcgen05
import jax.numpy as jnp
import numpy as np
try:
import jax._src.lib.mosaic_gpu # noqa: F401
HAS_MOSAIC_GPU = True
except ImportError:
HAS_MOSAIC_GPU = False
class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok
x = 0
y = 1
z = 2
else:
import jax.experimental.mosaic.gpu as mgpu
from jax.experimental.mosaic.gpu import core
from jax.experimental.mosaic.gpu import launch_context
from jax.experimental.mosaic.gpu import utils as utils
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
Dimension = gpu.Dimension
# ruff: noqa: F405
# pylint: disable=g-complex-comprehension
config.parse_flags_with_absl()
def nd_loop(bounds, body, *, _idxs = ()):
if not bounds:
body(*_idxs)
return
bound, *other_bounds = bounds
@fori(bound, ())
def _loop_body(i, _):
nd_loop(other_bounds, body, _idxs=(*_idxs, i))
return ()
def mlir_sum(elems):
assert elems
total = elems[0]
for elem in elems[1:]:
total = arith.addi(total, elem)
return total
def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
index = ir.IndexType.get()
thread_id = gpu.thread_id(gpu.Dimension.x)
stride = gpu.block_dim(gpu.Dimension.x)
for dim in (gpu.Dimension.y, gpu.Dimension.z):
thread_id = arith.addi(thread_id, arith.muli(gpu.thread_id(dim), stride))
stride = arith.muli(stride, gpu.block_dim(dim))
is_first_thread = arith.cmpi(arith.CmpIPredicate.eq, thread_id, c(0, index))
src_ty = ir.MemRefType(src.type)
dst_ty = ir.MemRefType(dst.type)
if src_ty.shape != dst_ty.shape:
raise ValueError(
f"src and dst shapes don't match: {src_ty.shape} != {dst_ty.shape}"
)
shape = src_ty.shape
if src_ty.element_type != dst_ty.element_type:
raise ValueError(
f"src and dst element types don't match: {src_ty.element_type} !="
f" {dst_ty.element_type}"
)
contig_strides = get_contiguous_strides(shape)
# If swizzling is on, at least one of the memrefs must be contiguous
# (simulating a TMA).
if (swizzle is not None and
src_ty.get_strides_and_offset()[0] != contig_strides and
dst_ty.get_strides_and_offset()[0] != contig_strides):
raise NotImplementedError(src_ty, dst_ty)
bw = bitwidth(src_ty.element_type)
if bw < 8:
assert bw.bit_count() == 1
packing = 8 // bw
if shape[-1] % packing:
raise NotImplementedError
workgroup_mem = ir.Attribute.parse("#gpu.address_space<workgroup>")
shape = (*shape[:-1], shape[-1] // packing)
contig_strides = get_contiguous_strides(shape)
def bitcast(ref):
ref_ty = ir.MemRefType(ref.type)
old_strides = ref_ty.get_strides_and_offset()[0]
if old_strides[-1] != 1:
raise NotImplementedError
new_strides = [s // packing for s in old_strides[:-1]] + [1]
new_ref_ty = ir.MemRefType.get(
shape,
ir.VectorType.get((packing,), src_ty.element_type), # noqa: F821
ir.StridedLayoutAttr.get(0, new_strides),
ref_ty.memory_space,
)
ptr_space = (
3
if ref_ty.memory_space is not None
and ref_ty.memory_space == workgroup_mem
else None
)
return ptr_as_memref(
# NOTE: memref_ptr applies the offset in case there was any.
memref_ptr(ref, memory_space=ptr_space),
new_ref_ty,
ptr_memory_space=ptr_space,
)
src = bitcast(src)
dst = bitcast(dst)
bw = 8
del src_ty, dst_ty # If you remove this, update it in the branch above
dyn_strides = [c(s, index) for s in contig_strides]
with ir.InsertionPoint(scf.IfOp(is_first_thread).then_block):
def body(*idx):
dst_idx = idx
if swizzle is not None:
assert swizzle.bit_count() == 1
assert bw % 8 == 0
bytes_per_element = bw // 8
linear_idx = c(0, index)
for stride, i in zip(dyn_strides, idx):
linear_idx = arith.addi(linear_idx, arith.muli(i, stride))
# Swizzle pattern repeats every 128 bytes.
swizzle_src = arith.remui(
arith.divui(linear_idx, c(128 // bytes_per_element, index)),
c(swizzle // 16, index),
)
# Swizzle happens in groups of 16 bytes.
swizzle_shift = 4 - (bytes_per_element.bit_length() - 1)
dst_linear_idx = arith.xori(
linear_idx, arith.shli(swizzle_src, c(swizzle_shift, index))
)
dst_idx = [
arith.remui(arith.divui(dst_linear_idx, stride), c(bound, index))
for stride, bound in zip(dyn_strides, shape)
]
memref.store(memref.load(src, idx), dst, dst_idx)
nd_loop([c(d, index) for d in shape], body)
scf.yield_([])
gpu.barrier()
nvvm.fence_proxy(nvvm.ProxyKind.async_)
def iota_tensor(m, n, dtype: jax.typing.DTypeLike, tiled_layout=False):
assert m % 64 == 0
assert n % 8 == 0
def c(i):
return arith.constant(index, ir.IntegerAttr.get(index, i))
index = ir.IndexType.get()
i32 = ir.IntegerType.get_signless(32)
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32))
within_warp_id = arith.remui(gpu.thread_id(gpu.Dimension.x), c(32))
warp_row_start = arith.muli(warp_id, c(16))
within_warp_row = arith.divui(within_warp_id, c(4))
start_row = arith.addi(warp_row_start, within_warp_row)
start_col = arith.muli(arith.remui(within_warp_id, c(4)), c(2))
registers = np.empty((m // 64, n // 8, 2, 1), dtype=object)
for row_tile, col_tile, row_subtile, _ in np.ndindex(registers.shape):
row = arith.addi(start_row, c(row_tile * 64 + row_subtile * 8))
col = arith.addi(start_col, c(col_tile * 8))
row_value_base = arith.muli(row, c(n))
vec = llvm.mlir_undef(ir.VectorType.get((2,), i32))
for col_offset in range(2):
value = arith.addi(row_value_base, arith.addi(c(col_offset), col))
value = arith.index_cast(i32, value)
vec = vector.insertelement(value, vec, position=c(col_offset))
registers[row_tile, col_tile, row_subtile, 0] = vec
t = mgpu.FragmentedArray(
_registers=registers, _layout=mgpu.WGMMA_LAYOUT, _is_signed=True
)
if tiled_layout:
t = t.to_layout(mgpu.TILED_LAYOUT_WGMMA)
return t.astype(
utils.dtype_to_ir_type(dtype), is_signed=utils.is_signed(dtype)
)
class TestCase(parameterized.TestCase):
def setUp(self):
if not HAS_MOSAIC_GPU:
self.skipTest("jaxlib built without Mosaic GPU")
if (not jtu.test_device_matches(["cuda"]) or
not jtu.is_cuda_compute_capability_at_least("9.0")):
self.skipTest("Only works on GPU with capability >= sm90")
super().setUp()
self.prng = np.random.default_rng(1234)
self.context = mlir.make_ir_context()
if mgpu_dialect is not None:
mgpu_dialect.register_dialect(self.context)
self.enter_context(config.traceback_filtering("off"))
self.enter_context(self.context)
self.enter_context(ir.Location.unknown())
class Sm90ATestCase(TestCase, jtu.CudaArchSpecificTest):
def setUp(self):
self.skip_unless_sm90a()
super().setUp()
class TestUtilTest(TestCase):
def test_copy_basic(self):
def kernel(ctx, src, dst, _):
copy(src, dst)
x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3)
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x)
np.testing.assert_array_equal(y, x)
def test_copy_swizzle(self):
def kernel(ctx, src, dst, _):
copy(src, dst, swizzle=128)
x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32)
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x)
expected = np.zeros_like(y)
for i in range(8):
for j in range(8):
js = j ^ i
expected[i, (j * 4):(j * 4) + 4] = x[i, (js * 4):(js * 4) + 4]
np.testing.assert_array_equal(y, expected)
def test_copy_swizzle_noop(self):
# Two swizzles cancel out
def kernel(ctx, src, dst, smem):
copy(src, smem, swizzle=128)
copy(smem, dst, swizzle=128)
x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32)
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
np.testing.assert_array_equal(y, x)
@parameterized.parameters(False, True)
def test_iota_tensor(self, tiled_layout):
m = n = 64
def kernel(ctx, dst, _):
f32 = ir.F32Type.get()
index = ir.IndexType.get()
registers = iota_tensor(m, n, jnp.float32, tiled_layout).registers
assert registers.size == 16, registers.size
for i, vec_reg in enumerate(registers.flat):
for j in range(2):
reg = vector.extractelement(vec_reg, position=c(j, index))
memref.store(
reg, dst, [gpu.thread_id(gpu.Dimension.x), c(2 * i + j, index)]
)
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
regs = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
thread_ids = np.arange(128)
warp_ids = thread_ids // 32
lane_ids = thread_ids % 32
thread_rows = warp_ids * 16 + lane_ids // 4
thread_start_cols = (lane_ids % 4) * 2
thread_cols = thread_start_cols[:, None] + (np.arange(n // 8)[None] * 8)
regs = regs.reshape(128, 8, 2, 2)
for row_half in range(2):
for col_half in range(2):
np.testing.assert_array_equal(
regs[..., row_half, col_half],
(thread_rows[:, None] + row_half * 8) * n + thread_cols + col_half
)
class MemRefTest(TestCase):
@parameterized.product(
dim=tuple(range(3)),
strided=(False, True)
)
def test_unsqueeze(self, dim, strided):
def kernel(ctx, inp, out, _):
if strided:
for i in range(8):
s = ds(i, 1)
out_slice = s if dim != 0 else (slice(None), s)
copy(
memref_unsqueeze(memref_slice(inp, s), dim),
memref_slice(out, out_slice),
)
else:
copy(memref_unsqueeze(inp, dim), out)
x = np.arange(8 * 16, dtype=jnp.float32).reshape(8, 16)
out_shape = list(x.shape)
out_shape.insert(dim, 1)
out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32)
y = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
)(x)
np.testing.assert_array_equal(y, x.reshape(out_shape))
@parameterized.product(
dim=tuple(range(2)),
strided=(False, True)
)
def test_unfold(self, dim, strided):
in_shape = (8, 16)
def kernel(ctx, inp, out, _):
if strided:
# We slice the dim we don't unfold
for i in range(in_shape[1 - dim] // 4):
s = ds(i * 4, 4)
in_slice = s if dim == 1 else (slice(None), s)
out_slice = s if dim == 1 else (slice(None),) * 3 + (s,)
copy(
memref_unfold(memref_slice(inp, in_slice), dim, (2, 2, None)),
memref_slice(out, out_slice),
)
else:
copy(memref_unfold(inp, dim, (2, 2, None)), out)
x = np.arange(np.prod(in_shape), dtype=jnp.float32).reshape(in_shape)
out_shape = list(in_shape)
out_shape[dim:dim + 1] = [2, 2, out_shape[dim] // 4]
out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32)
y = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
)(x)
np.testing.assert_array_equal(y, x.reshape(out_ty.shape))
@parameterized.product(
dim=tuple(range(2)),
)
def test_fold_not_strided(self, dim):
def kernel(ctx, inp, out, _):
copy(memref_fold(inp, dim, 2), out)
x = np.arange(8 * 2 * 8, dtype=jnp.float32).reshape(8, 2, 8)
out_ty = jax.ShapeDtypeStruct((16, 8) if dim == 0 else (8, 16), jnp.float32)
y = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
)(x)
np.testing.assert_array_equal(y, x.reshape(out_ty.shape))
@parameterized.named_parameters(
("drop_1s", (1, 1, 5, 1, 1, 2, 1, 1), (5, 1, 2)),
("add_1s", (5, 1, 2), (1, 1, 5, 1, 1, 2, 1, 1)),
("fold", (1, 5, 2, 1,), (1, 10, 1)),
("un", (1, 10, 1), (1, 5, 2, 1,)),
)
def test_reshape(self, inp_shape, out_shape):
def kernel(ctx, inp, out, _):
copy(memref_reshape(inp, out_shape), out)
x = np.arange(math.prod(inp_shape), dtype=jnp.float32).reshape(*inp_shape)
out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32)
y = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
)(x)
np.testing.assert_array_equal(y, x.reshape(*out_shape))
@parameterized.named_parameters([
("packed", (4, 4, 4), (16, 4, 1), 1, 2, False),
("strided_end", (4, 4, 4, 4), (256, 64, 16, 4), 1, 2, False),
("strided_bot", (4, 4, 4, 4), (256, 16, 4, 1), 1, 2, False),
("strided_top", (4, 4, 4, 4), (256, 64, 4, 1), 1, 2, True),
("strided_mid", (4, 4, 4, 4), (265, 64, 16, 1), 1, 3, True),
# TODO(cperivol): Investigate why this is indexing OOB and uncomment.
# ("overap", (2, 4, 4), (16, 1, 1), 0, 3, True),
])
def test_fold_strided(
self, shape, strides, dim, fold_rank, throws_not_impl
):
expanded_shape = get_packed_shape(strides, shape)
total_size = np.prod(expanded_shape)
np_inp = np.arange(total_size, dtype=jnp.float32).reshape(expanded_shape)
index = tuple(slice(0, s) for s in shape)
# Reference implementation
def np_fold(inp, dim, fold_rank):
out_shape = list(inp.shape)
out_shape[dim : dim + fold_rank] = [
int(np.prod(inp.shape[dim : dim + fold_rank]))
]
if throws_not_impl:
return jax.ShapeDtypeStruct(shape=out_shape, dtype=inp.dtype)
else:
return inp.reshape(*out_shape)
total_size = np.prod(shape) * np.prod(strides)
def do_test():
def kernel(ctx, inp, out, _):
copy(memref_fold(memref_slice(inp, index), dim, fold_rank), out)
out = np_fold(np_inp[index], dim, fold_rank)
y = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), np_inp, out, ()
)(np_inp)
assert (
not throws_not_impl
), "If it should have thrown it would during the call."
np.testing.assert_array_equal(y, out)
if throws_not_impl:
with self.assertRaises(NotImplementedError):
do_test()
else:
do_test()
@parameterized.parameters(jnp.uint64, jnp.uint32, jnp.uint16, jnp.uint8)
def test_scalar_argument(self, dtype):
if dtype == jnp.uint64 and not config.enable_x64.value:
self.skipTest(
"64-bit types are disabled: this leads to the input scalar being"
" traced as a uint32 value, which causes the top 32 bits of the 64-bit"
" values read from the 32-bit input buffer to sometimes"
" (nondeterministically) contain garbage.")
scalar = 42
expected = np.full((128, 128), scalar, dtype=dtype)
def kernel(ctx, inp, out, _):
del ctx
inp = memref.load(inp, [])
mgpu.FragmentedArray.splat(inp, expected.shape, is_signed=True).store_untiled(out)
res = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
jax.ShapeDtypeStruct(shape=(), dtype=expected.dtype),
expected,
(),
)(scalar)
np.testing.assert_array_equal(res, expected)
def get_packed_shape(strides, shape):
perm = sorted(range(len(strides)), key=lambda i: strides[i], reverse=True)
ordered_strides = [strides[i] for i in perm]
ordered_shape = [shape[i] for i in perm]
packed_shape = [ordered_shape[-1]]
packed_shape += [
stride0 // stride
for stride0, stride in zip(ordered_strides, ordered_strides[1:])
]
# Invert permutation
inv_perm = [None] * len(perm)
for i, p in enumerate(perm):
inv_perm[p] = i
return [packed_shape[i] for i in inv_perm]
class WGMMALayoutTest(TestCase):
@parameterized.product(dtype=[jnp.float16, jnp.float32],
tiled_layout=[False, True],
transposed_smem=[False, True])
def test_store_untiled(self, dtype, tiled_layout, transposed_smem):
def kernel(ctx, out, _):
del ctx
if transposed_smem:
out = memref_transpose(out, (1, 0))
iota_tensor(64, 64, dtype, tiled_layout=tiled_layout).store_untiled(
out, vector_store=not transposed_smem
)
expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64)
if transposed_smem:
expected = expected.T
iota = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), expected, ()
)()
np.testing.assert_array_equal(iota, expected)
@parameterized.product(
dtype=[jnp.float32, jnp.float16, jnp.int8],
swizzle=(32, 64, 128),
num_col_tiles=(1, 2, 3),
tiled_layout=(False, True),
)
def test_store_tiled(self, dtype, swizzle, num_col_tiles, tiled_layout):
mlir_dtype = utils.dtype_to_ir_type(dtype)
if bytewidth(mlir_dtype) > 2 and swizzle == 32:
self.skipTest("Not implemented")
col_tiling = swizzle // bytewidth(mlir_dtype)
m = 128
n = col_tiling * num_col_tiles
tiling = (64, col_tiling)
def kernel(ctx, out, smem):
del ctx
iota_tensor(m, n, dtype, tiled_layout).store_tiled(smem, swizzle=swizzle)
copy(smem, out, swizzle=swizzle)
expected = (
np.arange(m * n, dtype=dtype)
.reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1])
.transpose(0, 2, 1, 3)
)
iota = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), expected, expected
)()
np.testing.assert_array_equal(iota, expected)
@parameterized.product(
dtype=[jnp.float16, jnp.int8],
swizzle=(32, 64, 128),
)
def test_store_tiled_short_n(self, dtype, swizzle):
mlir_dtype = utils.dtype_to_ir_type(dtype)
col_tiling = swizzle // bytewidth(mlir_dtype)
m = 128
n = 16 // bytewidth(mlir_dtype)
tiling = (64, col_tiling)
def kernel(ctx, out, smem):
iota_tensor(m, n, dtype).store_tiled(smem, swizzle=swizzle)
ctx.async_copy(
src_ref=smem,
dst_ref=out,
swizzle=swizzle,
gmem_slice=(ds(0, m), ds(0, col_tiling)),
gmem_transform=mgpu.TileTransform(tiling),
)
ctx.await_async_copy(0)
smem_shape = jax.ShapeDtypeStruct((m // tiling[0], 1, *tiling), dtype)
expected = np.arange(m * n, dtype=dtype).reshape(m, n)
iota = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), expected, smem_shape
)()
np.testing.assert_array_equal(iota, expected)
@parameterized.named_parameters(
("bf16_i8", jnp.bfloat16, jnp.int8),
("i8_bf16", jnp.int8, jnp.bfloat16),
("i8_i8", jnp.int8, jnp.int8),
("i4_i4", jnp.int4, jnp.int4),
("i4_bf16", jnp.int4, jnp.bfloat16),
)
def test_convert_tiled(self, jax_dtype_from, jax_dtype_to):
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
m = 128
n = 256 * 8 // bitwidth(mlir_dtype_from)
def kernel(ctx, inp, out, smem):
del ctx
smem_from, smem_to = smem
copy(inp, smem_from, swizzle=128)
t = mgpu.FragmentedArray.load_tiled(
smem_from,
swizzle=128,
is_signed=utils.is_signed(jax_dtype_from),
layout=fa._tiled_wgmma_layout((m, n))
)
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
t.store_tiled(smem_to, swizzle=128)
copy(smem_to, out, swizzle=128)
from_tiling = (64, 128 * 8 // bitwidth(mlir_dtype_from))
to_tiling = (64, 128 * 8 // bitwidth(mlir_dtype_to))
# We only test lossless conversions for now.
# TODO(apaszke): Test and fix failures that appear with lossy conversions.
int_sample_dtype = getattr(
jnp,
"int" + str(min(bitwidth(mlir_dtype_from), bitwidth(mlir_dtype_to))),
)
sample_iinfo = jnp.iinfo(int_sample_dtype)
expected_raw = self.prng.integers(
low=sample_iinfo.min, high=sample_iinfo.max,
size=(m, n), dtype=np.int32
)
expected = lambda jax_dtype, tiling: expected_raw.reshape(
m // tiling[0], tiling[0], n // tiling[1], tiling[1]
).transpose(0, 2, 1, 3).astype(jax_dtype)
expected_from = expected(jax_dtype_from, from_tiling)
expected_to = expected(jax_dtype_to, to_tiling)
res = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
expected_from,
expected_to,
(expected_from, expected_to),
)(expected_from)
np.testing.assert_array_equal(res, expected_to)
@parameterized.named_parameters(
("f32", jnp.float32),
("f16", jnp.float16),
("i8", jnp.int8),
)
def test_load_tiled(self, jax_dtype):
mlir_dtype = utils.dtype_to_ir_type(jax_dtype)
m = 128
n = 256 // bytewidth(mlir_dtype)
tiling = (64, 128 // bytewidth(mlir_dtype))
def kernel(ctx, in_, out, smem):
del ctx
smem1, smem2 = smem
copy(in_, smem1, swizzle=128)
t = mgpu.FragmentedArray.load_tiled(
smem1, swizzle=128, is_signed=utils.is_signed(jax_dtype)
)
t.store_tiled(smem2, swizzle=128)
copy(smem2, out, swizzle=128)
expected = (
np.arange(m * n, dtype=jax_dtype)
.reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1])
.transpose(0, 2, 1, 3)
)
iota = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), expected, expected, (expected,) * 2
)(expected)
np.testing.assert_array_equal(iota, expected)
class WGMMATest(TestCase):
def setUp(self):
super().setUp()
if not jtu.is_cuda_compute_capability_equal("9.0"):
self.skipTest("Only works on GPU with capability sm90a")
@parameterized.product(
lhs_transpose=(False, True),
rhs_transpose=(False, True),
in_mlir_dtype_cls=(ir.F16Type, ir.BF16Type, ir.F32Type),
m=(64, 128, 192),
n=(64, 128, 192),
k_steps=(1, 2),
swizzle=(32, 64, 128),
jax_out_dtype=(jnp.float16, jnp.float32),
rhs_tiling_kind=("large", "small", "small+no_transpose"),
lhs_tiling_kind=("large", "small", "small+no_transpose"),
)
def test_wgmma_basic(
self,
m,
n,
k_steps,
in_mlir_dtype_cls,
lhs_transpose,
rhs_transpose,
swizzle,
jax_out_dtype,
rhs_tiling_kind,
lhs_tiling_kind,
):
if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type:
self.skipTest("Only f16 input is supported for f16 output.")
if swizzle != 128 and lhs_transpose and lhs_tiling_kind == "large":
self.skipTest("Transpose only supported in 128B swizzled WGMMA")
if rhs_tiling_kind == "small+no_transpose" and not rhs_transpose:
self.skipTest("No transpose happening anyway")
if lhs_tiling_kind == "small+no_transpose" and not lhs_transpose:
self.skipTest("No transpose happening anyway")
in_mlir_dtype = in_mlir_dtype_cls.get()
out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype)
if ir.F32Type.isinstance(in_mlir_dtype): # We actually use tf32 instead
in_jax_dtype = jnp.float32
if lhs_transpose or not rhs_transpose:
self.skipTest("Transpose only supported in 16-bit WGMMA")
exponent_bits, mantissa_bits = 8, 10 # Use tf32
elif bytewidth(in_mlir_dtype) == 2:
if n % 64 != 0:
self.skipTest("16-bit WGMMA only supports n % 64 == 0")
if ir.F16Type.isinstance(in_mlir_dtype):
in_jax_dtype = jnp.float16
exponent_bits, mantissa_bits = 5, 10
elif ir.BF16Type.isinstance(in_mlir_dtype):
in_jax_dtype = jnp.bfloat16
exponent_bits, mantissa_bits = 8, 7
else:
raise NotImplementedError(in_mlir_dtype)
else:
raise NotImplementedError(in_mlir_dtype)
nk_tile = swizzle // bytewidth(in_mlir_dtype)
k = nk_tile * k_steps
assert m % 64 == 0 and n % nk_tile == 0
small_rhs_tile = rhs_tiling_kind != "large"
transpose_rhs_tiles = rhs_tiling_kind != "small+no_transpose"
rhs_tiling = (8, nk_tile) if small_rhs_tile else (nk_tile, nk_tile)
small_lhs_tile = lhs_tiling_kind != "large"
transpose_lhs_tiles = lhs_tiling_kind != "small+no_transpose"
lhs_tiling = (8, nk_tile) if small_lhs_tile else (64, nk_tile)
def kernel(ctx, lhs, rhs, out, scratch):
lhs_smem, rhs_smem, barriers = scratch
lhs_transform = (mgpu.TileTransform(lhs_tiling),)
if lhs_transpose and transpose_lhs_tiles:
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
rhs_transform = (mgpu.TileTransform(rhs_tiling),)
if rhs_transpose and transpose_rhs_tiles:
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
ctx.async_copy(
src_ref=lhs,
dst_ref=lhs_smem,
swizzle=swizzle,
gmem_transform=lhs_transform,
barrier=barriers[0],
)
ctx.async_copy(
src_ref=rhs,
dst_ref=rhs_smem,
swizzle=swizzle,
gmem_transform=rhs_transform,
barrier=barriers[1],
)
for i in range(2):
barriers[i].wait()
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype)
if lhs_transpose:
perm = (0, 1, 3, 2) if transpose_lhs_tiles else (1, 0, 3, 2)
lhs_smem = memref_transpose(lhs_smem, perm)
if rhs_transpose:
perm = (0, 1, 3, 2) if transpose_rhs_tiles else (1, 0, 3, 2)
rhs_smem = memref_transpose(rhs_smem, perm)
acc = mgpu.wgmma(init_acc, lhs_smem, rhs_smem, swizzle=swizzle)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
acc.value.store_untiled(out)
def quantize(x):
# Quantize the input to avoid rounding when feeding the WGMMA
return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits)
x_shape = (k, m) if lhs_transpose else (m, k)
x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype)
y_shape = (n, k) if rhs_transpose else (k, n)
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype)
if transpose_rhs_tiles:
rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling
rhs_smem_shape = (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling)
else:
rhs_smem_shape = tile_shape(y_shape, rhs_tiling)
if transpose_lhs_tiles:
lhs_tiling_t = lhs_tiling[::-1] if lhs_transpose else lhs_tiling
lhs_smem_shape = (m // lhs_tiling_t[0], k // lhs_tiling_t[1], *lhs_tiling)
else:
lhs_smem_shape = tile_shape(x_shape, lhs_tiling)
scratch_shape = [
jax.ShapeDtypeStruct(lhs_smem_shape, in_jax_dtype),
jax.ShapeDtypeStruct(rhs_smem_shape, in_jax_dtype),
mgpu.TMABarrier(2),
]
z = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape
)(x, y)
x32, y32 = x.astype(np.float32), y.astype(np.float32)
ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32)
atol = 2e-2 if jax_out_dtype == jnp.float16 else 5e-6
np.testing.assert_allclose(z, ref, atol=atol)
# TODO(apaszke): Add support for f32
@parameterized.product(
m=(64, 128, 192),
n=(64, 128, 192),
k_steps=(1, 2),
rhs_transpose=(False, True),
swizzle=(32, 64, 128),
dtype=[jnp.float16, jnp.bfloat16],
tiled_layout=(False, True),
)
def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, swizzle, dtype, tiled_layout):
index = ir.IndexType.get()
bytewidth = 2
nk_tile = swizzle // bytewidth
k = nk_tile * k_steps
def kernel(ctx, rhs, out, rhs_smem):
del ctx
for ki in range(k_steps):
for ni in range(n // nk_tile):
rhs_slice = (
ds(c(ki * nk_tile, index), nk_tile),
ds(c(ni * nk_tile, index), nk_tile),
)
if rhs_transpose:
rhs_slice = rhs_slice[::-1]
copy(
src=memref_slice(rhs, rhs_slice),
dst=memref_slice(rhs_smem, (ki, ni)),
swizzle=swizzle,
)
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
lhs_regs = iota_tensor(m, k, dtype, tiled_layout)
if rhs_transpose:
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
acc.value.store_untiled(out)
y_shape = (n, k) if rhs_transpose else (k, n)
y = self.prng.uniform(-1, 1, y_shape).astype(dtype)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
scratch_shape = jax.ShapeDtypeStruct(
(k_steps, n // nk_tile, nk_tile, nk_tile), dtype
)
z = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape
)(y)
x = np.arange(m * k, dtype=dtype).reshape(m, k)
ref = jax.lax.dot(
x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32
)
rtol = 5e-4
np.testing.assert_allclose(z, ref, rtol=rtol, atol=0)
@parameterized.product(
rhs_transpose=(False, True),
swizzle=(32, 64, 128),
n=(8, 16),
small_rhs_tile=(False, True),
)
def test_narrow_n(self, rhs_transpose, swizzle, n, small_rhs_tile):
m, k_steps = 64, 2
bytewidth = 2
nk_tile = swizzle // bytewidth
k = nk_tile * k_steps
if small_rhs_tile and not rhs_transpose:
self.skipTest("Small tiles only supported for transposed RHS")
n_tile = 8 if small_rhs_tile else nk_tile
def kernel(ctx, rhs, out, smem):
rhs_smem, barrier = smem
gmem_slice = (ds(0, k), ds(0, max(n_tile, n)))
transform = (mgpu.TileTransform((n_tile, nk_tile)),)
if rhs_transpose:
gmem_slice = gmem_slice[::-1]
transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
ctx.async_copy(
src_ref=rhs,
dst_ref=rhs_smem,
swizzle=swizzle,
gmem_slice=gmem_slice,
gmem_transform=transform,
barrier=barrier,
)
barrier.wait()
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
lhs_regs = iota_tensor(m, k, jnp.float16)
if rhs_transpose:
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
if not small_rhs_tile:
smem_slice = (slice(None), slice(None), slice(None), ds(0, n))
rhs_smem = memref_slice(rhs_smem, smem_slice)
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
acc.value.store_untiled(out)
jax_dtype = jnp.float16
y_shape = (n, k) if rhs_transpose else (k, n)
y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
rhs_scratch_shape = jax.ShapeDtypeStruct(
(k_steps, (n + n_tile - 1) // n_tile, n_tile, nk_tile), jax_dtype
)
z = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, (rhs_scratch_shape, mgpu.TMABarrier()),
)(y)
x = np.arange(m * k, dtype=jax_dtype).reshape(m, k)
ref = jax.lax.dot(
x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32
)
np.testing.assert_allclose(z, ref, rtol=5e-4, atol=0)
class TCGen05Test(TestCase):
def setUp(self):
super().setUp()
capabilities = ("10.0", "10.1")
if not any(jtu.is_cuda_compute_capability_equal(sm) for sm in capabilities):
self.skipTest("Only works on GPU with capability sm_100a or sm_101a")
@parameterized.product(
lhs_transpose=(False, True),
rhs_transpose=(False, True),
in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32
out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation
m=(128,), # TODO(apaszke): 64, 192, 256
n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2
k_steps=(1, 2),
swizzle=(32, 64, 128,),
rhs_transpose_tiles=(False, True),
lhs_transpose_tiles=(False, True),
)
def test_mma_basic(
self,
m,
n,
k_steps,
swizzle,
lhs_transpose,
rhs_transpose,
in_jax_dtype,
out_jax_dtype,
rhs_transpose_tiles,
lhs_transpose_tiles,
):
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
self.skipTest("Only f16 input is supported for f16 output.")
in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype)
swizzle_elems = swizzle // bytewidth(in_mlir_dtype)
k = swizzle_elems * k_steps
lhs_tiling = rhs_tiling = (8, swizzle_elems)
def kernel(ctx, lhs, rhs, out, scratch):
lhs_smem, rhs_smem, barriers, acc = scratch
lhs_transform = (mgpu.TileTransform(lhs_tiling),)
if lhs_transpose_tiles:
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
rhs_transform = (mgpu.TileTransform(rhs_tiling),)
if rhs_transpose_tiles:
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
ctx.async_copy(
src_ref=lhs,
dst_ref=lhs_smem,
swizzle=swizzle,
gmem_transform=lhs_transform,
barrier=barriers[0],
)
ctx.async_copy(
src_ref=rhs,
dst_ref=rhs_smem,
swizzle=swizzle,
gmem_transform=rhs_transform,
barrier=barriers[1],
)
barriers[0].wait()
barriers[1].wait()
with mgpu.single_thread():
if lhs_transpose_tiles:
lhs_smem = memref_transpose(lhs_smem, (1, 0, 2, 3))
if lhs_transpose:
lhs_smem = memref_transpose(lhs_smem, (1, 0, 3, 2))
if rhs_transpose_tiles:
rhs_smem = memref_transpose(rhs_smem, (1, 0, 2, 3))
if rhs_transpose:
rhs_smem = memref_transpose(rhs_smem, (1, 0, 3, 2))
tcgen05.mma(
acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False,
)
tcgen05.commit_arrive(barriers[2])
barriers[2].wait(for_tensor_core=True)
acc[:].store_untiled(out)
in_finfo = jnp.finfo(in_jax_dtype)
exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant
def quantize(x):
# Quantize the input to avoid rounding when feeding the TensorCore
return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits)
x_shape = (k, m) if lhs_transpose else (m, k)
x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype)
y_shape = (n, k) if rhs_transpose else (k, n)
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
if rhs_transpose_tiles:
rhs_smem_shape = (
y_shape[1] // rhs_tiling[1], y_shape[0] // rhs_tiling[0], *rhs_tiling,
)
else:
rhs_smem_shape = tile_shape(y_shape, rhs_tiling)
if lhs_transpose_tiles:
lhs_smem_shape = (
x_shape[1] // lhs_tiling[1], x_shape[0] // lhs_tiling[0], *lhs_tiling,
)
else:
lhs_smem_shape = tile_shape(x_shape, lhs_tiling)
scratch_shape = [
jax.ShapeDtypeStruct(lhs_smem_shape, in_jax_dtype),
jax.ShapeDtypeStruct(rhs_smem_shape, in_jax_dtype),
mgpu.TMABarrier(3),
mgpu.TMEM((128, n), out_jax_dtype),
]
z = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape
)(x, y)
x32, y32 = x.astype(np.float32), y.astype(np.float32)
ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32)
atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6
np.testing.assert_allclose(z, ref, atol=atol)
@parameterized.product(
lhs_transpose=(False, True),
rhs_transpose=(False, True),
in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32
out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation
m=(256,), # TODO(apaszke): 64, 192, 256
n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2
k_steps=(1, 2),
swizzle=(32, 64, 128,),
)
def test_mma_collective(
self,
m,
n,
k_steps,
swizzle,
lhs_transpose,
rhs_transpose,
in_jax_dtype,
out_jax_dtype,
):
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
raise self.skipTest("Only f16 input is supported for f16 output.")
in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype)
m_block_tile = m // 2
n_block_tile = n // 2
swizzle_elems = swizzle // bytewidth(in_mlir_dtype)
k = swizzle_elems * k_steps
index = ir.IndexType.get()
tiling = (8, swizzle_elems)
def kernel(ctx, lhs, rhs, out, scratch):
lhs_smem, rhs_smem, barriers, acc = scratch
block_id = gpu.cluster_block_id(gpu.Dimension.x)
ctx.async_copy(
src_ref=lhs,
dst_ref=lhs_smem,
swizzle=swizzle,
gmem_transform=mgpu.TileTransform(tiling),
barrier=barriers[0],
collective=gpu.Dimension.x,
partitioned=1 if lhs_transpose else 0, # Split non-contracting dim.
)
ctx.async_copy(
src_ref=rhs,
dst_ref=rhs_smem,
swizzle=swizzle,
gmem_transform=mgpu.TileTransform(tiling),
barrier=barriers[1],
collective=gpu.Dimension.x,
partitioned=0 if rhs_transpose else 1, # Split non-contracting dim.
)
is_leader_thread = single_thread_predicate()
is_first_block = arith.cmpi(arith.CmpIPredicate.eq, block_id, c(0, index))
with when(arith.andi(is_first_block, is_leader_thread)):
barriers[0].wait()
barriers[1].wait()
if lhs_transpose:
lhs_smem = memref_transpose(lhs_smem, (1, 0, 3, 2))
if rhs_transpose:
rhs_smem = memref_transpose(rhs_smem, (1, 0, 3, 2))
tcgen05.mma(
acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, collective=True
)
tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx)
barriers[2].wait(for_tensor_core=True)
m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile)
acc[:].store_untiled(memref_slice(out, m_slice))
in_finfo = jnp.finfo(in_jax_dtype)
exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant
def quantize(x):
# Quantize the input to avoid rounding when feeding the TensorCore
return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits)
x_shape = (k, m) if lhs_transpose else (m, k)
x_block_shape = (k, m_block_tile) if lhs_transpose else (m_block_tile, k)
x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype)
y_shape = (n, k) if rhs_transpose else (k, n)
y_block_shape = (n_block_tile, k) if rhs_transpose else (k, n_block_tile)
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
scratch_shape = [
jax.ShapeDtypeStruct(tile_shape(x_block_shape, tiling), in_jax_dtype),
jax.ShapeDtypeStruct(tile_shape(y_block_shape, tiling), in_jax_dtype),
mgpu.TMABarrier(3),
mgpu.TMEM((128, n), out_jax_dtype, collective=True),
]
z = mgpu.as_gpu_kernel(
kernel, (2, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape, cluster=(2, 1, 1)
)(x, y)
x32, y32 = x.astype(np.float32), y.astype(np.float32)
ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32)
atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6
np.testing.assert_allclose(z, ref, atol=atol)
class BarrierTest(TestCase):
def test_wg_communication(self):
i32 = ir.IntegerType.get_signless(32)
def kernel(ctx, dst, scratch):
tmp, barriers = scratch
del ctx # Unused.
wg_idx = arith.divui(mgpu.warp_idx(), c(4, i32))
is_first_wg = arith.cmpi(arith.CmpIPredicate.eq, wg_idx, c(0, i32))
is_second_wg = arith.cmpi(arith.CmpIPredicate.eq, wg_idx, c(1, i32))
arr = mgpu.FragmentedArray.splat(
arith.addi(wg_idx, c(1, i32)),
(128,),
mgpu.WGStridedFragLayout((128,), 1),
is_signed=False,
)
with ir.InsertionPoint(scf.IfOp(is_first_wg).then_block):
arr.store_untiled(tmp)
barriers[0].arrive() # Signal that tmp is ready.
barriers[1].wait() # Wait for the other warp to produce tmp.
final_arr = arr + mgpu.FragmentedArray.load_strided(
tmp, is_signed=False
)
final_arr.store_untiled(memref_slice(dst, 0))
scf.yield_([])
with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block):
barriers[0].wait()
final_arr = arr + mgpu.FragmentedArray.load_strided(
tmp, is_signed=False
)
barriers[2].arrive()
barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp.
arr.store_untiled(tmp)
barriers[1].arrive() # Signal that tmp is ready.
final_arr.store_untiled(memref_slice(dst, 1))
scf.yield_([])
out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32)
y = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(2 * 128, 1, 1),
(),
out_shape,
(
jax.ShapeDtypeStruct((128,), jnp.int32),
mgpu.Barrier(arrival_count=128, num_barriers=3),
),
)()
np.testing.assert_array_equal(y, np.full_like(y, 3, dtype=np.int32))
@parameterized.named_parameters(
(
f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}{'_group' if group_dims else ''}",
collective_dims,
noncollective_dims,
collective_size,
group_dims,
)
for collective_dims in itertools.chain.from_iterable(
itertools.combinations(Dimension, n) for n in range(1, 4)
)
for noncollective_dims in itertools.chain.from_iterable(
itertools.combinations(Dimension, n) for n in range(3)
)
for collective_size in (1, 2, 4)
for group_dims in (False,) + ((True,) if len(collective_dims) > 1 else ())
if all(d not in noncollective_dims for d in collective_dims)
)
def test_collective_arrive(self, collective_dims, noncollective_dims, collective_size, group_dims):
i32 = ir.IntegerType.get_signless(32)
index = ir.IndexType.get()
cluster = [1, 1, 1]
for d in collective_dims:
cluster[d] = collective_size
for d in noncollective_dims:
cluster[d] = 2
if math.prod(cluster) > 16:
self.skipTest("Cluster too big")
is_trivial = math.prod(cluster[d] for d in collective_dims) == 1
def kernel(ctx, dst, mask, collective_barrier):
cluster_idx = ctx.cluster_idx()
if not is_trivial:
memref.store(collective_barrier.cluster_mask, mask, [cluster_idx])
else:
assert collective_barrier.cluster_mask is None
collective_barrier.arrive()
collective_barrier.wait()
tid = thread_idx()
linear_idx = arith.index_cast(index, tid)
stride = c(128, index)
for d in gpu.Dimension:
linear_idx = arith.addi(linear_idx, arith.muli(gpu.block_id(d), stride))
stride = arith.muli(stride, gpu.grid_dim(d))
memref.store(arith.index_cast(i32, linear_idx), dst, [linear_idx])
out_shape = jax.ShapeDtypeStruct((math.prod(cluster) * 128,), jnp.int32)
mask_shape = jax.ShapeDtypeStruct((math.prod(cluster),), jnp.int32)
barrier_dims = collective_dims
if group_dims:
barrier_dims = (collective_dims[:2], *collective_dims[2:])
scratch = mgpu.ClusterBarrier(barrier_dims)
y, mask = mgpu.as_gpu_kernel(
kernel, cluster, (128, 1, 1), (), (out_shape, mask_shape), scratch, cluster=cluster,
)()
np.testing.assert_array_equal(
y, np.arange(math.prod(cluster) * 128, dtype=np.int32)
)
if not is_trivial:
# Verify that the mask is correct. Blocks are column-major, hence the transpose.
block_bits = 1 << np.arange(math.prod(cluster), dtype=np.int32).reshape(cluster[::-1]).T
expected_mask = 0
for bd in barrier_dims:
if isinstance(bd, gpu.Dimension):
bd = (bd,)
least_significant_slice = tuple(
slice(None) if d in bd else 0 for d in gpu.Dimension
)
mask_bits = block_bits[least_significant_slice]
expected_mask |= np.bitwise_or.reduce(mask_bits, axis=None)
self.assertEqual(min(mask), expected_mask)
class TMATest(TestCase):
@parameterized.product(
swizzle=(None, 32, 64, 128),
shape=((64, None), (5, None), (2, 3, 5, None)),
dtype=(jnp.float32, jnp.float16, jnp.int4),
)
def test_tma_load_basic(self, swizzle, shape, dtype):
bw = bitwidth(dtype_to_ir_type(dtype))
minor_size = 64 if swizzle is None else 8 * swizzle // bw
shape = (*shape[:-1], minor_size)
i1 = ir.IntegerType.get_signless(1)
def kernel(ctx, src, dst, smem):
tmp, barrier = smem
ctx.async_copy(src_ref=src, dst_ref=tmp, swizzle=swizzle, barrier=barrier)
barrier.wait_parity(c(0, i1))
copy(tmp, dst, swizzle=swizzle)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
smem = (x, mgpu.TMABarrier())
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x)
np.testing.assert_array_equal(y, x)
@parameterized.named_parameters(
(
f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}",
collective_dims,
noncollective_dims,
collective_size,
)
for collective_dims in itertools.chain.from_iterable(
itertools.combinations(Dimension, n) for n in range(1, 4)
)
for noncollective_dims in itertools.chain.from_iterable(
itertools.combinations(Dimension, n) for n in range(3)
)
for collective_size in (1, 2, 4)
if all(d not in noncollective_dims for d in collective_dims)
)
def test_tma_load_multicast(self, collective_dims, noncollective_dims, collective_dim_size):
index = ir.IndexType.get()
swizzle = 128
dtype = jnp.float16
cluster = [1, 1, 1]
for d in collective_dims:
cluster[d] = collective_dim_size
for d in noncollective_dims:
cluster[d] = 2
if math.prod(cluster) > 16:
self.skipTest("Cluster too big")
collective_size = math.prod(cluster[d] for d in collective_dims)
noncollective_size = math.prod(cluster) // collective_size
# We use the 2 dimension to exercise splitting the collective over
# multiple dimensions when the cluster is large.
shape = (noncollective_size, 2, 16 * collective_size, 64)
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
shape = (*shape[:-1], minor_size)
# Note that this kernel does not use the non-collective dimensions in any
# interesting way and so they don't really have to be part of the cluster.
# We use them to test that the multicast mask is generated correctly.
def kernel(ctx, src, dst, scratch):
tmp, barrier = scratch
stride = 1
noncollective_idx = c(0, index)
for d in noncollective_dims:
noncollective_idx = arith.addi(
noncollective_idx,
arith.muli(gpu.cluster_block_id(d), c(stride, index))
)
stride *= cluster[d]
ctx.async_copy(
src_ref=src,
dst_ref=tmp,
gmem_slice=(noncollective_idx,),
swizzle=swizzle,
barrier=barrier,
collective=collective_dims,
)
barrier.wait()
# This is _not_ the real cluster block idx, because it does not consider
# the column-major ordering of the grid dimensions.
idx = c(0, index)
stride = 1
for d in collective_dims:
idx = arith.addi(
idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))
)
stride *= cluster[d]
idx_minor = arith.divui(idx, c(2, index))
idx_major = arith.remui(idx, c(2, index))
slc_minor = ds(
arith.muli(idx_minor, c(16 * 2, index)), 16 * 2
)
copy(
memref_slice(tmp, (idx_major, slc_minor)),
memref_slice(dst, (noncollective_idx, idx_major, slc_minor)),
swizzle=swizzle,
)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
smem_shape = (jax.ShapeDtypeStruct(shape[1:], dtype), mgpu.TMABarrier())
y = mgpu.as_gpu_kernel(
kernel, cluster, (128, 1, 1), x, x, smem_shape, cluster=cluster
)(x)
np.testing.assert_array_equal(y, x)
@parameterized.product(
swizzle=(None, 128),
shape=((128, 128), (5, 32, 128)),
dtype=(jnp.float16, jnp.float32),
)
def test_tma_load_tiled(self, swizzle, shape, dtype):
# TODO(apaszke): ptxas seems to freeze when generating code for copy with
# swizzle 32 and 64.
i1 = ir.IntegerType.get_signless(1)
index = ir.IndexType.get()
tiling = (32, (swizzle or 128) // jnp.dtype(dtype).itemsize)
tiled_shape = tile_shape(shape, tiling)[:len(shape)]
def kernel(ctx, src, dst, scratch):
tmp, barrier = scratch
ctx.async_copy(
src_ref=src,
dst_ref=tmp,
swizzle=swizzle,
barrier=barrier,
gmem_transform=mgpu.TileTransform(tiling),
)
barrier.wait_parity(c(0, i1))
for idxs in np.ndindex(tiled_shape):
untiled_idxs, tiled_idxs = idxs[:-len(tiling)], idxs[-len(tiling):]
s = (
*untiled_idxs,
*(ds(c(ix * t, index), t) for ix, t in zip(tiled_idxs, tiling)),
)
copy(memref_slice(tmp, idxs), memref_slice(dst, s), swizzle=swizzle)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
smem = (
jax.ShapeDtypeStruct(tile_shape(shape, tiling), dtype),
mgpu.TMABarrier(),
)
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)
y = f(x)
np.testing.assert_array_equal(y, x)
@parameterized.product(swizzle=(None, 128))
def test_tma_load_tiled_rounding(self, swizzle):
# TODO(apaszke): ptxas seems to freeze when generating code for copy with
# swizzle 32 and 64.
shape = (5, 32, 144)
dtype = jnp.float16
i1 = ir.IntegerType.get_signless(1)
index = ir.IndexType.get()
tiling = (32, (swizzle or 128) // jnp.dtype(dtype).itemsize)
rounded_shape = (*shape[:-1], shape[-1] // tiling[-1] * tiling[-1])
tiled_shape = tile_shape(rounded_shape, tiling)[:len(shape)]
def kernel(ctx, src, dst, scratch):
tmp, barrier = scratch
ctx.async_copy(
src_ref=src,
dst_ref=tmp,
swizzle=swizzle,
barrier=barrier,
gmem_transform=mgpu.TileTransform(
tiling, rounding=mgpu.Rounding.DOWN
),
)
barrier.wait_parity(c(0, i1))
for idxs in np.ndindex(tiled_shape):
untiled_idxs, tiled_idxs = idxs[:-len(tiling)], idxs[-len(tiling):]
s = (
*untiled_idxs,
*(ds(c(ix * t, index), t) for ix, t in zip(tiled_idxs, tiling)),
)
copy(memref_slice(tmp, idxs), memref_slice(dst, s), swizzle=swizzle)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
tmp_shape = jax.ShapeDtypeStruct(tile_shape(rounded_shape, tiling), dtype)
smem = (tmp_shape, mgpu.TMABarrier())
out_shape = jax.ShapeDtypeStruct(rounded_shape, dtype)
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, out_shape, smem)
y = f(x)
np.testing.assert_array_equal(y, x[..., :rounded_shape[-1]])
def test_tma_load_indexed_tiled(self):
shape = (128, 2, 128)
tiling = mgpu.TileTransform((32, 32))
def kernel(ctx, src, dst, scratch):
tmp, barrier = scratch
ctx.async_copy(
src_ref=src,
dst_ref=tmp,
barrier=barrier,
gmem_transform=tiling,
gmem_slice=(slice(None), 1, slice(None)),
)
barrier.wait()
ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_transform=tiling)
ctx.await_async_copy(0)
x = np.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
smem = (
jax.ShapeDtypeStruct((4, 4, 32, 32), jnp.float32),
mgpu.TMABarrier(),
)
out_shape = jax.ShapeDtypeStruct((128, 128), jnp.float32)
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, out_shape, smem)
np.testing.assert_array_equal(f(x), x[:, 1, :])
@parameterized.product(
swizzle=(None, 128),
dtype=(jnp.float16, jnp.float32),
)
def test_tma_squeeze_indexing(self, swizzle, dtype):
# TODO(apaszke): ptxas seems to freeze when generating code for copy with
# swizzle 32 and 64.
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
shape = (4, 5, minor_size)
def kernel(ctx, src, dst, smem):
tmp, barrier = smem
for i in range(4):
ctx.async_copy(
src_ref=src,
dst_ref=memref_slice(tmp, i),
gmem_slice=i,
swizzle=swizzle,
barrier=barrier,
)
barrier.wait()
copy(tmp, dst, swizzle=swizzle)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
smem = (x, mgpu.TMABarrier())
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x)
np.testing.assert_array_equal(y, x)
def test_parity_tracking(self):
shape = (16, 64)
index = ir.IndexType.get()
def kernel(ctx, src, dst, smem):
tmp, barrier = smem
for i in range(shape[0]):
s = ds(c(i, index), 1)
ctx.async_copy(
src_ref=src, dst_ref=tmp, gmem_slice=s, barrier=barrier,
)
barrier.wait()
copy(tmp, memref_slice(dst, s))
x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape)
y = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, x, (x[0:1], mgpu.TMABarrier())
)(x)
np.testing.assert_array_equal(y, x)
@parameterized.product(
swizzle=(None, 32, 64, 128),
shape=((64, None), (5, None), (2, 3, 5, None)),
dtype=(jnp.float16, jnp.float32),
)
def test_tma_store(self, swizzle, shape, dtype):
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
shape = (*shape[:-1], minor_size)
def kernel(ctx, src, dst, tmp):
copy(src, tmp, swizzle=swizzle)
ctx.async_copy(src_ref=tmp, dst_ref=dst, swizzle=swizzle)
ctx.await_async_copy(0)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
np.testing.assert_array_equal(y, x)
@parameterized.parameters(0, 1)
def test_tma_small_tile_load(self, small_dim):
if small_dim == 0:
shape = (4, 128)
elif small_dim == 1:
shape = (128, 8)
else:
raise ValueError("small_dim must be 0 or 1")
tiled_shape = ((shape[0] + 63) // 64, (shape[1] + 63) // 64, 64, 64)
padded_shape = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2]))
def kernel(ctx, src, dst, smem):
tmp, barrier = smem
ctx.async_copy(
src_ref=src,
dst_ref=tmp,
swizzle=128,
gmem_transform=mgpu.TileTransform((64, 64)),
gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])),
barrier=barrier,
)
barrier.wait()
copy(tmp, dst, swizzle=128)
x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape)
tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16)
y_tiled = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, tiled, (tiled, mgpu.TMABarrier()),
)(x)
y = y_tiled.swapaxes(1, 2).reshape(padded_shape)
# y should contain x and zero everywhere else.
np.testing.assert_array_equal(y[:shape[0], :shape[1]], x)
y_mut = np.asarray(y).copy()
y_mut[:shape[0], :shape[1]] = 0
np.testing.assert_array_equal(y_mut, np.zeros_like(y_mut))
@parameterized.product(
small_dim=(0, 1),
tiled_layout=(False, True),
)
def test_tma_small_tile_store(self, small_dim, tiled_layout):
if small_dim == 0:
shape = (4, 128)
elif small_dim == 1:
shape = (128, 8)
else:
raise ValueError("small_dim must be 0 or 1")
tiled_shape = ((shape[0] + 63) // 64, (shape[1] + 63) // 64, 64, 64)
m, n = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2]))
def kernel(ctx, dst, tmp):
vals = iota_tensor(m, n, jnp.float16)
vals.store_tiled(tmp, swizzle=128)
ctx.async_copy(
src_ref=tmp,
dst_ref=dst,
swizzle=128,
gmem_transform=mgpu.TileTransform((64, 64)),
gmem_slice=(ds(0, m), ds(0, n)),
)
ctx.await_async_copy(0)
tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16)
out = jax.ShapeDtypeStruct(shape, jnp.float16)
y = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out, tiled,
)()
iota = np.arange(m * n, dtype=jnp.float16).reshape([m, n])
np.testing.assert_array_equal(y, iota[:shape[0], :shape[1]])
def test_tma_invalid(self):
def kernel(ctx, src, dst, tmp):
copy(src, tmp)
ctx.async_copy(src_ref=tmp, dst_ref=dst)
ctx.await_async_copy(0)
def run_kernel(shape):
x = np.arange(np.prod(shape)).reshape(shape)
_ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
with self.assertRaisesRegex(ValueError, "all GMEM strides except the last"):
run_kernel([1] * 6)
with self.assertRaisesRegex(
ValueError, "last dimension to be divisible by 16"
):
run_kernel([23])
class FragmentedArrayTest(TestCase):
@parameterized.product(
op=(
operator.add,
operator.mul,
operator.sub,
(lambda x, y: mgpu.FragmentedArray.min(x, y), np.minimum),
(lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum),
),
dtype=[jnp.float32, jnp.int32, jnp.uint32],
m=(64, 128),
n=(8, 16, 32, 64, 80, 128, 256),
)
@jtu.ignore_warning(
message="(invalid value|divide by zero)", category=RuntimeWarning
)
def test_binary(self, op, dtype, m=64, n=32):
if isinstance(op, tuple):
op, np_op = op
else:
np_op = op
for scalar_rhs in [None, 2]:
def kernel(ctx, dst, _):
mlir_dtype = utils.dtype_to_ir_type(dtype)
iota = iota_tensor(m, n, dtype)
rhs = iota if scalar_rhs is None else c(scalar_rhs, mlir_dtype)
op(iota, rhs).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
ref_x = np.arange(m * n, dtype=dtype).reshape(m, n)
ref_rhs = scalar_rhs or ref_x
np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs))
def test_minimum_np_compatibility(self):
one = np.ones((128, 128)).astype(np.float32)
negz = one * -0.
posz = one * 0.
nan = one * np.nan
expectation = (np.minimum(negz, posz) == negz) & (np.minimum(nan, one) != one)
assert np.all(expectation), expectation
def kernel(ctx, dst, _):
f32 = ir.F32Type.get()
splat = lambda i: mgpu.FragmentedArray.splat(c(i, f32), (128, 128))
negz = splat(-0.)
posz = splat(0.)
nan = splat(np.nan)
one = splat(1.)
res = (negz.min(posz) == negz) & (one.min(nan) != one) & (nan.min(one) != one)
i8 = ir.IntegerType.get_signless(8)
res.astype(i8, is_signed=False).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((128, 128), np.int8)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
np.testing.assert_array_equal(result == 1, expectation)
@parameterized.product(
op=[operator.truediv, operator.floordiv, operator.mod],
dtype=[jnp.float32, jnp.int32, jnp.uint32],
)
def test_division(self, op, dtype, m=64, n=32):
if jnp.issubdtype(dtype, jnp.integer) and op is operator.truediv:
self.skipTest("Unsupported for integer types")
if jnp.issubdtype(dtype, jnp.floating) and op is operator.mod:
self.skipTest("Unsupported for floating types")
def kernel(ctx, dst, _):
iota = iota_tensor(m, n, dtype)
op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
iota = np.arange(m * n, dtype=dtype).reshape(m, n)
np.testing.assert_allclose(
result, op(dtype(4.2).item() * iota, iota + 1), atol=2e-7
)
@parameterized.product(
op=[
operator.lt,
operator.le,
operator.gt,
operator.ge,
operator.eq,
operator.ne,
],
dtype=[jnp.float32, jnp.int32, jnp.uint32],
rhs_is_literal=[False, True]
)
def test_comparison(self, op, dtype, rhs_is_literal, m=64, n=32):
def kernel(ctx, dst, _):
iota = iota_tensor(m, n, dtype)
rhs = 0 if rhs_is_literal else iota + 1
op(iota, rhs).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.bool)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
iota = np.arange(m * n, dtype=dtype).reshape(m, n)
rhs = rhs = 0 if rhs_is_literal else iota + 1
np.testing.assert_array_equal(result, op(iota, rhs))
@parameterized.product(tiled_layout=(False, True))
def test_foreach(self, tiled_layout):
dtype = jnp.int32
swizzle = 128
tile = 64, swizzle // jnp.dtype(dtype).itemsize
shape = 128, 192
tiled_shape = mgpu.tile_shape(shape, tile)
mlir_dtype = utils.dtype_to_ir_type(dtype)
cst = 9999
def causal(val, idx):
row, col = idx
mask = arith.cmpi(arith.CmpIPredicate.uge, row, col)
return arith.select(mask, val, c(cst, mlir_dtype))
tiling = mgpu.TileTransform(tile)
def kernel(ctx, dst, smem):
x = iota_tensor(shape[0], shape[1], dtype, tiled_layout=tiled_layout)
x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem)
mgpu.commit_shared()
ctx.async_copy(src_ref=smem, dst_ref=dst)
ctx.await_async_copy(0)
iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape)
result = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
(),
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
)()
expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst
np.testing.assert_array_equal(result, expected)
@parameterized.product(
op=[operator.and_, operator.or_, operator.xor],
dtype=[jnp.uint32],
)
def test_bitwise(self, op, dtype, m=64, n=8):
def kernel(ctx, dst, _):
iota = iota_tensor(m, n, dtype)
op(iota, iota + 1).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
iota = np.arange(m * n, dtype=dtype).reshape(m, n)
np.testing.assert_array_equal(result, op(iota, iota + 1))
@parameterized.product(
ops=(
(lambda x: -x, jax.lax.neg),
(lambda x: x + 42, lambda x: x + 42),
(lambda x: x.tanh(), jax.lax.tanh),
),
dtype=[jnp.float32, jnp.int32, jnp.uint32],
)
def test_unary(self, ops, dtype, m=64, n=32):
op, np_op = ops
if np_op is jax.lax.tanh and jnp.issubdtype(dtype, jnp.integer):
raise self.skipTest("Tanh not supported for integer types")
def kernel(ctx, dst, _):
iota = iota_tensor(m, n, dtype)
op(iota).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
x = np.arange(m * n, dtype=dtype).reshape(m, n)
np.testing.assert_allclose(result, np_op(x), atol=2e-7, rtol=2e-7)
def test_select(self, m=64, n=32):
def kernel(ctx, dst, _):
iota = iota_tensor(m, n, jnp.int32)
(iota < 16).select(iota * 2, iota * 3).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.int32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
x = np.arange(m * n, dtype=jnp.int32).reshape(m, n)
np.testing.assert_array_equal(result, np.where(x < 16, x * 2, x * 3))
@parameterized.product(
ops=[
(lambda x: mgpu.FragmentedArray.exp(x), np.exp),
(lambda x: mgpu.FragmentedArray.sin(x), np.sin),
(lambda x: mgpu.FragmentedArray.cos(x), np.cos),
(lambda x: mgpu.FragmentedArray.rsqrt(x), jax.lax.rsqrt),
],
approx=[False, True],
)
@jtu.ignore_warning(message="overflow encountered", category=RuntimeWarning)
def test_math(self, ops, approx, m=64, n=32):
op, np_op = ops
def kernel(ctx, dst, _):
iota = iota_tensor(m, n, jnp.float32)
op(iota).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n)
atol = 5e-3 if approx else 2e-7
rtol = 4e-6 if approx else 2e-7
np.testing.assert_allclose(result, np_op(x), atol=atol, rtol=rtol)
@parameterized.product(
dtype=[jnp.float32, jnp.int32],
m=[128],
n=[32, 64],
)
def test_strided_reduce_sum(self, dtype, m, n):
def kernel(ctx, src, dst, scratch):
src = mgpu.FragmentedArray.load_strided(
src, is_signed=utils.is_signed(dtype)
)
acc = src.reduce_sum(scratch).broadcast((m,))
acc.store_untiled(dst)
in_shape = jax.ShapeDtypeStruct((m, n), dtype)
out_shape = jax.ShapeDtypeStruct((m,), dtype)
kernel_fn = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
in_shape,
out_shape,
smem_scratch_shape=jax.ShapeDtypeStruct((4,), dtype),
)
x = np.arange(m * n, dtype=dtype).reshape(m, n)
np.testing.assert_array_equal(kernel_fn(x), jnp.full((m,), x.sum()))
@parameterized.product(
dtype=[jnp.float32, jnp.int32],
m=[128],
n=[32, 64],
)
def test_splat_reduce_sum(self, dtype, m, n):
def kernel(ctx, dst, _):
src = mgpu.FragmentedArray.splat(
utils.c(1, utils.dtype_to_ir_type(dtype)),
(m, n),
is_signed=utils.is_signed(dtype),
)
acc = src.reduce_sum().broadcast((m,))
acc.store_untiled(dst)
kernel_fn = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
in_shape=(),
out_shape=jax.ShapeDtypeStruct((m,), dtype),
smem_scratch_shape=(),
)
np.testing.assert_array_equal(kernel_fn(), jnp.full((m,), m * n * 1.0))
@parameterized.product(
op=(arith.addf, arith.maximumf),
m=(64, 128),
n=(8, 16, 32, 64, 80, 128, 256),
)
def test_reduce(self, op, m=64, n=32):
def kernel(ctx, dst, _):
iota = iota_tensor(m, n, jnp.float32)
iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n)
if op == arith.addf:
expected = np.broadcast_to(x.sum(axis=1, keepdims=True), x.shape)
elif op == arith.maximumf:
expected = np.broadcast_to(x.max(axis=1, keepdims=True), x.shape)
else:
raise NotImplementedError(f"Unsupported op: {op}")
np.testing.assert_array_equal(result, expected)
def test_splat_layout(self):
m, n = 64, 8
def kernel(ctx, dst, _):
iota = iota_tensor(m, n, jnp.float32)
cte = c(1, iota.mlir_dtype)
cte_arr = mgpu.FragmentedArray.splat(cte, ())
cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n))
(iota + cte_arr).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
expected = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + 1
np.testing.assert_array_equal(result, expected)
def test_splat(self):
def kernel(ctx, dst, _):
f32 = ir.F32Type.get()
v = arith.constant(f32, ir.FloatAttr.get(f32, 3.14))
t = mgpu.FragmentedArray.splat(
v, (128,), mgpu.WGMMA_ROW_LAYOUT
)
t.broadcast_minor(32).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32))
def test_splat_binary_ops(self):
def kernel(ctx, src, dst, _):
f32 = ir.F32Type.get()
pi_arr = mgpu.FragmentedArray.load_strided(src)
assert isinstance(pi_arr.layout, mgpu.WGStridedFragLayout)
pi_scalar = arith.constant(f32, ir.FloatAttr.get(f32, 3.14))
pi_splat = mgpu.FragmentedArray.splat(pi_scalar, ())
assert isinstance(pi_splat.layout, mgpu.WGSplatFragLayout)
pi_arr_sq = pi_arr * pi_splat.broadcast(pi_arr.shape)
assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout)
pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq
assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout)
(pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
inp = jnp.ones_like(out_shape) * 3.14
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), inp, out_shape, ()
)(inp)
np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32))
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))
def test_strided_load_store(self, in_shape):
def kernel(ctx, *args):
gmem_input, gmem_output, (smem_input, smem_output) = args
copy(gmem_input, smem_input)
t = mgpu.FragmentedArray.load_strided(smem_input)
t.store_untiled(smem_output)
copy(smem_output, gmem_output)
inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out],
)(inp)
np.testing.assert_array_equal(inp, result)
def test_warp_tree_reduce(self):
def kernel(ctx, out, *_):
del ctx
i32 = ir.IntegerType.get_signless(32)
tid = gpu.thread_id(gpu.Dimension.x)
value = arith.index_cast(i32, tid)
grp = warp_tree_reduce(value, arith.addi, 4)
memref.store(grp, out, [tid])
x = np.arange(128, dtype=jnp.int32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), x, [],
)()
for i in range(0, 128, 4):
x[i:i + 4] = jnp.sum(x[i:i + 4])
np.testing.assert_array_equal(result, x)
@parameterized.parameters(2, 4)
def test_fast_i8_convert(self, reg_length):
jax_dtype_to = jnp.dtype(jnp.bfloat16)
jax_dtype_from = jnp.dtype(jnp.int8)
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
def kernel(ctx, inp, out, smem):
del ctx, smem
arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True)
assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length]
arr.astype(mlir_dtype_to).store_untiled(out)
x = jnp.arange(-128, 128, dtype=jax_dtype_from)
x = jnp.tile(x, reg_length // 2)
reference = x.astype(jax_dtype_to)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, reference, None,
)(x)
np.testing.assert_array_equal(result, reference)
@parameterized.parameters(
([64 * 4], "WGMMA_ROW_LAYOUT"),
([64 * 4, 8 * 2], "WGMMA_LAYOUT"),
([64 * 4, 8 * 2], "TILED_LAYOUT_WGMMA"),
)
def test_to_layout(self, shape, new_layout):
def kernel(ctx, _):
# No assertions, we are just checking there are no compile-time errors.
arr = mgpu.FragmentedArray.splat(c(42.0, ir.F32Type.get()), shape)
arr.to_layout(getattr(mgpu, new_layout))
_ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), (), None)()
@parameterized.parameters(
(jnp.float16, jnp.float16), # Noop
(jnp.int16, jnp.bfloat16),
(jnp.int16, jnp.float16),
(jnp.uint16, jnp.float16),
(jnp.float32, jnp.int32),
(jnp.float32, jnp.uint32),
(jnp.uint32, jnp.int32),
(jnp.int32, jnp.uint32),
)
def test_bitcast(self, in_dtype, out_dtype):
out_ir_type = utils.dtype_to_ir_type(out_dtype)
in_is_signed = utils.is_signed(in_dtype)
out_is_signed = utils.is_signed(out_dtype)
def kernel(ctx, inp, out, smem):
del ctx, smem
arr = mgpu.FragmentedArray.load_strided(inp, is_signed=in_is_signed)
arr = arr.bitcast(out_ir_type, output_is_signed=out_is_signed)
arr.store_untiled(out)
x = jnp.arange(256, dtype=in_dtype)
reference = jax.lax.bitcast_convert_type(x, out_dtype)
result = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
x,
reference,
None,
)(x)
np.testing.assert_array_equal(result, reference)
@parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16)
def test_optimization_barrier(self, dtype):
def kernel(ctx, inp, out, smem):
del ctx, smem
arr = mgpu.FragmentedArray.load_strided(inp)
arr2 = arr * 2
arr, arr2 = mgpu.optimization_barrier(arr, arr2)
(arr + arr2).store_untiled(out)
x = jnp.arange(256, dtype=dtype)
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, None)
np.testing.assert_array_equal(f(x), x * 3)
def test_convert_bool_to_u8(self):
m, n = 128, 128
def kernel(ctx, dst, _):
i8 = ir.IntegerType.get_signless(8)
iota = iota_tensor(m, n, jnp.uint8)
(iota > 10).astype(i8, is_signed=False).store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
iota = np.arange(m * n, dtype=jnp.uint8).reshape(m, n)
np.testing.assert_array_equal(result, (iota > 10).astype(jnp.uint8))
@parameterized.parameters(
(jnp.uint8, jnp.uint16, 255),
(jnp.uint8, jnp.int16, 255),
(jnp.int8, jnp.uint16, -127),
(jnp.int8, jnp.int16, -127),
)
def test_convert_int_uint(self, from_dtype, to_dtype, value):
m, n = 1, 128
def kernel(ctx, dst, _):
i8 = ir.IntegerType.get_signless(8)
from_mlir_dtype = utils.dtype_to_ir_type(from_dtype)
to_mlir_dtype = utils.dtype_to_ir_type(to_dtype)
from_arr = mgpu.FragmentedArray.splat(
c(value, from_mlir_dtype),
(m, n),
is_signed=utils.is_signed(from_dtype),
)
to_arr = from_arr.astype(to_mlir_dtype, is_signed=utils.is_signed(to_dtype))
to_arr.store_untiled(dst)
out_shape = jax.ShapeDtypeStruct((m, n), to_dtype)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
expected = jnp.full((m, n), value, dtype=from_dtype).astype(to_dtype)
np.testing.assert_array_equal(result, expected)
class ProfilerTest(TestCase):
def test_measure_events_explicit(self):
x = jnp.arange(1024 * 1024)
_, runtime_ms = profiler.measure(lambda x, y: x + y, mode="events")(x, x)
self.assertIsInstance(runtime_ms, float)
def test_profile(self):
def kernel(ctx, src, dst, _):
mgpu.FragmentedArray.load_strided(src).store_untiled(dst)
x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64)
spec = profiler.ProfilerSpec(1024)
# This is just a smoke test.
f = jax.jit(mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, x, (), prof_spec=spec
))
jax.block_until_ready(f(x))
def test_multigpu(self):
if len(jax.devices()) < 2:
self.skipTest("Need at least 2 devices")
def kernel(ctx, src, dst, _):
mgpu.FragmentedArray.load_strided(src).store_untiled(dst)
x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64)
f = jax.jit(mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, x, ()
))
# Make sure we can invoke the same program on different devices.
for xd in (jax.device_put(x, d) for d in jax.devices()[:2]):
jax.block_until_ready(f(xd))
class TorchTest(TestCase):
def setUp(self):
super().setUp()
try:
import torch
except ImportError:
raise unittest.SkipTest("Test requires PyTorch")
self.torch = torch
def test_basic(self):
def kernel(ctx, i_gmem, o_gmem, _):
x = mgpu.FragmentedArray.load_strided(i_gmem)
(x + x).store_untiled(o_gmem)
ty = jax.ShapeDtypeStruct((128, 128), jnp.float32)
x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda')
f = mgpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ())
y = f(x)
np.testing.assert_allclose(y.cpu(), x.cpu() * 2)
del y # Make sure the destructor runs successfully.
class LayoutTest(TestCase):
@parameterized.named_parameters(
("f32", jnp.float32, 256),
("f16", jnp.float16, 256),
("f16_small", jnp.float16, 128),
)
def test_store_untiled_splat(self, jax_dtype, size):
mlir_dtype = utils.dtype_to_ir_type(jax_dtype)
def kernel(ctx, out, _):
del ctx
arr = mgpu.FragmentedArray.splat(
c(1.0, mlir_dtype), (size,), is_signed=utils.is_signed(jax_dtype)
)
arr.store_untiled(out)
expected = np.ones((size,), jax_dtype)
mosaic_ones = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), expected, ()
)()
np.testing.assert_array_equal(mosaic_ones, expected)
@parameterized.product(
shape=((128, 128), (64, 8), (64, 256)),
dtype=(jnp.int32, jnp.int16, jnp.int8),
)
def test_wgmma_tiled_layout(self, shape, dtype):
def kernel(ctx, dst, _):
iota = iota_tensor(*shape, dtype)
tiled = iota.to_layout(fa._tiled_wgmma_layout(shape))
# Note that WGMMA layouts are always (shape[0] // 64, shape[1] // 8, 2, 1)
self.assertEqual(
tiled.registers.shape,
(shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1, 1),
)
self.assertEqual(tiled.shape, shape)
self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype)
tiled.store_untiled(dst)
ty = jax.ShapeDtypeStruct(shape, dtype)
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ())
expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
np.testing.assert_array_equal(f(), expected)
@parameterized.product(
load_tiled=[False, True],
store_tiled=[False, True],
dtype=[jnp.int8, jnp.int16, jnp.int32],
swizzle=[16, 32, 64, 128],
num_col_tiles=[1, 2, 3],
row_tiling=[8, 64],
)
def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles, row_tiling):
if (not load_tiled or not load_tiled) and row_tiling != 64:
self.skipTest("Old code path does not support this")
mlir_dtype = utils.dtype_to_ir_type(dtype)
bw = bytewidth(mlir_dtype)
col_tiling = swizzle // bw
if col_tiling % 8:
self.skipTest("WGMMA layout requires col_tiling % 8 == 0")
m, n = 128, col_tiling * num_col_tiles
tiling = (row_tiling, col_tiling)
tiled_layout = fa._tiled_wgmma_layout((m, n))
load_layout = tiled_layout if load_tiled else mgpu.TILED_LAYOUT_WGMMA
store_layout = tiled_layout if store_tiled else mgpu.TILED_LAYOUT_WGMMA
if (not load_tiled or not store_tiled) and bw == 4 and swizzle == 32:
self.skipTest("Old code path does not support this")
def kernel(ctx, in_, out, smems):
smem_in, smem_out, barrier = smems
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
barrier.wait()
t = mgpu.FragmentedArray.load_tiled(
smem_in, swizzle=swizzle, is_signed=True, layout=load_layout
)
t.to_layout(store_layout).store_tiled(smem_out, swizzle=swizzle)
mgpu.commit_shared()
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
ctx.await_async_copy(0)
expected = (
np.arange(m * n, dtype=dtype)
.reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1])
.transpose(0, 2, 1, 3)
)
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
try:
with jtu.capture_stdout() as get_sass:
iota = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
[expected, expected, mgpu.TMABarrier()],
)(expected)
finally:
if prev_dump is not None:
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
else:
del os.environ["MOSAIC_GPU_DUMP_SASS"]
np.testing.assert_array_equal(iota, expected)
# Verify that we don't use too many registers for the transfers.
# We verify LDS and STS separately, because they might use two different
# methods of computing offsets and we don't rely on CSE between them.
expected_regs = swizzle // bytewidth(mlir_dtype) // 8
# When the bytewidth is smaller than 2 the swizzle pattern changes every 2
# column tiles, so we only need half the registers.
if load_tiled and store_tiled: # The old code doesn't optimize properly.
if bytewidth(mlir_dtype) < 2:
expected_regs //= 2
for instr in ("STS", "LDS"):
with self.subTest(instr + " count"):
addrs = re.findall(instr + r".* \[(.*)\]", get_sass())
def get_reg(addr):
if (pos := addr.find("+")) != -1:
return addr[:pos]
return addr
used_regs = {get_reg(addr) for addr in addrs}
self.assertLessEqual(len(used_regs), expected_regs)
def test_copy_for_upcast(self):
dtype = jnp.int8
swizzle = 128
col_tiling = swizzle // bytewidth(utils.dtype_to_ir_type(dtype))
m, n = 128, col_tiling * 2
tiling = (64, col_tiling)
tiled_layout = fa._tiled_wgmma_layout_for_upcast((m, n))
def kernel(ctx, in_, out, smems):
smem_in, smem_out, barrier = smems
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
barrier.wait()
t = mgpu.FragmentedArray.load_tiled(
smem_in, swizzle=swizzle, is_signed=True, layout=tiled_layout
)
t.store_tiled(smem_out, swizzle=swizzle)
mgpu.commit_shared()
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
ctx.await_async_copy(0)
x = jax.random.randint(
jax.random.key(42), tile_shape((m, n), tiling), -128, 127, dtype=dtype
)
f = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, x, [x, x, mgpu.TMABarrier()],
)
np.testing.assert_array_equal(f(x), x)
@parameterized.product(
dtype=[jnp.int16], # TODO(apaszke): More dtypes
# TODO(apaszke): swizzle=64 <- not implemented in transfer_tiled right now
swizzle=[16, 32, 128],
)
def test_transpose_tiled(self, dtype, swizzle):
mlir_dtype = utils.dtype_to_ir_type(dtype)
bw = bytewidth(mlir_dtype)
col_tiling = swizzle // bw
m, n = 128, 256
tiling = (8, col_tiling)
transpose_layout = fa.WGMMA_TRANSPOSED_LAYOUT
def kernel(ctx, in_, out, smems):
smem_in, smem_out, barrier = smems
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
barrier.wait()
t = mgpu.FragmentedArray.load_tiled(
smem_in, swizzle=swizzle, is_signed=True, layout=fa.TILED_LAYOUT_WGMMA
)
smem_out_t = memref_transpose(smem_out, (1, 0, 3, 2))
t.to_layout(transpose_layout).store_tiled(smem_out_t, swizzle=swizzle)
mgpu.commit_shared()
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
ctx.await_async_copy(0)
x = (
np.arange(m * n, dtype=dtype)
.reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1])
.transpose(0, 2, 1, 3)
)
y_ref = (
np.arange(m * n, dtype=dtype)
.reshape(m, n)
.T.reshape(n // tiling[0], tiling[0], m // tiling[1], tiling[1])
.transpose(0, 2, 1, 3)
)
y = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), x, y_ref, [x, y_ref, mgpu.TMABarrier()],
)(x)
np.testing.assert_array_equal(y, y_ref)
@dataclasses.dataclass(frozen=True)
class Tile:
"""Defines a Tile transform in a TestCaseInput.
Note that we cannot simply alias mgpu_dialect.TileTransformAttr.get, because
we do not have an MLIR context at the point we define the TestCaseInput.
"""
tiling: tuple[int, ...]
def attr(self):
return mgpu_dialect.TileTransformAttr.get(self.tiling)
@dataclasses.dataclass(frozen=True)
class Transpose:
"""Defines a Transpose transform in a TestCaseInput.
Note that we cannot simply alias mgpu_dialect.TransposeTransformAttr.get,
because we do not have an MLIR context at the point we define the
TestCaseInput.
"""
permutation: tuple[int, ...]
def attr(self):
return mgpu_dialect.TransposeTransformAttr.get(self.permutation)
@dataclasses.dataclass(frozen=True)
class Swizzle:
"""Defines a Swizzle transform in a TestCaseInput.
Note that we cannot simply alias mgpu_dialect.SwizzleTransformAttr.get,
because we do not have an MLIR context at the point we define the
TestCaseInput.
"""
swizzle: mgpu_dialect.SwizzlingMode
def attr(self):
return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle)
def memref_with_transforms(
mem_ref: ir.Value,
transforms: Sequence[Tile | Transpose | Swizzle],
) -> ir.Value:
"""Casts the memref to one that has a layout with the given transforms."""
mem_ref_type = ir.MemRefType(mem_ref.type)
transform_attr = [t.attr() for t in transforms]
if not transform_attr:
return mem_ref
layout = mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, transform_attr)
memref_new_type = ir.MemRefType.get(
mem_ref_type.shape,
mem_ref_type.element_type,
layout,
mem_ref_type.memory_space,
)
return memref.cast(memref_new_type, mem_ref)
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
"""Device tests with lowering from the MLIR dialect and layout inference."""
def setUp(self):
if mgpu_dialect is None:
raise self.skipTest("Test requires Mosaic GPU dialect")
super().setUp()
def test_pointwise_kernel(self):
def add(ctx, a, b, result, smem):
del ctx, smem
shape = ir.MemRefType(a.type).shape
elt_type = ir.MemRefType(a.type).element_type
zero_index = arith.constant(ir.IndexType.get(), 0)
zero_vector_indices = [zero_index] * len(shape)
# GMEM -> registers
ab_type = ir.VectorType.get(shape, elt_type)
a = vector.load(ab_type, a, zero_vector_indices)
b = vector.load(ab_type, b, zero_vector_indices)
# Computation
add = arith.addf(a, b)
# Registers -> GMEM
vector.store(add, result, zero_vector_indices)
dtype = jnp.bfloat16
shape = (128, 128)
jax_shape = jax.ShapeDtypeStruct(shape, dtype)
kernel = mgpu.as_gpu_kernel(
add,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(jax_shape, jax_shape),
out_shape=jax_shape,
smem_scratch_shape=[],
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
)
x = self.prng.uniform(-1, 1, shape).astype(dtype)
y = self.prng.uniform(-1, 1, shape).astype(dtype)
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y)
@staticmethod
def kernel_with_tma_cases(dtype: jnp.dtype):
@dataclasses.dataclass()
class TestCaseInput:
shape: tuple[int, ...]
shape_sliced: tuple[int, ...] = ()
slice_indices: tuple[int, ...] = ()
slice_lengths: tuple[int, ...] = ()
transforms: tuple[Tile | Transpose | Swizzle, ...] = ()
def __post_init__(self):
if not self.shape_sliced:
self.shape_sliced = self.shape
if not self.slice_lengths:
self.slice_lengths = self.shape_sliced
if not self.slice_indices:
self.slice_indices = tuple([0] * len(self.slice_lengths))
result = []
for swizzle in mgpu_dialect.SwizzlingMode:
n = swizzle * 8 // jnp.finfo(dtype).bits
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
# We need at least one case with no transforms, as this is handled
# differently.
result.append(TestCaseInput(shape=[128, n]))
result.extend([
TestCaseInput(
shape=[128, n],
transforms=[Swizzle(swizzle)],
),
TestCaseInput(
shape=[256, n],
shape_sliced=[128, n],
transforms=[Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 128, n],
shape_sliced=[128, n],
slice_lengths=[-1, 128, n],
slice_indices=[1, 0, 0],
transforms=[Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[
Transpose([1, 0, 2, 3]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[256, n],
shape_sliced=[128, n],
transforms=[Tile([64, n]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2 * 64, 3 * n],
transforms=[
Tile([64, n]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
])
return result
@parameterized.parameters(kernel_with_tma_cases(jnp.bfloat16))
def test_kernel_with_tma(self, test_case):
def add(
ctx: launch_context.LaunchContext,
in_gmem_ref: ir.Value,
result_gmem_ref: ir.Value,
smem: list[ir.Value],
):
del ctx
smem_ref, tma_barrier = smem
smem_ref = memref_with_transforms(smem_ref, test_case.transforms)
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
elt_type = ir.MemRefType(in_gmem_ref.type).element_type
memref_bytes = utils.bytewidth(elt_type) * math.prod(
test_case.shape_sliced
)
mgpu_dialect.arrive_expect_tx(
barrier=dialect_barrier, expect_tx= memref_bytes
)
i32 = ir.IntegerType.get_signless(32)
slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices]
# GMEM -> SMEM
mgpu_dialect.async_load(
source=in_gmem_ref,
destination=smem_ref,
barrier=dialect_barrier,
indices=slice_indices,
slice_lengths=test_case.slice_lengths,
collective=ir.ArrayAttr.get([]),
)
parities = memref.load(tma_barrier.phases, [])
parity, _ = tma_barrier.update_parities(parities)
mgpu_dialect.wait(dialect_barrier, parity)
# SMEM -> GMEM
zero_index = arith.constant(i32, 0)
mgpu_dialect.async_store(
source=smem_ref,
destination=result_gmem_ref,
indices=[zero_index] * len(test_case.shape_sliced),
slice_lengths=test_case.shape_sliced,
)
nvvm.cp_async_bulk_wait_group(0)
utils.warpgroup_barrier()
dtype = jnp.bfloat16
jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype)
jax_shape_sliced = jax.ShapeDtypeStruct(test_case.shape_sliced, dtype)
kernel = mgpu.as_gpu_kernel(
add,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(jax_shape),
out_shape=jax_shape_sliced,
smem_scratch_shape=[
jax_shape_sliced,
core.TMABarrier(1),
],
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
)
x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
input_slice = tuple(
slice(i * abs(l), (i + 1) * abs(l))
for i, l in zip(test_case.slice_indices, test_case.slice_lengths)
)
self.assertArraysEqual(
jax.jit(kernel)(x),
(x[input_slice]).reshape(test_case.shape_sliced),
)
@staticmethod
def pointwise_kernel_with_tma_cases(dtype: jnp.dtype):
@dataclasses.dataclass(frozen=True)
class TestCaseInput:
shape: tuple[int, ...]
transforms: tuple[Tile | Transpose | Swizzle, ...] = ()
result = []
for swizzle in mgpu_dialect.SwizzlingMode:
n = swizzle * 8 // jnp.finfo(dtype).bits
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
# We need at least one case with no transforms, as this is handled
# differently.
result.append(TestCaseInput(shape=[128, n]))
result.extend([
TestCaseInput(
shape=[128, n],
transforms=[Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[
Transpose([1, 0, 2, 3]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
TestCaseInput(
shape=[2, 3, 64, n],
transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[128, n],
transforms=[Tile([64, n]), Swizzle(swizzle)],
),
TestCaseInput(
shape=[2 * 64, 3 * n],
transforms=[
Tile([64, n]),
Transpose([1, 0, 2, 3]),
Swizzle(swizzle),
],
),
])
return result
@parameterized.parameters(pointwise_kernel_with_tma_cases(jnp.bfloat16))
def test_pointwise_kernel_with_tma(self, test_case):
def add(
ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value,
b_gmem_ref: ir.Value,
result_gmem_ref: ir.Value,
smem: list[ir.Value],
):
del ctx
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
memref_type = ir.MemRefType(a_gmem_ref.type)
shape = memref_type.shape
elt_type = memref_type.element_type
memref_bytes = utils.bytewidth(elt_type) * math.prod(shape)
mgpu_dialect.arrive_expect_tx(
barrier=dialect_barrier, expect_tx=2 * memref_bytes
)
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
zero_slice_indices = [zero_i32] * memref_type.rank
# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=memref_with_transforms(
a_smem_ref, test_case.transforms
),
barrier=dialect_barrier,
indices=zero_slice_indices,
slice_lengths=shape,
collective=ir.ArrayAttr.get([]),
)
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=memref_with_transforms(
b_smem_ref, test_case.transforms
),
barrier=dialect_barrier,
indices=zero_slice_indices,
slice_lengths=shape,
collective=ir.ArrayAttr.get([]),
)
parities = memref.load(tma_barrier.phases, [])
parity, _ = tma_barrier.update_parities(parities)
mgpu_dialect.wait(dialect_barrier, parity)
zero_index = arith.constant(ir.IndexType.get(), 0)
zero_vector_indices = [zero_index] * memref_type.rank
# SMEM -> registers
ab_type = ir.VectorType.get(shape, elt_type)
a = vector.load(ab_type, a_smem_ref, zero_vector_indices)
b = vector.load(ab_type, b_smem_ref, zero_vector_indices)
# Computation
add = arith.addf(arith.addf(a, b), b)
# Registers -> SMEM
vector.store(add, result_smem_ref, zero_vector_indices)
# SMEM -> GMEM
mgpu_dialect.async_store(
source=memref_with_transforms(
result_smem_ref, test_case.transforms
),
destination=result_gmem_ref,
indices=zero_slice_indices,
slice_lengths=shape,
)
nvvm.cp_async_bulk_wait_group(0)
utils.warpgroup_barrier()
dtype = jnp.bfloat16
jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype)
kernel = mgpu.as_gpu_kernel(
add,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(jax_shape, jax_shape),
out_shape=jax_shape,
smem_scratch_shape=[
jax_shape,
jax_shape,
jax_shape,
core.TMABarrier(1),
],
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
)
x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
y = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
@staticmethod
def wgmma_kernel_with_tma_cases(abtype: jnp.dtype):
@dataclasses.dataclass(frozen=True)
class TestCaseInput:
shape_a: tuple[int, ...] = ()
shape_b: tuple[int, ...] = ()
shape_res: tuple[int, ...] = ()
transforms_a: tuple[Tile | Transpose | Swizzle, ...] = ()
transforms_b: tuple[Tile | Transpose | Swizzle, ...] = ()
transpose_a: bool = False
transpose_b: bool = False
load_a_in_registers: bool = False
result = []
for swizzle in [
# TODO(dasenov): Add a test for kNoSwizzle, i.e. all swizzling modes.
mgpu_dialect.SwizzlingMode.k32ByteSwizzle,
mgpu_dialect.SwizzlingMode.k64ByteSwizzle,
mgpu_dialect.SwizzlingMode.k128ByteSwizzle,
]:
k = swizzle // np.dtype(abtype).itemsize
groups_m = 4
groups_n = 1
groups_k = 1
result.extend([
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_n * k, groups_k * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_b=True,
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
transforms_b=[Tile([k, k]), Swizzle(swizzle)],
),
TestCaseInput(
shape_a=[groups_m * 64, groups_k * k],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
load_a_in_registers=True,
),
])
# The below only works for 128-byte swizzling. Regardless of transposing,
# TMA needs the size of the last dimension to be compatible with the
# swizzle.
if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
result.append(
TestCaseInput(
shape_a=[groups_k * k, groups_m * 64],
shape_b=[groups_k * k, groups_n * k],
shape_res=[groups_m * 64, groups_n * k],
transpose_a=True,
)
)
return result
@parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16))
def test_wgmma_kernel_with_tma(self, test_case):
def matmul(
ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value,
b_gmem_ref: ir.Value,
result_gmem_ref: ir.Value,
smem: list[ir.Value],
):
del ctx
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem
a_smem_ref = memref_with_transforms(a_smem_ref, test_case.transforms_a)
b_smem_ref = memref_with_transforms(b_smem_ref, test_case.transforms_b)
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
ab_elt_type = ir.MemRefType(a_gmem_ref.type).element_type
bytes_a = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_a)
bytes_b = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_b)
mgpu_dialect.arrive_expect_tx(
barrier=dialect_barrier,
expect_tx=bytes_a + bytes_b,
)
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=a_smem_ref,
barrier=dialect_barrier,
indices=[zero_i32] * len(test_case.shape_a),
slice_lengths=test_case.shape_a,
collective=ir.ArrayAttr.get([]),
)
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=b_smem_ref,
barrier=dialect_barrier,
indices=[zero_i32] * len(test_case.shape_b),
slice_lengths=test_case.shape_b,
collective=ir.ArrayAttr.get([]),
)
parities = memref.load(tma_barrier.phases, [])
parity, _ = tma_barrier.update_parities(parities)
mgpu_dialect.wait(dialect_barrier, parity)
# SMEM -> Registers
a_operand = a_smem_ref
zero_index = arith.constant(ir.IndexType.get(), 0)
if test_case.load_a_in_registers:
a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type)
zero_vector_indices = [zero_index] * len(test_case.shape_a)
a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices)
# Computation
shape_result = ir.MemRefType(result_gmem_ref.type).shape
result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type
zero_acc = arith.constant(
result_elt_type, ir.FloatAttr.get(result_elt_type, 0.0)
)
accumulator = vector.splat(
ir.VectorType.get(shape_result, result_elt_type), zero_acc
)
result = mgpu_dialect.wgmma(
accumulator,
a_operand,
b_smem_ref,
transpose_a=test_case.transpose_a,
transpose_b=test_case.transpose_b,
)
nvvm.wgmma_commit_group_sync_aligned()
nvvm.wgmma_wait_group_sync_aligned(0)
# Registers -> SMEM
vector.store(result, result_smem_ref, [zero_index] * len(shape_result))
# SMEM -> GMEM
mgpu_dialect.async_store(
source=result_smem_ref,
destination=result_gmem_ref,
indices=[zero_i32, zero_i32],
slice_lengths=shape_result,
)
nvvm.cp_async_bulk_wait_group(0)
abtype = jnp.bfloat16
acctype = jnp.float32
a_jax_shape = jax.ShapeDtypeStruct(test_case.shape_a, abtype)
b_jax_shape = jax.ShapeDtypeStruct(test_case.shape_b, abtype)
result_jax_shape = jax.ShapeDtypeStruct(test_case.shape_res, acctype)
kernel = mgpu.as_gpu_kernel(
matmul,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(a_jax_shape, b_jax_shape),
out_shape=result_jax_shape,
smem_scratch_shape=[
a_jax_shape,
b_jax_shape,
result_jax_shape,
core.TMABarrier(1),
],
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
)
x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype)
y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype)
transpose = lambda x, t: x.T if t else x
self.assertArraysAllClose(
jax.jit(kernel)(x, y),
np.matmul(
transpose(x.reshape(test_case.shape_a), test_case.transpose_a),
transpose(y.reshape(test_case.shape_b), test_case.transpose_b),
),
atol=1e-5,
rtol=1e-5,
)
class UtilsTest(TestCase):
@parameterized.parameters(
(1,),
(-1,),
(slice(2), slice(3),),
(slice(1), slice(1, 3)),
(slice(-2, 0),),
(slice(-2, -1),),
*([(utils.DynamicSlice(0, 2),)] if HAS_MOSAIC_GPU else []),
)
def test_parse_indices(self, *indices):
# We are simply making sure this does not raise.
_, _, _ = utils.parse_indices(indices, (2, 3, 4))
@parameterized.parameters(
(42,),
(-42,),
(slice(42),),
(slice(0, 42),),
(slice(-42, 0),),
(slice(-4, -42),),
*([(utils.DynamicSlice(0, 4),)] if HAS_MOSAIC_GPU else []),
)
def test_parse_indices_oob(self, indices):
with self.assertRaisesRegex(IndexError, "out of bounds"):
utils.parse_indices(indices, (2, 3, 4))
class SerializationTest(absltest.TestCase):
def test_pass_is_registered(self):
ctx = mlir.make_ir_context()
ctx.allow_unregistered_dialects = True
with ir.Location.unknown(ctx):
module = ir.Module.create()
pipeline = passmanager.PassManager.parse(
"builtin.module(mosaic_gpu-serde{serialize=true})",
ctx,
)
pipeline.run(module.operation)
if __name__ == "__main__":
absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader())