mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

This one is particularly annoying, because we have to break up the MMA into two collective N=256 MMAs. However, TensorCore only updates a contiguous chunk of columns in TMEM and so after executing two of those we end up with a TMEM layout that looks like this: ``` Contributing CTA | 0 | 1 | 0 | 1 | N local | 0:128 | 0:128 | 128:256 | 128:256 | N | 0:128 | 256:384 | 128:256 | 384:512 | ``` You can see that the TMEM columns no longer monotonically go over all columns until N=512, but they include a number of jumps! We could fix this on the load side, by ensuring that each CTA in the group does a strided load along the tiled dimension, but that just seems more trouble than it's worth (and is not that well supported by TMA unless we increase the number of striding levels). Instead, we encode this weirdness in the TMEM layout we use and make sure to rearrange the data properly while loading the tiles into registers. PiperOrigin-RevId: 735791426
2979 lines
104 KiB
Python
2979 lines
104 KiB
Python
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for Mosaic GPU DSL functions and utilities."""
|
|
|
|
from collections.abc import Sequence
|
|
import dataclasses
|
|
import enum
|
|
import itertools
|
|
import math
|
|
import operator
|
|
import os
|
|
import re
|
|
import unittest
|
|
|
|
from absl.testing import absltest, parameterized
|
|
import jax
|
|
from jax._src import config
|
|
from jax._src import test_util as jtu
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir import passmanager
|
|
from jax._src.lib.mlir.dialects import arith
|
|
from jax._src.lib.mlir.dialects import scf
|
|
from jax._src.lib.mlir.dialects import vector
|
|
from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member
|
|
from jax.experimental.mosaic.gpu import fragmented_array as fa
|
|
from jax.experimental.mosaic.gpu import tcgen05
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
try:
|
|
import jax._src.lib.mosaic_gpu # noqa: F401
|
|
HAS_MOSAIC_GPU = True
|
|
except ImportError:
|
|
HAS_MOSAIC_GPU = False
|
|
|
|
class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok
|
|
x = 0
|
|
y = 1
|
|
z = 2
|
|
else:
|
|
import jax.experimental.mosaic.gpu as mgpu
|
|
from jax.experimental.mosaic.gpu import core
|
|
from jax.experimental.mosaic.gpu import launch_context
|
|
from jax.experimental.mosaic.gpu import utils as utils
|
|
from jax.experimental.mosaic.gpu import profiler
|
|
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
|
|
from jax._src.lib.mlir.dialects import gpu
|
|
from jax._src.lib.mlir.dialects import llvm
|
|
Dimension = gpu.Dimension
|
|
|
|
|
|
# ruff: noqa: F405
|
|
# pylint: disable=g-complex-comprehension
|
|
config.parse_flags_with_absl()
|
|
|
|
def nd_loop(bounds, body, *, _idxs = ()):
|
|
if not bounds:
|
|
body(*_idxs)
|
|
return
|
|
bound, *other_bounds = bounds
|
|
@fori(bound, ())
|
|
def _loop_body(i, _):
|
|
nd_loop(other_bounds, body, _idxs=(*_idxs, i))
|
|
return ()
|
|
|
|
|
|
def mlir_sum(elems):
|
|
assert elems
|
|
total = elems[0]
|
|
for elem in elems[1:]:
|
|
total = arith.addi(total, elem)
|
|
return total
|
|
|
|
|
|
def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
|
|
index = ir.IndexType.get()
|
|
thread_id = gpu.thread_id(gpu.Dimension.x)
|
|
stride = gpu.block_dim(gpu.Dimension.x)
|
|
for dim in (gpu.Dimension.y, gpu.Dimension.z):
|
|
thread_id = arith.addi(thread_id, arith.muli(gpu.thread_id(dim), stride))
|
|
stride = arith.muli(stride, gpu.block_dim(dim))
|
|
is_first_thread = arith.cmpi(arith.CmpIPredicate.eq, thread_id, c(0, index))
|
|
|
|
src_ty = ir.MemRefType(src.type)
|
|
dst_ty = ir.MemRefType(dst.type)
|
|
if src_ty.shape != dst_ty.shape:
|
|
raise ValueError(
|
|
f"src and dst shapes don't match: {src_ty.shape} != {dst_ty.shape}"
|
|
)
|
|
shape = src_ty.shape
|
|
if src_ty.element_type != dst_ty.element_type:
|
|
raise ValueError(
|
|
f"src and dst element types don't match: {src_ty.element_type} !="
|
|
f" {dst_ty.element_type}"
|
|
)
|
|
contig_strides = get_contiguous_strides(shape)
|
|
# If swizzling is on, at least one of the memrefs must be contiguous
|
|
# (simulating a TMA).
|
|
if (swizzle is not None and
|
|
src_ty.get_strides_and_offset()[0] != contig_strides and
|
|
dst_ty.get_strides_and_offset()[0] != contig_strides):
|
|
raise NotImplementedError(src_ty, dst_ty)
|
|
|
|
bw = bitwidth(src_ty.element_type)
|
|
if bw < 8:
|
|
assert bw.bit_count() == 1
|
|
packing = 8 // bw
|
|
if shape[-1] % packing:
|
|
raise NotImplementedError
|
|
workgroup_mem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
|
shape = (*shape[:-1], shape[-1] // packing)
|
|
contig_strides = get_contiguous_strides(shape)
|
|
def bitcast(ref):
|
|
ref_ty = ir.MemRefType(ref.type)
|
|
old_strides = ref_ty.get_strides_and_offset()[0]
|
|
if old_strides[-1] != 1:
|
|
raise NotImplementedError
|
|
new_strides = [s // packing for s in old_strides[:-1]] + [1]
|
|
new_ref_ty = ir.MemRefType.get(
|
|
shape,
|
|
ir.VectorType.get((packing,), src_ty.element_type), # noqa: F821
|
|
ir.StridedLayoutAttr.get(0, new_strides),
|
|
ref_ty.memory_space,
|
|
)
|
|
ptr_space = (
|
|
3
|
|
if ref_ty.memory_space is not None
|
|
and ref_ty.memory_space == workgroup_mem
|
|
else None
|
|
)
|
|
return ptr_as_memref(
|
|
# NOTE: memref_ptr applies the offset in case there was any.
|
|
memref_ptr(ref, memory_space=ptr_space),
|
|
new_ref_ty,
|
|
ptr_memory_space=ptr_space,
|
|
)
|
|
src = bitcast(src)
|
|
dst = bitcast(dst)
|
|
bw = 8
|
|
del src_ty, dst_ty # If you remove this, update it in the branch above
|
|
dyn_strides = [c(s, index) for s in contig_strides]
|
|
|
|
with ir.InsertionPoint(scf.IfOp(is_first_thread).then_block):
|
|
def body(*idx):
|
|
dst_idx = idx
|
|
if swizzle is not None:
|
|
assert swizzle.bit_count() == 1
|
|
assert bw % 8 == 0
|
|
bytes_per_element = bw // 8
|
|
linear_idx = c(0, index)
|
|
for stride, i in zip(dyn_strides, idx):
|
|
linear_idx = arith.addi(linear_idx, arith.muli(i, stride))
|
|
# Swizzle pattern repeats every 128 bytes.
|
|
swizzle_src = arith.remui(
|
|
arith.divui(linear_idx, c(128 // bytes_per_element, index)),
|
|
c(swizzle // 16, index),
|
|
)
|
|
# Swizzle happens in groups of 16 bytes.
|
|
swizzle_shift = 4 - (bytes_per_element.bit_length() - 1)
|
|
dst_linear_idx = arith.xori(
|
|
linear_idx, arith.shli(swizzle_src, c(swizzle_shift, index))
|
|
)
|
|
dst_idx = [
|
|
arith.remui(arith.divui(dst_linear_idx, stride), c(bound, index))
|
|
for stride, bound in zip(dyn_strides, shape)
|
|
]
|
|
memref.store(memref.load(src, idx), dst, dst_idx)
|
|
nd_loop([c(d, index) for d in shape], body)
|
|
scf.yield_([])
|
|
gpu.barrier()
|
|
nvvm.fence_proxy(nvvm.ProxyKind.async_)
|
|
|
|
|
|
def iota_tensor(m, n, dtype: jax.typing.DTypeLike, tiled_layout=False):
|
|
assert m % 64 == 0
|
|
assert n % 8 == 0
|
|
def c(i):
|
|
return arith.constant(index, ir.IntegerAttr.get(index, i))
|
|
index = ir.IndexType.get()
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32))
|
|
within_warp_id = arith.remui(gpu.thread_id(gpu.Dimension.x), c(32))
|
|
warp_row_start = arith.muli(warp_id, c(16))
|
|
within_warp_row = arith.divui(within_warp_id, c(4))
|
|
start_row = arith.addi(warp_row_start, within_warp_row)
|
|
start_col = arith.muli(arith.remui(within_warp_id, c(4)), c(2))
|
|
registers = np.empty((m // 64, n // 8, 2, 1), dtype=object)
|
|
for row_tile, col_tile, row_subtile, _ in np.ndindex(registers.shape):
|
|
row = arith.addi(start_row, c(row_tile * 64 + row_subtile * 8))
|
|
col = arith.addi(start_col, c(col_tile * 8))
|
|
row_value_base = arith.muli(row, c(n))
|
|
vec = llvm.mlir_undef(ir.VectorType.get((2,), i32))
|
|
for col_offset in range(2):
|
|
value = arith.addi(row_value_base, arith.addi(c(col_offset), col))
|
|
value = arith.index_cast(i32, value)
|
|
vec = vector.insertelement(value, vec, position=c(col_offset))
|
|
registers[row_tile, col_tile, row_subtile, 0] = vec
|
|
t = mgpu.FragmentedArray(
|
|
_registers=registers, _layout=mgpu.WGMMA_LAYOUT, _is_signed=True
|
|
)
|
|
if tiled_layout:
|
|
t = t.to_layout(mgpu.TILED_LAYOUT_WGMMA)
|
|
return t.astype(
|
|
utils.dtype_to_ir_type(dtype), is_signed=utils.is_signed(dtype)
|
|
)
|
|
|
|
|
|
class TestCase(parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
if not HAS_MOSAIC_GPU:
|
|
self.skipTest("jaxlib built without Mosaic GPU")
|
|
if (not jtu.test_device_matches(["cuda"]) or
|
|
not jtu.is_cuda_compute_capability_at_least("9.0")):
|
|
self.skipTest("Only works on GPU with capability >= sm90")
|
|
super().setUp()
|
|
self.prng = np.random.default_rng(1234)
|
|
self.context = mlir.make_ir_context()
|
|
if mgpu_dialect is not None:
|
|
mgpu_dialect.register_dialect(self.context)
|
|
self.enter_context(config.traceback_filtering("off"))
|
|
self.enter_context(self.context)
|
|
self.enter_context(ir.Location.unknown())
|
|
|
|
|
|
class Sm90ATestCase(TestCase, jtu.CudaArchSpecificTest):
|
|
|
|
def setUp(self):
|
|
self.skip_unless_sm90a()
|
|
super().setUp()
|
|
|
|
|
|
class TestUtilTest(TestCase):
|
|
|
|
def test_copy_basic(self):
|
|
def kernel(ctx, src, dst, _):
|
|
copy(src, dst)
|
|
x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3)
|
|
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x)
|
|
np.testing.assert_array_equal(y, x)
|
|
|
|
def test_copy_swizzle(self):
|
|
def kernel(ctx, src, dst, _):
|
|
copy(src, dst, swizzle=128)
|
|
x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32)
|
|
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x)
|
|
expected = np.zeros_like(y)
|
|
for i in range(8):
|
|
for j in range(8):
|
|
js = j ^ i
|
|
expected[i, (j * 4):(j * 4) + 4] = x[i, (js * 4):(js * 4) + 4]
|
|
np.testing.assert_array_equal(y, expected)
|
|
|
|
def test_copy_swizzle_noop(self):
|
|
# Two swizzles cancel out
|
|
def kernel(ctx, src, dst, smem):
|
|
copy(src, smem, swizzle=128)
|
|
copy(smem, dst, swizzle=128)
|
|
x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32)
|
|
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
|
|
np.testing.assert_array_equal(y, x)
|
|
|
|
@parameterized.parameters(False, True)
|
|
def test_iota_tensor(self, tiled_layout):
|
|
m = n = 64
|
|
def kernel(ctx, dst, _):
|
|
f32 = ir.F32Type.get()
|
|
index = ir.IndexType.get()
|
|
registers = iota_tensor(m, n, jnp.float32, tiled_layout).registers
|
|
assert registers.size == 16, registers.size
|
|
for i, vec_reg in enumerate(registers.flat):
|
|
for j in range(2):
|
|
reg = vector.extractelement(vec_reg, position=c(j, index))
|
|
memref.store(
|
|
reg, dst, [gpu.thread_id(gpu.Dimension.x), c(2 * i + j, index)]
|
|
)
|
|
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
|
|
regs = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
thread_ids = np.arange(128)
|
|
warp_ids = thread_ids // 32
|
|
lane_ids = thread_ids % 32
|
|
thread_rows = warp_ids * 16 + lane_ids // 4
|
|
thread_start_cols = (lane_ids % 4) * 2
|
|
thread_cols = thread_start_cols[:, None] + (np.arange(n // 8)[None] * 8)
|
|
regs = regs.reshape(128, 8, 2, 2)
|
|
for row_half in range(2):
|
|
for col_half in range(2):
|
|
np.testing.assert_array_equal(
|
|
regs[..., row_half, col_half],
|
|
(thread_rows[:, None] + row_half * 8) * n + thread_cols + col_half
|
|
)
|
|
|
|
|
|
class MemRefTest(TestCase):
|
|
@parameterized.product(
|
|
dim=tuple(range(3)),
|
|
strided=(False, True)
|
|
)
|
|
def test_unsqueeze(self, dim, strided):
|
|
def kernel(ctx, inp, out, _):
|
|
if strided:
|
|
for i in range(8):
|
|
s = ds(i, 1)
|
|
out_slice = s if dim != 0 else (slice(None), s)
|
|
copy(
|
|
memref_unsqueeze(memref_slice(inp, s), dim),
|
|
memref_slice(out, out_slice),
|
|
)
|
|
else:
|
|
copy(memref_unsqueeze(inp, dim), out)
|
|
x = np.arange(8 * 16, dtype=jnp.float32).reshape(8, 16)
|
|
out_shape = list(x.shape)
|
|
out_shape.insert(dim, 1)
|
|
out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32)
|
|
y = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
|
|
)(x)
|
|
np.testing.assert_array_equal(y, x.reshape(out_shape))
|
|
|
|
@parameterized.product(
|
|
dim=tuple(range(2)),
|
|
strided=(False, True)
|
|
)
|
|
def test_unfold(self, dim, strided):
|
|
in_shape = (8, 16)
|
|
def kernel(ctx, inp, out, _):
|
|
if strided:
|
|
# We slice the dim we don't unfold
|
|
for i in range(in_shape[1 - dim] // 4):
|
|
s = ds(i * 4, 4)
|
|
in_slice = s if dim == 1 else (slice(None), s)
|
|
out_slice = s if dim == 1 else (slice(None),) * 3 + (s,)
|
|
copy(
|
|
memref_unfold(memref_slice(inp, in_slice), dim, (2, 2, None)),
|
|
memref_slice(out, out_slice),
|
|
)
|
|
else:
|
|
copy(memref_unfold(inp, dim, (2, 2, None)), out)
|
|
x = np.arange(np.prod(in_shape), dtype=jnp.float32).reshape(in_shape)
|
|
out_shape = list(in_shape)
|
|
out_shape[dim:dim + 1] = [2, 2, out_shape[dim] // 4]
|
|
out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32)
|
|
y = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
|
|
)(x)
|
|
np.testing.assert_array_equal(y, x.reshape(out_ty.shape))
|
|
|
|
@parameterized.product(
|
|
dim=tuple(range(2)),
|
|
)
|
|
def test_fold_not_strided(self, dim):
|
|
def kernel(ctx, inp, out, _):
|
|
copy(memref_fold(inp, dim, 2), out)
|
|
|
|
x = np.arange(8 * 2 * 8, dtype=jnp.float32).reshape(8, 2, 8)
|
|
out_ty = jax.ShapeDtypeStruct((16, 8) if dim == 0 else (8, 16), jnp.float32)
|
|
y = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
|
|
)(x)
|
|
np.testing.assert_array_equal(y, x.reshape(out_ty.shape))
|
|
|
|
@parameterized.named_parameters(
|
|
("drop_1s", (1, 1, 5, 1, 1, 2, 1, 1), (5, 1, 2)),
|
|
("add_1s", (5, 1, 2), (1, 1, 5, 1, 1, 2, 1, 1)),
|
|
("fold", (1, 5, 2, 1,), (1, 10, 1)),
|
|
("un", (1, 10, 1), (1, 5, 2, 1,)),
|
|
)
|
|
def test_reshape(self, inp_shape, out_shape):
|
|
def kernel(ctx, inp, out, _):
|
|
copy(memref_reshape(inp, out_shape), out)
|
|
|
|
x = np.arange(math.prod(inp_shape), dtype=jnp.float32).reshape(*inp_shape)
|
|
out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32)
|
|
y = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), x, out_ty, ()
|
|
)(x)
|
|
np.testing.assert_array_equal(y, x.reshape(*out_shape))
|
|
|
|
@parameterized.named_parameters([
|
|
("packed", (4, 4, 4), (16, 4, 1), 1, 2, False),
|
|
("strided_end", (4, 4, 4, 4), (256, 64, 16, 4), 1, 2, False),
|
|
("strided_bot", (4, 4, 4, 4), (256, 16, 4, 1), 1, 2, False),
|
|
("strided_top", (4, 4, 4, 4), (256, 64, 4, 1), 1, 2, True),
|
|
("strided_mid", (4, 4, 4, 4), (265, 64, 16, 1), 1, 3, True),
|
|
# TODO(cperivol): Investigate why this is indexing OOB and uncomment.
|
|
# ("overap", (2, 4, 4), (16, 1, 1), 0, 3, True),
|
|
])
|
|
def test_fold_strided(
|
|
self, shape, strides, dim, fold_rank, throws_not_impl
|
|
):
|
|
expanded_shape = get_packed_shape(strides, shape)
|
|
total_size = np.prod(expanded_shape)
|
|
np_inp = np.arange(total_size, dtype=jnp.float32).reshape(expanded_shape)
|
|
index = tuple(slice(0, s) for s in shape)
|
|
|
|
# Reference implementation
|
|
def np_fold(inp, dim, fold_rank):
|
|
out_shape = list(inp.shape)
|
|
out_shape[dim : dim + fold_rank] = [
|
|
int(np.prod(inp.shape[dim : dim + fold_rank]))
|
|
]
|
|
if throws_not_impl:
|
|
return jax.ShapeDtypeStruct(shape=out_shape, dtype=inp.dtype)
|
|
else:
|
|
return inp.reshape(*out_shape)
|
|
|
|
total_size = np.prod(shape) * np.prod(strides)
|
|
|
|
def do_test():
|
|
def kernel(ctx, inp, out, _):
|
|
copy(memref_fold(memref_slice(inp, index), dim, fold_rank), out)
|
|
|
|
out = np_fold(np_inp[index], dim, fold_rank)
|
|
y = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), np_inp, out, ()
|
|
)(np_inp)
|
|
assert (
|
|
not throws_not_impl
|
|
), "If it should have thrown it would during the call."
|
|
np.testing.assert_array_equal(y, out)
|
|
|
|
if throws_not_impl:
|
|
with self.assertRaises(NotImplementedError):
|
|
do_test()
|
|
else:
|
|
do_test()
|
|
|
|
@parameterized.parameters(jnp.uint64, jnp.uint32, jnp.uint16, jnp.uint8)
|
|
def test_scalar_argument(self, dtype):
|
|
if dtype == jnp.uint64 and not config.enable_x64.value:
|
|
self.skipTest(
|
|
"64-bit types are disabled: this leads to the input scalar being"
|
|
" traced as a uint32 value, which causes the top 32 bits of the 64-bit"
|
|
" values read from the 32-bit input buffer to sometimes"
|
|
" (nondeterministically) contain garbage.")
|
|
|
|
scalar = 42
|
|
expected = np.full((128, 128), scalar, dtype=dtype)
|
|
|
|
def kernel(ctx, inp, out, _):
|
|
del ctx
|
|
inp = memref.load(inp, [])
|
|
mgpu.FragmentedArray.splat(inp, expected.shape, is_signed=True).store_untiled(out)
|
|
|
|
res = mgpu.as_gpu_kernel(
|
|
kernel,
|
|
(1, 1, 1),
|
|
(128, 1, 1),
|
|
jax.ShapeDtypeStruct(shape=(), dtype=expected.dtype),
|
|
expected,
|
|
(),
|
|
)(scalar)
|
|
np.testing.assert_array_equal(res, expected)
|
|
|
|
|
|
def get_packed_shape(strides, shape):
|
|
perm = sorted(range(len(strides)), key=lambda i: strides[i], reverse=True)
|
|
ordered_strides = [strides[i] for i in perm]
|
|
ordered_shape = [shape[i] for i in perm]
|
|
packed_shape = [ordered_shape[-1]]
|
|
packed_shape += [
|
|
stride0 // stride
|
|
for stride0, stride in zip(ordered_strides, ordered_strides[1:])
|
|
]
|
|
# Invert permutation
|
|
inv_perm = [None] * len(perm)
|
|
for i, p in enumerate(perm):
|
|
inv_perm[p] = i
|
|
return [packed_shape[i] for i in inv_perm]
|
|
|
|
|
|
class WGMMALayoutTest(TestCase):
|
|
|
|
@parameterized.product(dtype=[jnp.float16, jnp.float32],
|
|
tiled_layout=[False, True],
|
|
transposed_smem=[False, True])
|
|
def test_store_untiled(self, dtype, tiled_layout, transposed_smem):
|
|
def kernel(ctx, out, _):
|
|
del ctx
|
|
if transposed_smem:
|
|
out = memref_transpose(out, (1, 0))
|
|
iota_tensor(64, 64, dtype, tiled_layout=tiled_layout).store_untiled(
|
|
out, vector_store=not transposed_smem
|
|
)
|
|
expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64)
|
|
if transposed_smem:
|
|
expected = expected.T
|
|
iota = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), expected, ()
|
|
)()
|
|
np.testing.assert_array_equal(iota, expected)
|
|
|
|
@parameterized.product(
|
|
dtype=[jnp.float32, jnp.float16, jnp.int8],
|
|
swizzle=(32, 64, 128),
|
|
num_col_tiles=(1, 2, 3),
|
|
tiled_layout=(False, True),
|
|
)
|
|
def test_store_tiled(self, dtype, swizzle, num_col_tiles, tiled_layout):
|
|
mlir_dtype = utils.dtype_to_ir_type(dtype)
|
|
if bytewidth(mlir_dtype) > 2 and swizzle == 32:
|
|
self.skipTest("Not implemented")
|
|
col_tiling = swizzle // bytewidth(mlir_dtype)
|
|
m = 128
|
|
n = col_tiling * num_col_tiles
|
|
tiling = (64, col_tiling)
|
|
def kernel(ctx, out, smem):
|
|
del ctx
|
|
iota_tensor(m, n, dtype, tiled_layout).store_tiled(smem, swizzle=swizzle)
|
|
copy(smem, out, swizzle=swizzle)
|
|
expected = (
|
|
np.arange(m * n, dtype=dtype)
|
|
.reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1])
|
|
.transpose(0, 2, 1, 3)
|
|
)
|
|
iota = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), expected, expected
|
|
)()
|
|
np.testing.assert_array_equal(iota, expected)
|
|
|
|
@parameterized.product(
|
|
dtype=[jnp.float16, jnp.int8],
|
|
swizzle=(32, 64, 128),
|
|
)
|
|
def test_store_tiled_short_n(self, dtype, swizzle):
|
|
mlir_dtype = utils.dtype_to_ir_type(dtype)
|
|
col_tiling = swizzle // bytewidth(mlir_dtype)
|
|
m = 128
|
|
n = 16 // bytewidth(mlir_dtype)
|
|
tiling = (64, col_tiling)
|
|
def kernel(ctx, out, smem):
|
|
iota_tensor(m, n, dtype).store_tiled(smem, swizzle=swizzle)
|
|
ctx.async_copy(
|
|
src_ref=smem,
|
|
dst_ref=out,
|
|
swizzle=swizzle,
|
|
gmem_slice=(ds(0, m), ds(0, col_tiling)),
|
|
gmem_transform=mgpu.TileTransform(tiling),
|
|
)
|
|
ctx.await_async_copy(0)
|
|
smem_shape = jax.ShapeDtypeStruct((m // tiling[0], 1, *tiling), dtype)
|
|
expected = np.arange(m * n, dtype=dtype).reshape(m, n)
|
|
iota = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), expected, smem_shape
|
|
)()
|
|
np.testing.assert_array_equal(iota, expected)
|
|
|
|
@parameterized.named_parameters(
|
|
("bf16_i8", jnp.bfloat16, jnp.int8),
|
|
("i8_bf16", jnp.int8, jnp.bfloat16),
|
|
("i8_i8", jnp.int8, jnp.int8),
|
|
("i4_i4", jnp.int4, jnp.int4),
|
|
("i4_bf16", jnp.int4, jnp.bfloat16),
|
|
)
|
|
def test_convert_tiled(self, jax_dtype_from, jax_dtype_to):
|
|
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
|
|
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
|
|
m = 128
|
|
n = 256 * 8 // bitwidth(mlir_dtype_from)
|
|
def kernel(ctx, inp, out, smem):
|
|
del ctx
|
|
smem_from, smem_to = smem
|
|
copy(inp, smem_from, swizzle=128)
|
|
t = mgpu.FragmentedArray.load_tiled(
|
|
smem_from,
|
|
swizzle=128,
|
|
is_signed=utils.is_signed(jax_dtype_from),
|
|
layout=fa._tiled_wgmma_layout((m, n))
|
|
)
|
|
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
|
|
t.store_tiled(smem_to, swizzle=128)
|
|
copy(smem_to, out, swizzle=128)
|
|
|
|
from_tiling = (64, 128 * 8 // bitwidth(mlir_dtype_from))
|
|
to_tiling = (64, 128 * 8 // bitwidth(mlir_dtype_to))
|
|
# We only test lossless conversions for now.
|
|
# TODO(apaszke): Test and fix failures that appear with lossy conversions.
|
|
int_sample_dtype = getattr(
|
|
jnp,
|
|
"int" + str(min(bitwidth(mlir_dtype_from), bitwidth(mlir_dtype_to))),
|
|
)
|
|
sample_iinfo = jnp.iinfo(int_sample_dtype)
|
|
expected_raw = self.prng.integers(
|
|
low=sample_iinfo.min, high=sample_iinfo.max,
|
|
size=(m, n), dtype=np.int32
|
|
)
|
|
expected = lambda jax_dtype, tiling: expected_raw.reshape(
|
|
m // tiling[0], tiling[0], n // tiling[1], tiling[1]
|
|
).transpose(0, 2, 1, 3).astype(jax_dtype)
|
|
|
|
expected_from = expected(jax_dtype_from, from_tiling)
|
|
expected_to = expected(jax_dtype_to, to_tiling)
|
|
res = mgpu.as_gpu_kernel(
|
|
kernel,
|
|
(1, 1, 1),
|
|
(128, 1, 1),
|
|
expected_from,
|
|
expected_to,
|
|
(expected_from, expected_to),
|
|
)(expected_from)
|
|
np.testing.assert_array_equal(res, expected_to)
|
|
|
|
@parameterized.named_parameters(
|
|
("f32", jnp.float32),
|
|
("f16", jnp.float16),
|
|
("i8", jnp.int8),
|
|
)
|
|
def test_load_tiled(self, jax_dtype):
|
|
mlir_dtype = utils.dtype_to_ir_type(jax_dtype)
|
|
m = 128
|
|
n = 256 // bytewidth(mlir_dtype)
|
|
tiling = (64, 128 // bytewidth(mlir_dtype))
|
|
def kernel(ctx, in_, out, smem):
|
|
del ctx
|
|
smem1, smem2 = smem
|
|
copy(in_, smem1, swizzle=128)
|
|
t = mgpu.FragmentedArray.load_tiled(
|
|
smem1, swizzle=128, is_signed=utils.is_signed(jax_dtype)
|
|
)
|
|
t.store_tiled(smem2, swizzle=128)
|
|
copy(smem2, out, swizzle=128)
|
|
expected = (
|
|
np.arange(m * n, dtype=jax_dtype)
|
|
.reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1])
|
|
.transpose(0, 2, 1, 3)
|
|
)
|
|
iota = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), expected, expected, (expected,) * 2
|
|
)(expected)
|
|
np.testing.assert_array_equal(iota, expected)
|
|
|
|
|
|
class WGMMATest(TestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
if not jtu.is_cuda_compute_capability_equal("9.0"):
|
|
self.skipTest("Only works on GPU with capability sm90a")
|
|
|
|
@parameterized.product(
|
|
lhs_transpose=(False, True),
|
|
rhs_transpose=(False, True),
|
|
in_mlir_dtype_cls=(ir.F16Type, ir.BF16Type, ir.F32Type),
|
|
m=(64, 128, 192),
|
|
n=(64, 128, 192),
|
|
k_steps=(1, 2),
|
|
swizzle=(32, 64, 128),
|
|
jax_out_dtype=(jnp.float16, jnp.float32),
|
|
rhs_tiling_kind=("large", "small", "small+no_transpose"),
|
|
lhs_tiling_kind=("large", "small", "small+no_transpose"),
|
|
)
|
|
def test_wgmma_basic(
|
|
self,
|
|
m,
|
|
n,
|
|
k_steps,
|
|
in_mlir_dtype_cls,
|
|
lhs_transpose,
|
|
rhs_transpose,
|
|
swizzle,
|
|
jax_out_dtype,
|
|
rhs_tiling_kind,
|
|
lhs_tiling_kind,
|
|
):
|
|
if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type:
|
|
self.skipTest("Only f16 input is supported for f16 output.")
|
|
if swizzle != 128 and lhs_transpose and lhs_tiling_kind == "large":
|
|
self.skipTest("Transpose only supported in 128B swizzled WGMMA")
|
|
if rhs_tiling_kind == "small+no_transpose" and not rhs_transpose:
|
|
self.skipTest("No transpose happening anyway")
|
|
if lhs_tiling_kind == "small+no_transpose" and not lhs_transpose:
|
|
self.skipTest("No transpose happening anyway")
|
|
|
|
in_mlir_dtype = in_mlir_dtype_cls.get()
|
|
out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype)
|
|
if ir.F32Type.isinstance(in_mlir_dtype): # We actually use tf32 instead
|
|
in_jax_dtype = jnp.float32
|
|
if lhs_transpose or not rhs_transpose:
|
|
self.skipTest("Transpose only supported in 16-bit WGMMA")
|
|
exponent_bits, mantissa_bits = 8, 10 # Use tf32
|
|
elif bytewidth(in_mlir_dtype) == 2:
|
|
if n % 64 != 0:
|
|
self.skipTest("16-bit WGMMA only supports n % 64 == 0")
|
|
if ir.F16Type.isinstance(in_mlir_dtype):
|
|
in_jax_dtype = jnp.float16
|
|
exponent_bits, mantissa_bits = 5, 10
|
|
elif ir.BF16Type.isinstance(in_mlir_dtype):
|
|
in_jax_dtype = jnp.bfloat16
|
|
exponent_bits, mantissa_bits = 8, 7
|
|
else:
|
|
raise NotImplementedError(in_mlir_dtype)
|
|
else:
|
|
raise NotImplementedError(in_mlir_dtype)
|
|
nk_tile = swizzle // bytewidth(in_mlir_dtype)
|
|
k = nk_tile * k_steps
|
|
assert m % 64 == 0 and n % nk_tile == 0
|
|
|
|
small_rhs_tile = rhs_tiling_kind != "large"
|
|
transpose_rhs_tiles = rhs_tiling_kind != "small+no_transpose"
|
|
rhs_tiling = (8, nk_tile) if small_rhs_tile else (nk_tile, nk_tile)
|
|
small_lhs_tile = lhs_tiling_kind != "large"
|
|
transpose_lhs_tiles = lhs_tiling_kind != "small+no_transpose"
|
|
lhs_tiling = (8, nk_tile) if small_lhs_tile else (64, nk_tile)
|
|
|
|
def kernel(ctx, lhs, rhs, out, scratch):
|
|
lhs_smem, rhs_smem, barriers = scratch
|
|
lhs_transform = (mgpu.TileTransform(lhs_tiling),)
|
|
if lhs_transpose and transpose_lhs_tiles:
|
|
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
|
rhs_transform = (mgpu.TileTransform(rhs_tiling),)
|
|
if rhs_transpose and transpose_rhs_tiles:
|
|
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
|
ctx.async_copy(
|
|
src_ref=lhs,
|
|
dst_ref=lhs_smem,
|
|
swizzle=swizzle,
|
|
gmem_transform=lhs_transform,
|
|
barrier=barriers[0],
|
|
)
|
|
ctx.async_copy(
|
|
src_ref=rhs,
|
|
dst_ref=rhs_smem,
|
|
swizzle=swizzle,
|
|
gmem_transform=rhs_transform,
|
|
barrier=barriers[1],
|
|
)
|
|
for i in range(2):
|
|
barriers[i].wait()
|
|
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype)
|
|
if lhs_transpose:
|
|
perm = (0, 1, 3, 2) if transpose_lhs_tiles else (1, 0, 3, 2)
|
|
lhs_smem = memref_transpose(lhs_smem, perm)
|
|
if rhs_transpose:
|
|
perm = (0, 1, 3, 2) if transpose_rhs_tiles else (1, 0, 3, 2)
|
|
rhs_smem = memref_transpose(rhs_smem, perm)
|
|
acc = mgpu.wgmma(init_acc, lhs_smem, rhs_smem, swizzle=swizzle)
|
|
nvvm.wgmma_commit_group_sync_aligned()
|
|
nvvm.wgmma_wait_group_sync_aligned(0)
|
|
acc.value.store_untiled(out)
|
|
|
|
def quantize(x):
|
|
# Quantize the input to avoid rounding when feeding the WGMMA
|
|
return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits)
|
|
|
|
x_shape = (k, m) if lhs_transpose else (m, k)
|
|
x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype)
|
|
y_shape = (n, k) if rhs_transpose else (k, n)
|
|
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
|
|
out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype)
|
|
if transpose_rhs_tiles:
|
|
rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling
|
|
rhs_smem_shape = (k // rhs_tiling_t[0], n // rhs_tiling_t[1], *rhs_tiling)
|
|
else:
|
|
rhs_smem_shape = tile_shape(y_shape, rhs_tiling)
|
|
if transpose_lhs_tiles:
|
|
lhs_tiling_t = lhs_tiling[::-1] if lhs_transpose else lhs_tiling
|
|
lhs_smem_shape = (m // lhs_tiling_t[0], k // lhs_tiling_t[1], *lhs_tiling)
|
|
else:
|
|
lhs_smem_shape = tile_shape(x_shape, lhs_tiling)
|
|
scratch_shape = [
|
|
jax.ShapeDtypeStruct(lhs_smem_shape, in_jax_dtype),
|
|
jax.ShapeDtypeStruct(rhs_smem_shape, in_jax_dtype),
|
|
mgpu.TMABarrier(2),
|
|
]
|
|
z = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape
|
|
)(x, y)
|
|
x32, y32 = x.astype(np.float32), y.astype(np.float32)
|
|
ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32)
|
|
atol = 2e-2 if jax_out_dtype == jnp.float16 else 5e-6
|
|
np.testing.assert_allclose(z, ref, atol=atol)
|
|
|
|
# TODO(apaszke): Add support for f32
|
|
@parameterized.product(
|
|
m=(64, 128, 192),
|
|
n=(64, 128, 192),
|
|
k_steps=(1, 2),
|
|
rhs_transpose=(False, True),
|
|
swizzle=(32, 64, 128),
|
|
dtype=[jnp.float16, jnp.bfloat16],
|
|
tiled_layout=(False, True),
|
|
)
|
|
def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, swizzle, dtype, tiled_layout):
|
|
index = ir.IndexType.get()
|
|
|
|
bytewidth = 2
|
|
nk_tile = swizzle // bytewidth
|
|
k = nk_tile * k_steps
|
|
|
|
def kernel(ctx, rhs, out, rhs_smem):
|
|
del ctx
|
|
for ki in range(k_steps):
|
|
for ni in range(n // nk_tile):
|
|
rhs_slice = (
|
|
ds(c(ki * nk_tile, index), nk_tile),
|
|
ds(c(ni * nk_tile, index), nk_tile),
|
|
)
|
|
if rhs_transpose:
|
|
rhs_slice = rhs_slice[::-1]
|
|
copy(
|
|
src=memref_slice(rhs, rhs_slice),
|
|
dst=memref_slice(rhs_smem, (ki, ni)),
|
|
swizzle=swizzle,
|
|
)
|
|
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
|
|
lhs_regs = iota_tensor(m, k, dtype, tiled_layout)
|
|
if rhs_transpose:
|
|
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
|
|
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle)
|
|
nvvm.wgmma_commit_group_sync_aligned()
|
|
nvvm.wgmma_wait_group_sync_aligned(0)
|
|
acc.value.store_untiled(out)
|
|
|
|
y_shape = (n, k) if rhs_transpose else (k, n)
|
|
y = self.prng.uniform(-1, 1, y_shape).astype(dtype)
|
|
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
|
scratch_shape = jax.ShapeDtypeStruct(
|
|
(k_steps, n // nk_tile, nk_tile, nk_tile), dtype
|
|
)
|
|
z = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape
|
|
)(y)
|
|
x = np.arange(m * k, dtype=dtype).reshape(m, k)
|
|
ref = jax.lax.dot(
|
|
x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32
|
|
)
|
|
rtol = 5e-4
|
|
np.testing.assert_allclose(z, ref, rtol=rtol, atol=0)
|
|
|
|
@parameterized.product(
|
|
rhs_transpose=(False, True),
|
|
swizzle=(32, 64, 128),
|
|
n=(8, 16),
|
|
small_rhs_tile=(False, True),
|
|
)
|
|
def test_narrow_n(self, rhs_transpose, swizzle, n, small_rhs_tile):
|
|
m, k_steps = 64, 2
|
|
bytewidth = 2
|
|
nk_tile = swizzle // bytewidth
|
|
k = nk_tile * k_steps
|
|
if small_rhs_tile and not rhs_transpose:
|
|
self.skipTest("Small tiles only supported for transposed RHS")
|
|
|
|
n_tile = 8 if small_rhs_tile else nk_tile
|
|
|
|
def kernel(ctx, rhs, out, smem):
|
|
rhs_smem, barrier = smem
|
|
gmem_slice = (ds(0, k), ds(0, max(n_tile, n)))
|
|
transform = (mgpu.TileTransform((n_tile, nk_tile)),)
|
|
if rhs_transpose:
|
|
gmem_slice = gmem_slice[::-1]
|
|
transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
|
ctx.async_copy(
|
|
src_ref=rhs,
|
|
dst_ref=rhs_smem,
|
|
swizzle=swizzle,
|
|
gmem_slice=gmem_slice,
|
|
gmem_transform=transform,
|
|
barrier=barrier,
|
|
)
|
|
barrier.wait()
|
|
init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n)
|
|
lhs_regs = iota_tensor(m, k, jnp.float16)
|
|
if rhs_transpose:
|
|
rhs_smem = memref_transpose(rhs_smem, (0, 1, 3, 2))
|
|
if not small_rhs_tile:
|
|
smem_slice = (slice(None), slice(None), slice(None), ds(0, n))
|
|
rhs_smem = memref_slice(rhs_smem, smem_slice)
|
|
acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle)
|
|
nvvm.wgmma_commit_group_sync_aligned()
|
|
nvvm.wgmma_wait_group_sync_aligned(0)
|
|
acc.value.store_untiled(out)
|
|
|
|
jax_dtype = jnp.float16
|
|
y_shape = (n, k) if rhs_transpose else (k, n)
|
|
y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype)
|
|
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
|
rhs_scratch_shape = jax.ShapeDtypeStruct(
|
|
(k_steps, (n + n_tile - 1) // n_tile, n_tile, nk_tile), jax_dtype
|
|
)
|
|
z = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), y, out_shape, (rhs_scratch_shape, mgpu.TMABarrier()),
|
|
)(y)
|
|
x = np.arange(m * k, dtype=jax_dtype).reshape(m, k)
|
|
ref = jax.lax.dot(
|
|
x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32
|
|
)
|
|
np.testing.assert_allclose(z, ref, rtol=5e-4, atol=0)
|
|
|
|
|
|
class TCGen05Test(TestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
capabilities = ("10.0", "10.1")
|
|
if not any(jtu.is_cuda_compute_capability_equal(sm) for sm in capabilities):
|
|
self.skipTest("Only works on GPU with capability sm_100a or sm_101a")
|
|
|
|
@parameterized.product(
|
|
lhs_transpose=(False, True),
|
|
rhs_transpose=(False, True),
|
|
in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32
|
|
out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation
|
|
m=(128,), # TODO(apaszke): 64, 192, 256
|
|
n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2
|
|
k_steps=(1, 2),
|
|
swizzle=(32, 64, 128,),
|
|
rhs_transpose_tiles=(False, True),
|
|
lhs_transpose_tiles=(False, True),
|
|
)
|
|
def test_mma_basic(
|
|
self,
|
|
m,
|
|
n,
|
|
k_steps,
|
|
swizzle,
|
|
lhs_transpose,
|
|
rhs_transpose,
|
|
in_jax_dtype,
|
|
out_jax_dtype,
|
|
rhs_transpose_tiles,
|
|
lhs_transpose_tiles,
|
|
):
|
|
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
|
|
self.skipTest("Only f16 input is supported for f16 output.")
|
|
|
|
in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype)
|
|
swizzle_elems = swizzle // bytewidth(in_mlir_dtype)
|
|
k = swizzle_elems * k_steps
|
|
lhs_tiling = rhs_tiling = (8, swizzle_elems)
|
|
|
|
def kernel(ctx, lhs, rhs, out, scratch):
|
|
lhs_smem, rhs_smem, barriers, acc = scratch
|
|
lhs_transform = (mgpu.TileTransform(lhs_tiling),)
|
|
if lhs_transpose_tiles:
|
|
lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
|
rhs_transform = (mgpu.TileTransform(rhs_tiling),)
|
|
if rhs_transpose_tiles:
|
|
rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),)
|
|
ctx.async_copy(
|
|
src_ref=lhs,
|
|
dst_ref=lhs_smem,
|
|
swizzle=swizzle,
|
|
gmem_transform=lhs_transform,
|
|
barrier=barriers[0],
|
|
)
|
|
ctx.async_copy(
|
|
src_ref=rhs,
|
|
dst_ref=rhs_smem,
|
|
swizzle=swizzle,
|
|
gmem_transform=rhs_transform,
|
|
barrier=barriers[1],
|
|
)
|
|
barriers[0].wait()
|
|
barriers[1].wait()
|
|
with mgpu.single_thread():
|
|
if lhs_transpose_tiles:
|
|
lhs_smem = memref_transpose(lhs_smem, (1, 0, 2, 3))
|
|
if lhs_transpose:
|
|
lhs_smem = memref_transpose(lhs_smem, (1, 0, 3, 2))
|
|
if rhs_transpose_tiles:
|
|
rhs_smem = memref_transpose(rhs_smem, (1, 0, 2, 3))
|
|
if rhs_transpose:
|
|
rhs_smem = memref_transpose(rhs_smem, (1, 0, 3, 2))
|
|
tcgen05.mma(
|
|
acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False,
|
|
)
|
|
tcgen05.commit_arrive(barriers[2])
|
|
barriers[2].wait(for_tensor_core=True)
|
|
acc[:].store_untiled(out)
|
|
|
|
in_finfo = jnp.finfo(in_jax_dtype)
|
|
exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant
|
|
def quantize(x):
|
|
# Quantize the input to avoid rounding when feeding the TensorCore
|
|
return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits)
|
|
|
|
x_shape = (k, m) if lhs_transpose else (m, k)
|
|
x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype)
|
|
y_shape = (n, k) if rhs_transpose else (k, n)
|
|
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
|
|
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
|
|
if rhs_transpose_tiles:
|
|
rhs_smem_shape = (
|
|
y_shape[1] // rhs_tiling[1], y_shape[0] // rhs_tiling[0], *rhs_tiling,
|
|
)
|
|
else:
|
|
rhs_smem_shape = tile_shape(y_shape, rhs_tiling)
|
|
if lhs_transpose_tiles:
|
|
lhs_smem_shape = (
|
|
x_shape[1] // lhs_tiling[1], x_shape[0] // lhs_tiling[0], *lhs_tiling,
|
|
)
|
|
else:
|
|
lhs_smem_shape = tile_shape(x_shape, lhs_tiling)
|
|
scratch_shape = [
|
|
jax.ShapeDtypeStruct(lhs_smem_shape, in_jax_dtype),
|
|
jax.ShapeDtypeStruct(rhs_smem_shape, in_jax_dtype),
|
|
mgpu.TMABarrier(3),
|
|
mgpu.TMEM((128, n), out_jax_dtype),
|
|
]
|
|
z = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape
|
|
)(x, y)
|
|
x32, y32 = x.astype(np.float32), y.astype(np.float32)
|
|
ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32)
|
|
atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6
|
|
np.testing.assert_allclose(z, ref, atol=atol)
|
|
|
|
@parameterized.product(
|
|
lhs_transpose=(False, True),
|
|
rhs_transpose=(False, True),
|
|
in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32
|
|
out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation
|
|
m=(256,), # TODO(apaszke): 64, 192, 256
|
|
n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2
|
|
k_steps=(1, 2),
|
|
swizzle=(32, 64, 128,),
|
|
)
|
|
def test_mma_collective(
|
|
self,
|
|
m,
|
|
n,
|
|
k_steps,
|
|
swizzle,
|
|
lhs_transpose,
|
|
rhs_transpose,
|
|
in_jax_dtype,
|
|
out_jax_dtype,
|
|
):
|
|
if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16:
|
|
raise self.skipTest("Only f16 input is supported for f16 output.")
|
|
|
|
in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype)
|
|
m_block_tile = m // 2
|
|
n_block_tile = n // 2
|
|
swizzle_elems = swizzle // bytewidth(in_mlir_dtype)
|
|
k = swizzle_elems * k_steps
|
|
index = ir.IndexType.get()
|
|
|
|
tiling = (8, swizzle_elems)
|
|
|
|
def kernel(ctx, lhs, rhs, out, scratch):
|
|
lhs_smem, rhs_smem, barriers, acc = scratch
|
|
block_id = gpu.cluster_block_id(gpu.Dimension.x)
|
|
ctx.async_copy(
|
|
src_ref=lhs,
|
|
dst_ref=lhs_smem,
|
|
swizzle=swizzle,
|
|
gmem_transform=mgpu.TileTransform(tiling),
|
|
barrier=barriers[0],
|
|
collective=gpu.Dimension.x,
|
|
partitioned=1 if lhs_transpose else 0, # Split non-contracting dim.
|
|
)
|
|
ctx.async_copy(
|
|
src_ref=rhs,
|
|
dst_ref=rhs_smem,
|
|
swizzle=swizzle,
|
|
gmem_transform=mgpu.TileTransform(tiling),
|
|
barrier=barriers[1],
|
|
collective=gpu.Dimension.x,
|
|
partitioned=0 if rhs_transpose else 1, # Split non-contracting dim.
|
|
)
|
|
is_leader_thread = single_thread_predicate()
|
|
is_first_block = arith.cmpi(arith.CmpIPredicate.eq, block_id, c(0, index))
|
|
with when(arith.andi(is_first_block, is_leader_thread)):
|
|
barriers[0].wait()
|
|
barriers[1].wait()
|
|
if lhs_transpose:
|
|
lhs_smem = memref_transpose(lhs_smem, (1, 0, 3, 2))
|
|
if rhs_transpose:
|
|
rhs_smem = memref_transpose(rhs_smem, (1, 0, 3, 2))
|
|
tcgen05.mma(
|
|
acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, collective=True
|
|
)
|
|
tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx)
|
|
barriers[2].wait(for_tensor_core=True)
|
|
m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile)
|
|
acc[:].store_untiled(memref_slice(out, m_slice))
|
|
|
|
in_finfo = jnp.finfo(in_jax_dtype)
|
|
exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant
|
|
def quantize(x):
|
|
# Quantize the input to avoid rounding when feeding the TensorCore
|
|
return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits)
|
|
|
|
x_shape = (k, m) if lhs_transpose else (m, k)
|
|
x_block_shape = (k, m_block_tile) if lhs_transpose else (m_block_tile, k)
|
|
x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype)
|
|
y_shape = (n, k) if rhs_transpose else (k, n)
|
|
y_block_shape = (n_block_tile, k) if rhs_transpose else (k, n_block_tile)
|
|
y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype)
|
|
out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype)
|
|
scratch_shape = [
|
|
jax.ShapeDtypeStruct(tile_shape(x_block_shape, tiling), in_jax_dtype),
|
|
jax.ShapeDtypeStruct(tile_shape(y_block_shape, tiling), in_jax_dtype),
|
|
mgpu.TMABarrier(3),
|
|
mgpu.TMEM((128, n), out_jax_dtype, collective=True),
|
|
]
|
|
z = mgpu.as_gpu_kernel(
|
|
kernel, (2, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape, cluster=(2, 1, 1)
|
|
)(x, y)
|
|
x32, y32 = x.astype(np.float32), y.astype(np.float32)
|
|
ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32)
|
|
atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6
|
|
np.testing.assert_allclose(z, ref, atol=atol)
|
|
|
|
|
|
class BarrierTest(TestCase):
|
|
|
|
def test_wg_communication(self):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
def kernel(ctx, dst, scratch):
|
|
tmp, barriers = scratch
|
|
del ctx # Unused.
|
|
wg_idx = arith.divui(mgpu.warp_idx(), c(4, i32))
|
|
is_first_wg = arith.cmpi(arith.CmpIPredicate.eq, wg_idx, c(0, i32))
|
|
is_second_wg = arith.cmpi(arith.CmpIPredicate.eq, wg_idx, c(1, i32))
|
|
arr = mgpu.FragmentedArray.splat(
|
|
arith.addi(wg_idx, c(1, i32)),
|
|
(128,),
|
|
mgpu.WGStridedFragLayout((128,), 1),
|
|
is_signed=False,
|
|
)
|
|
with ir.InsertionPoint(scf.IfOp(is_first_wg).then_block):
|
|
arr.store_untiled(tmp)
|
|
barriers[0].arrive() # Signal that tmp is ready.
|
|
barriers[1].wait() # Wait for the other warp to produce tmp.
|
|
final_arr = arr + mgpu.FragmentedArray.load_strided(
|
|
tmp, is_signed=False
|
|
)
|
|
final_arr.store_untiled(memref_slice(dst, 0))
|
|
scf.yield_([])
|
|
with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block):
|
|
barriers[0].wait()
|
|
final_arr = arr + mgpu.FragmentedArray.load_strided(
|
|
tmp, is_signed=False
|
|
)
|
|
barriers[2].arrive()
|
|
barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp.
|
|
arr.store_untiled(tmp)
|
|
barriers[1].arrive() # Signal that tmp is ready.
|
|
final_arr.store_untiled(memref_slice(dst, 1))
|
|
scf.yield_([])
|
|
out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32)
|
|
y = mgpu.as_gpu_kernel(
|
|
kernel,
|
|
(1, 1, 1),
|
|
(2 * 128, 1, 1),
|
|
(),
|
|
out_shape,
|
|
(
|
|
jax.ShapeDtypeStruct((128,), jnp.int32),
|
|
mgpu.Barrier(arrival_count=128, num_barriers=3),
|
|
),
|
|
)()
|
|
np.testing.assert_array_equal(y, np.full_like(y, 3, dtype=np.int32))
|
|
|
|
@parameterized.named_parameters(
|
|
(
|
|
f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}{'_group' if group_dims else ''}",
|
|
collective_dims,
|
|
noncollective_dims,
|
|
collective_size,
|
|
group_dims,
|
|
)
|
|
for collective_dims in itertools.chain.from_iterable(
|
|
itertools.combinations(Dimension, n) for n in range(1, 4)
|
|
)
|
|
for noncollective_dims in itertools.chain.from_iterable(
|
|
itertools.combinations(Dimension, n) for n in range(3)
|
|
)
|
|
for collective_size in (1, 2, 4)
|
|
for group_dims in (False,) + ((True,) if len(collective_dims) > 1 else ())
|
|
if all(d not in noncollective_dims for d in collective_dims)
|
|
)
|
|
def test_collective_arrive(self, collective_dims, noncollective_dims, collective_size, group_dims):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
index = ir.IndexType.get()
|
|
cluster = [1, 1, 1]
|
|
for d in collective_dims:
|
|
cluster[d] = collective_size
|
|
for d in noncollective_dims:
|
|
cluster[d] = 2
|
|
if math.prod(cluster) > 16:
|
|
self.skipTest("Cluster too big")
|
|
is_trivial = math.prod(cluster[d] for d in collective_dims) == 1
|
|
def kernel(ctx, dst, mask, collective_barrier):
|
|
cluster_idx = ctx.cluster_idx()
|
|
if not is_trivial:
|
|
memref.store(collective_barrier.cluster_mask, mask, [cluster_idx])
|
|
else:
|
|
assert collective_barrier.cluster_mask is None
|
|
collective_barrier.arrive()
|
|
collective_barrier.wait()
|
|
tid = thread_idx()
|
|
linear_idx = arith.index_cast(index, tid)
|
|
stride = c(128, index)
|
|
for d in gpu.Dimension:
|
|
linear_idx = arith.addi(linear_idx, arith.muli(gpu.block_id(d), stride))
|
|
stride = arith.muli(stride, gpu.grid_dim(d))
|
|
memref.store(arith.index_cast(i32, linear_idx), dst, [linear_idx])
|
|
out_shape = jax.ShapeDtypeStruct((math.prod(cluster) * 128,), jnp.int32)
|
|
mask_shape = jax.ShapeDtypeStruct((math.prod(cluster),), jnp.int32)
|
|
barrier_dims = collective_dims
|
|
if group_dims:
|
|
barrier_dims = (collective_dims[:2], *collective_dims[2:])
|
|
scratch = mgpu.ClusterBarrier(barrier_dims)
|
|
y, mask = mgpu.as_gpu_kernel(
|
|
kernel, cluster, (128, 1, 1), (), (out_shape, mask_shape), scratch, cluster=cluster,
|
|
)()
|
|
np.testing.assert_array_equal(
|
|
y, np.arange(math.prod(cluster) * 128, dtype=np.int32)
|
|
)
|
|
if not is_trivial:
|
|
# Verify that the mask is correct. Blocks are column-major, hence the transpose.
|
|
block_bits = 1 << np.arange(math.prod(cluster), dtype=np.int32).reshape(cluster[::-1]).T
|
|
expected_mask = 0
|
|
for bd in barrier_dims:
|
|
if isinstance(bd, gpu.Dimension):
|
|
bd = (bd,)
|
|
least_significant_slice = tuple(
|
|
slice(None) if d in bd else 0 for d in gpu.Dimension
|
|
)
|
|
mask_bits = block_bits[least_significant_slice]
|
|
expected_mask |= np.bitwise_or.reduce(mask_bits, axis=None)
|
|
self.assertEqual(min(mask), expected_mask)
|
|
|
|
|
|
class TMATest(TestCase):
|
|
|
|
@parameterized.product(
|
|
swizzle=(None, 32, 64, 128),
|
|
shape=((64, None), (5, None), (2, 3, 5, None)),
|
|
dtype=(jnp.float32, jnp.float16, jnp.int4),
|
|
)
|
|
def test_tma_load_basic(self, swizzle, shape, dtype):
|
|
bw = bitwidth(dtype_to_ir_type(dtype))
|
|
minor_size = 64 if swizzle is None else 8 * swizzle // bw
|
|
shape = (*shape[:-1], minor_size)
|
|
i1 = ir.IntegerType.get_signless(1)
|
|
def kernel(ctx, src, dst, smem):
|
|
tmp, barrier = smem
|
|
ctx.async_copy(src_ref=src, dst_ref=tmp, swizzle=swizzle, barrier=barrier)
|
|
barrier.wait_parity(c(0, i1))
|
|
copy(tmp, dst, swizzle=swizzle)
|
|
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
|
smem = (x, mgpu.TMABarrier())
|
|
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x)
|
|
np.testing.assert_array_equal(y, x)
|
|
|
|
@parameterized.named_parameters(
|
|
(
|
|
f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}",
|
|
collective_dims,
|
|
noncollective_dims,
|
|
collective_size,
|
|
)
|
|
for collective_dims in itertools.chain.from_iterable(
|
|
itertools.combinations(Dimension, n) for n in range(1, 4)
|
|
)
|
|
for noncollective_dims in itertools.chain.from_iterable(
|
|
itertools.combinations(Dimension, n) for n in range(3)
|
|
)
|
|
for collective_size in (1, 2, 4)
|
|
if all(d not in noncollective_dims for d in collective_dims)
|
|
)
|
|
def test_tma_load_multicast(self, collective_dims, noncollective_dims, collective_dim_size):
|
|
index = ir.IndexType.get()
|
|
swizzle = 128
|
|
dtype = jnp.float16
|
|
cluster = [1, 1, 1]
|
|
for d in collective_dims:
|
|
cluster[d] = collective_dim_size
|
|
for d in noncollective_dims:
|
|
cluster[d] = 2
|
|
if math.prod(cluster) > 16:
|
|
self.skipTest("Cluster too big")
|
|
collective_size = math.prod(cluster[d] for d in collective_dims)
|
|
noncollective_size = math.prod(cluster) // collective_size
|
|
# We use the 2 dimension to exercise splitting the collective over
|
|
# multiple dimensions when the cluster is large.
|
|
shape = (noncollective_size, 2, 16 * collective_size, 64)
|
|
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
|
|
shape = (*shape[:-1], minor_size)
|
|
# Note that this kernel does not use the non-collective dimensions in any
|
|
# interesting way and so they don't really have to be part of the cluster.
|
|
# We use them to test that the multicast mask is generated correctly.
|
|
def kernel(ctx, src, dst, scratch):
|
|
tmp, barrier = scratch
|
|
stride = 1
|
|
noncollective_idx = c(0, index)
|
|
for d in noncollective_dims:
|
|
noncollective_idx = arith.addi(
|
|
noncollective_idx,
|
|
arith.muli(gpu.cluster_block_id(d), c(stride, index))
|
|
)
|
|
stride *= cluster[d]
|
|
ctx.async_copy(
|
|
src_ref=src,
|
|
dst_ref=tmp,
|
|
gmem_slice=(noncollective_idx,),
|
|
swizzle=swizzle,
|
|
barrier=barrier,
|
|
collective=collective_dims,
|
|
)
|
|
barrier.wait()
|
|
# This is _not_ the real cluster block idx, because it does not consider
|
|
# the column-major ordering of the grid dimensions.
|
|
idx = c(0, index)
|
|
stride = 1
|
|
for d in collective_dims:
|
|
idx = arith.addi(
|
|
idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))
|
|
)
|
|
stride *= cluster[d]
|
|
idx_minor = arith.divui(idx, c(2, index))
|
|
idx_major = arith.remui(idx, c(2, index))
|
|
slc_minor = ds(
|
|
arith.muli(idx_minor, c(16 * 2, index)), 16 * 2
|
|
)
|
|
copy(
|
|
memref_slice(tmp, (idx_major, slc_minor)),
|
|
memref_slice(dst, (noncollective_idx, idx_major, slc_minor)),
|
|
swizzle=swizzle,
|
|
)
|
|
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
|
smem_shape = (jax.ShapeDtypeStruct(shape[1:], dtype), mgpu.TMABarrier())
|
|
y = mgpu.as_gpu_kernel(
|
|
kernel, cluster, (128, 1, 1), x, x, smem_shape, cluster=cluster
|
|
)(x)
|
|
np.testing.assert_array_equal(y, x)
|
|
|
|
@parameterized.product(
|
|
swizzle=(None, 128),
|
|
shape=((128, 128), (5, 32, 128)),
|
|
dtype=(jnp.float16, jnp.float32),
|
|
)
|
|
def test_tma_load_tiled(self, swizzle, shape, dtype):
|
|
# TODO(apaszke): ptxas seems to freeze when generating code for copy with
|
|
# swizzle 32 and 64.
|
|
i1 = ir.IntegerType.get_signless(1)
|
|
index = ir.IndexType.get()
|
|
tiling = (32, (swizzle or 128) // jnp.dtype(dtype).itemsize)
|
|
tiled_shape = tile_shape(shape, tiling)[:len(shape)]
|
|
def kernel(ctx, src, dst, scratch):
|
|
tmp, barrier = scratch
|
|
ctx.async_copy(
|
|
src_ref=src,
|
|
dst_ref=tmp,
|
|
swizzle=swizzle,
|
|
barrier=barrier,
|
|
gmem_transform=mgpu.TileTransform(tiling),
|
|
)
|
|
barrier.wait_parity(c(0, i1))
|
|
for idxs in np.ndindex(tiled_shape):
|
|
untiled_idxs, tiled_idxs = idxs[:-len(tiling)], idxs[-len(tiling):]
|
|
s = (
|
|
*untiled_idxs,
|
|
*(ds(c(ix * t, index), t) for ix, t in zip(tiled_idxs, tiling)),
|
|
)
|
|
copy(memref_slice(tmp, idxs), memref_slice(dst, s), swizzle=swizzle)
|
|
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
|
smem = (
|
|
jax.ShapeDtypeStruct(tile_shape(shape, tiling), dtype),
|
|
mgpu.TMABarrier(),
|
|
)
|
|
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)
|
|
y = f(x)
|
|
np.testing.assert_array_equal(y, x)
|
|
|
|
@parameterized.product(swizzle=(None, 128))
|
|
def test_tma_load_tiled_rounding(self, swizzle):
|
|
# TODO(apaszke): ptxas seems to freeze when generating code for copy with
|
|
# swizzle 32 and 64.
|
|
shape = (5, 32, 144)
|
|
dtype = jnp.float16
|
|
i1 = ir.IntegerType.get_signless(1)
|
|
index = ir.IndexType.get()
|
|
tiling = (32, (swizzle or 128) // jnp.dtype(dtype).itemsize)
|
|
rounded_shape = (*shape[:-1], shape[-1] // tiling[-1] * tiling[-1])
|
|
tiled_shape = tile_shape(rounded_shape, tiling)[:len(shape)]
|
|
def kernel(ctx, src, dst, scratch):
|
|
tmp, barrier = scratch
|
|
ctx.async_copy(
|
|
src_ref=src,
|
|
dst_ref=tmp,
|
|
swizzle=swizzle,
|
|
barrier=barrier,
|
|
gmem_transform=mgpu.TileTransform(
|
|
tiling, rounding=mgpu.Rounding.DOWN
|
|
),
|
|
)
|
|
barrier.wait_parity(c(0, i1))
|
|
for idxs in np.ndindex(tiled_shape):
|
|
untiled_idxs, tiled_idxs = idxs[:-len(tiling)], idxs[-len(tiling):]
|
|
s = (
|
|
*untiled_idxs,
|
|
*(ds(c(ix * t, index), t) for ix, t in zip(tiled_idxs, tiling)),
|
|
)
|
|
copy(memref_slice(tmp, idxs), memref_slice(dst, s), swizzle=swizzle)
|
|
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
|
tmp_shape = jax.ShapeDtypeStruct(tile_shape(rounded_shape, tiling), dtype)
|
|
smem = (tmp_shape, mgpu.TMABarrier())
|
|
out_shape = jax.ShapeDtypeStruct(rounded_shape, dtype)
|
|
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, out_shape, smem)
|
|
y = f(x)
|
|
np.testing.assert_array_equal(y, x[..., :rounded_shape[-1]])
|
|
|
|
def test_tma_load_indexed_tiled(self):
|
|
shape = (128, 2, 128)
|
|
tiling = mgpu.TileTransform((32, 32))
|
|
def kernel(ctx, src, dst, scratch):
|
|
tmp, barrier = scratch
|
|
ctx.async_copy(
|
|
src_ref=src,
|
|
dst_ref=tmp,
|
|
barrier=barrier,
|
|
gmem_transform=tiling,
|
|
gmem_slice=(slice(None), 1, slice(None)),
|
|
)
|
|
barrier.wait()
|
|
ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_transform=tiling)
|
|
ctx.await_async_copy(0)
|
|
x = np.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
|
|
smem = (
|
|
jax.ShapeDtypeStruct((4, 4, 32, 32), jnp.float32),
|
|
mgpu.TMABarrier(),
|
|
)
|
|
out_shape = jax.ShapeDtypeStruct((128, 128), jnp.float32)
|
|
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, out_shape, smem)
|
|
np.testing.assert_array_equal(f(x), x[:, 1, :])
|
|
|
|
@parameterized.product(
|
|
swizzle=(None, 128),
|
|
dtype=(jnp.float16, jnp.float32),
|
|
)
|
|
def test_tma_squeeze_indexing(self, swizzle, dtype):
|
|
# TODO(apaszke): ptxas seems to freeze when generating code for copy with
|
|
# swizzle 32 and 64.
|
|
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
|
|
shape = (4, 5, minor_size)
|
|
def kernel(ctx, src, dst, smem):
|
|
tmp, barrier = smem
|
|
for i in range(4):
|
|
ctx.async_copy(
|
|
src_ref=src,
|
|
dst_ref=memref_slice(tmp, i),
|
|
gmem_slice=i,
|
|
swizzle=swizzle,
|
|
barrier=barrier,
|
|
)
|
|
barrier.wait()
|
|
copy(tmp, dst, swizzle=swizzle)
|
|
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
|
smem = (x, mgpu.TMABarrier())
|
|
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x)
|
|
np.testing.assert_array_equal(y, x)
|
|
|
|
def test_parity_tracking(self):
|
|
shape = (16, 64)
|
|
index = ir.IndexType.get()
|
|
def kernel(ctx, src, dst, smem):
|
|
tmp, barrier = smem
|
|
for i in range(shape[0]):
|
|
s = ds(c(i, index), 1)
|
|
ctx.async_copy(
|
|
src_ref=src, dst_ref=tmp, gmem_slice=s, barrier=barrier,
|
|
)
|
|
barrier.wait()
|
|
copy(tmp, memref_slice(dst, s))
|
|
x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape)
|
|
y = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), x, x, (x[0:1], mgpu.TMABarrier())
|
|
)(x)
|
|
np.testing.assert_array_equal(y, x)
|
|
|
|
@parameterized.product(
|
|
swizzle=(None, 32, 64, 128),
|
|
shape=((64, None), (5, None), (2, 3, 5, None)),
|
|
dtype=(jnp.float16, jnp.float32),
|
|
)
|
|
def test_tma_store(self, swizzle, shape, dtype):
|
|
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
|
|
shape = (*shape[:-1], minor_size)
|
|
def kernel(ctx, src, dst, tmp):
|
|
copy(src, tmp, swizzle=swizzle)
|
|
ctx.async_copy(src_ref=tmp, dst_ref=dst, swizzle=swizzle)
|
|
ctx.await_async_copy(0)
|
|
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
|
y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
|
|
np.testing.assert_array_equal(y, x)
|
|
|
|
@parameterized.parameters(0, 1)
|
|
def test_tma_small_tile_load(self, small_dim):
|
|
if small_dim == 0:
|
|
shape = (4, 128)
|
|
elif small_dim == 1:
|
|
shape = (128, 8)
|
|
else:
|
|
raise ValueError("small_dim must be 0 or 1")
|
|
tiled_shape = ((shape[0] + 63) // 64, (shape[1] + 63) // 64, 64, 64)
|
|
padded_shape = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2]))
|
|
def kernel(ctx, src, dst, smem):
|
|
tmp, barrier = smem
|
|
ctx.async_copy(
|
|
src_ref=src,
|
|
dst_ref=tmp,
|
|
swizzle=128,
|
|
gmem_transform=mgpu.TileTransform((64, 64)),
|
|
gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])),
|
|
barrier=barrier,
|
|
)
|
|
barrier.wait()
|
|
copy(tmp, dst, swizzle=128)
|
|
x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape)
|
|
tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16)
|
|
y_tiled = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), x, tiled, (tiled, mgpu.TMABarrier()),
|
|
)(x)
|
|
y = y_tiled.swapaxes(1, 2).reshape(padded_shape)
|
|
# y should contain x and zero everywhere else.
|
|
np.testing.assert_array_equal(y[:shape[0], :shape[1]], x)
|
|
y_mut = np.asarray(y).copy()
|
|
y_mut[:shape[0], :shape[1]] = 0
|
|
np.testing.assert_array_equal(y_mut, np.zeros_like(y_mut))
|
|
|
|
@parameterized.product(
|
|
small_dim=(0, 1),
|
|
tiled_layout=(False, True),
|
|
)
|
|
def test_tma_small_tile_store(self, small_dim, tiled_layout):
|
|
if small_dim == 0:
|
|
shape = (4, 128)
|
|
elif small_dim == 1:
|
|
shape = (128, 8)
|
|
else:
|
|
raise ValueError("small_dim must be 0 or 1")
|
|
tiled_shape = ((shape[0] + 63) // 64, (shape[1] + 63) // 64, 64, 64)
|
|
m, n = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2]))
|
|
def kernel(ctx, dst, tmp):
|
|
vals = iota_tensor(m, n, jnp.float16)
|
|
vals.store_tiled(tmp, swizzle=128)
|
|
ctx.async_copy(
|
|
src_ref=tmp,
|
|
dst_ref=dst,
|
|
swizzle=128,
|
|
gmem_transform=mgpu.TileTransform((64, 64)),
|
|
gmem_slice=(ds(0, m), ds(0, n)),
|
|
)
|
|
ctx.await_async_copy(0)
|
|
tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16)
|
|
out = jax.ShapeDtypeStruct(shape, jnp.float16)
|
|
y = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out, tiled,
|
|
)()
|
|
iota = np.arange(m * n, dtype=jnp.float16).reshape([m, n])
|
|
np.testing.assert_array_equal(y, iota[:shape[0], :shape[1]])
|
|
|
|
def test_tma_invalid(self):
|
|
def kernel(ctx, src, dst, tmp):
|
|
copy(src, tmp)
|
|
ctx.async_copy(src_ref=tmp, dst_ref=dst)
|
|
ctx.await_async_copy(0)
|
|
|
|
def run_kernel(shape):
|
|
x = np.arange(np.prod(shape)).reshape(shape)
|
|
_ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)
|
|
|
|
with self.assertRaisesRegex(ValueError, "all GMEM strides except the last"):
|
|
run_kernel([1] * 6)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "last dimension to be divisible by 16"
|
|
):
|
|
run_kernel([23])
|
|
|
|
|
|
class FragmentedArrayTest(TestCase):
|
|
|
|
@parameterized.product(
|
|
op=(
|
|
operator.add,
|
|
operator.mul,
|
|
operator.sub,
|
|
(lambda x, y: mgpu.FragmentedArray.min(x, y), np.minimum),
|
|
(lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum),
|
|
),
|
|
dtype=[jnp.float32, jnp.int32, jnp.uint32],
|
|
m=(64, 128),
|
|
n=(8, 16, 32, 64, 80, 128, 256),
|
|
)
|
|
@jtu.ignore_warning(
|
|
message="(invalid value|divide by zero)", category=RuntimeWarning
|
|
)
|
|
def test_binary(self, op, dtype, m=64, n=32):
|
|
if isinstance(op, tuple):
|
|
op, np_op = op
|
|
else:
|
|
np_op = op
|
|
|
|
for scalar_rhs in [None, 2]:
|
|
def kernel(ctx, dst, _):
|
|
mlir_dtype = utils.dtype_to_ir_type(dtype)
|
|
iota = iota_tensor(m, n, dtype)
|
|
rhs = iota if scalar_rhs is None else c(scalar_rhs, mlir_dtype)
|
|
op(iota, rhs).store_untiled(dst)
|
|
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
ref_x = np.arange(m * n, dtype=dtype).reshape(m, n)
|
|
ref_rhs = scalar_rhs or ref_x
|
|
np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs))
|
|
|
|
def test_minimum_np_compatibility(self):
|
|
one = np.ones((128, 128)).astype(np.float32)
|
|
negz = one * -0.
|
|
posz = one * 0.
|
|
nan = one * np.nan
|
|
expectation = (np.minimum(negz, posz) == negz) & (np.minimum(nan, one) != one)
|
|
assert np.all(expectation), expectation
|
|
|
|
def kernel(ctx, dst, _):
|
|
f32 = ir.F32Type.get()
|
|
splat = lambda i: mgpu.FragmentedArray.splat(c(i, f32), (128, 128))
|
|
negz = splat(-0.)
|
|
posz = splat(0.)
|
|
nan = splat(np.nan)
|
|
one = splat(1.)
|
|
res = (negz.min(posz) == negz) & (one.min(nan) != one) & (nan.min(one) != one)
|
|
i8 = ir.IntegerType.get_signless(8)
|
|
res.astype(i8, is_signed=False).store_untiled(dst)
|
|
|
|
out_shape = jax.ShapeDtypeStruct((128, 128), np.int8)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
np.testing.assert_array_equal(result == 1, expectation)
|
|
|
|
@parameterized.product(
|
|
op=[operator.truediv, operator.floordiv, operator.mod],
|
|
dtype=[jnp.float32, jnp.int32, jnp.uint32],
|
|
)
|
|
def test_division(self, op, dtype, m=64, n=32):
|
|
if jnp.issubdtype(dtype, jnp.integer) and op is operator.truediv:
|
|
self.skipTest("Unsupported for integer types")
|
|
if jnp.issubdtype(dtype, jnp.floating) and op is operator.mod:
|
|
self.skipTest("Unsupported for floating types")
|
|
|
|
def kernel(ctx, dst, _):
|
|
iota = iota_tensor(m, n, dtype)
|
|
op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst)
|
|
|
|
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
iota = np.arange(m * n, dtype=dtype).reshape(m, n)
|
|
np.testing.assert_allclose(
|
|
result, op(dtype(4.2).item() * iota, iota + 1), atol=2e-7
|
|
)
|
|
|
|
@parameterized.product(
|
|
op=[
|
|
operator.lt,
|
|
operator.le,
|
|
operator.gt,
|
|
operator.ge,
|
|
operator.eq,
|
|
operator.ne,
|
|
],
|
|
dtype=[jnp.float32, jnp.int32, jnp.uint32],
|
|
rhs_is_literal=[False, True]
|
|
)
|
|
def test_comparison(self, op, dtype, rhs_is_literal, m=64, n=32):
|
|
def kernel(ctx, dst, _):
|
|
iota = iota_tensor(m, n, dtype)
|
|
rhs = 0 if rhs_is_literal else iota + 1
|
|
op(iota, rhs).store_untiled(dst)
|
|
|
|
out_shape = jax.ShapeDtypeStruct((m, n), jnp.bool)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
iota = np.arange(m * n, dtype=dtype).reshape(m, n)
|
|
rhs = rhs = 0 if rhs_is_literal else iota + 1
|
|
np.testing.assert_array_equal(result, op(iota, rhs))
|
|
|
|
@parameterized.product(tiled_layout=(False, True))
|
|
def test_foreach(self, tiled_layout):
|
|
dtype = jnp.int32
|
|
swizzle = 128
|
|
tile = 64, swizzle // jnp.dtype(dtype).itemsize
|
|
shape = 128, 192
|
|
tiled_shape = mgpu.tile_shape(shape, tile)
|
|
mlir_dtype = utils.dtype_to_ir_type(dtype)
|
|
cst = 9999
|
|
def causal(val, idx):
|
|
row, col = idx
|
|
mask = arith.cmpi(arith.CmpIPredicate.uge, row, col)
|
|
return arith.select(mask, val, c(cst, mlir_dtype))
|
|
|
|
tiling = mgpu.TileTransform(tile)
|
|
def kernel(ctx, dst, smem):
|
|
x = iota_tensor(shape[0], shape[1], dtype, tiled_layout=tiled_layout)
|
|
x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem)
|
|
mgpu.commit_shared()
|
|
ctx.async_copy(src_ref=smem, dst_ref=dst)
|
|
ctx.await_async_copy(0)
|
|
|
|
iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel,
|
|
(1, 1, 1),
|
|
(128, 1, 1),
|
|
(),
|
|
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
|
|
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
|
|
)()
|
|
expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst
|
|
np.testing.assert_array_equal(result, expected)
|
|
|
|
@parameterized.product(
|
|
op=[operator.and_, operator.or_, operator.xor],
|
|
dtype=[jnp.uint32],
|
|
)
|
|
def test_bitwise(self, op, dtype, m=64, n=8):
|
|
def kernel(ctx, dst, _):
|
|
iota = iota_tensor(m, n, dtype)
|
|
op(iota, iota + 1).store_untiled(dst)
|
|
|
|
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
iota = np.arange(m * n, dtype=dtype).reshape(m, n)
|
|
np.testing.assert_array_equal(result, op(iota, iota + 1))
|
|
|
|
@parameterized.product(
|
|
ops=(
|
|
(lambda x: -x, jax.lax.neg),
|
|
(lambda x: x + 42, lambda x: x + 42),
|
|
(lambda x: x.tanh(), jax.lax.tanh),
|
|
),
|
|
dtype=[jnp.float32, jnp.int32, jnp.uint32],
|
|
)
|
|
def test_unary(self, ops, dtype, m=64, n=32):
|
|
op, np_op = ops
|
|
if np_op is jax.lax.tanh and jnp.issubdtype(dtype, jnp.integer):
|
|
raise self.skipTest("Tanh not supported for integer types")
|
|
|
|
def kernel(ctx, dst, _):
|
|
iota = iota_tensor(m, n, dtype)
|
|
op(iota).store_untiled(dst)
|
|
|
|
out_shape = jax.ShapeDtypeStruct((m, n), dtype)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
x = np.arange(m * n, dtype=dtype).reshape(m, n)
|
|
np.testing.assert_allclose(result, np_op(x), atol=2e-7, rtol=2e-7)
|
|
|
|
def test_select(self, m=64, n=32):
|
|
|
|
def kernel(ctx, dst, _):
|
|
iota = iota_tensor(m, n, jnp.int32)
|
|
(iota < 16).select(iota * 2, iota * 3).store_untiled(dst)
|
|
|
|
out_shape = jax.ShapeDtypeStruct((m, n), jnp.int32)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
x = np.arange(m * n, dtype=jnp.int32).reshape(m, n)
|
|
np.testing.assert_array_equal(result, np.where(x < 16, x * 2, x * 3))
|
|
|
|
@parameterized.product(
|
|
ops=[
|
|
(lambda x: mgpu.FragmentedArray.exp(x), np.exp),
|
|
(lambda x: mgpu.FragmentedArray.sin(x), np.sin),
|
|
(lambda x: mgpu.FragmentedArray.cos(x), np.cos),
|
|
(lambda x: mgpu.FragmentedArray.rsqrt(x), jax.lax.rsqrt),
|
|
],
|
|
approx=[False, True],
|
|
)
|
|
@jtu.ignore_warning(message="overflow encountered", category=RuntimeWarning)
|
|
def test_math(self, ops, approx, m=64, n=32):
|
|
op, np_op = ops
|
|
def kernel(ctx, dst, _):
|
|
iota = iota_tensor(m, n, jnp.float32)
|
|
op(iota).store_untiled(dst)
|
|
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n)
|
|
atol = 5e-3 if approx else 2e-7
|
|
rtol = 4e-6 if approx else 2e-7
|
|
np.testing.assert_allclose(result, np_op(x), atol=atol, rtol=rtol)
|
|
|
|
@parameterized.product(
|
|
dtype=[jnp.float32, jnp.int32],
|
|
m=[128],
|
|
n=[32, 64],
|
|
)
|
|
def test_strided_reduce_sum(self, dtype, m, n):
|
|
def kernel(ctx, src, dst, scratch):
|
|
src = mgpu.FragmentedArray.load_strided(
|
|
src, is_signed=utils.is_signed(dtype)
|
|
)
|
|
acc = src.reduce_sum(scratch).broadcast((m,))
|
|
acc.store_untiled(dst)
|
|
|
|
in_shape = jax.ShapeDtypeStruct((m, n), dtype)
|
|
out_shape = jax.ShapeDtypeStruct((m,), dtype)
|
|
kernel_fn = mgpu.as_gpu_kernel(
|
|
kernel,
|
|
(1, 1, 1),
|
|
(128, 1, 1),
|
|
in_shape,
|
|
out_shape,
|
|
smem_scratch_shape=jax.ShapeDtypeStruct((4,), dtype),
|
|
)
|
|
x = np.arange(m * n, dtype=dtype).reshape(m, n)
|
|
np.testing.assert_array_equal(kernel_fn(x), jnp.full((m,), x.sum()))
|
|
|
|
@parameterized.product(
|
|
dtype=[jnp.float32, jnp.int32],
|
|
m=[128],
|
|
n=[32, 64],
|
|
)
|
|
def test_splat_reduce_sum(self, dtype, m, n):
|
|
def kernel(ctx, dst, _):
|
|
src = mgpu.FragmentedArray.splat(
|
|
utils.c(1, utils.dtype_to_ir_type(dtype)),
|
|
(m, n),
|
|
is_signed=utils.is_signed(dtype),
|
|
)
|
|
acc = src.reduce_sum().broadcast((m,))
|
|
acc.store_untiled(dst)
|
|
|
|
kernel_fn = mgpu.as_gpu_kernel(
|
|
kernel,
|
|
(1, 1, 1),
|
|
(128, 1, 1),
|
|
in_shape=(),
|
|
out_shape=jax.ShapeDtypeStruct((m,), dtype),
|
|
smem_scratch_shape=(),
|
|
)
|
|
np.testing.assert_array_equal(kernel_fn(), jnp.full((m,), m * n * 1.0))
|
|
|
|
@parameterized.product(
|
|
op=(arith.addf, arith.maximumf),
|
|
m=(64, 128),
|
|
n=(8, 16, 32, 64, 80, 128, 256),
|
|
)
|
|
def test_reduce(self, op, m=64, n=32):
|
|
def kernel(ctx, dst, _):
|
|
iota = iota_tensor(m, n, jnp.float32)
|
|
iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst)
|
|
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n)
|
|
if op == arith.addf:
|
|
expected = np.broadcast_to(x.sum(axis=1, keepdims=True), x.shape)
|
|
elif op == arith.maximumf:
|
|
expected = np.broadcast_to(x.max(axis=1, keepdims=True), x.shape)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported op: {op}")
|
|
np.testing.assert_array_equal(result, expected)
|
|
|
|
def test_splat_layout(self):
|
|
m, n = 64, 8
|
|
def kernel(ctx, dst, _):
|
|
iota = iota_tensor(m, n, jnp.float32)
|
|
cte = c(1, iota.mlir_dtype)
|
|
cte_arr = mgpu.FragmentedArray.splat(cte, ())
|
|
cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n))
|
|
(iota + cte_arr).store_untiled(dst)
|
|
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
expected = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + 1
|
|
np.testing.assert_array_equal(result, expected)
|
|
|
|
def test_splat(self):
|
|
def kernel(ctx, dst, _):
|
|
f32 = ir.F32Type.get()
|
|
v = arith.constant(f32, ir.FloatAttr.get(f32, 3.14))
|
|
t = mgpu.FragmentedArray.splat(
|
|
v, (128,), mgpu.WGMMA_ROW_LAYOUT
|
|
)
|
|
t.broadcast_minor(32).store_untiled(dst)
|
|
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32))
|
|
|
|
def test_splat_binary_ops(self):
|
|
def kernel(ctx, src, dst, _):
|
|
f32 = ir.F32Type.get()
|
|
pi_arr = mgpu.FragmentedArray.load_strided(src)
|
|
assert isinstance(pi_arr.layout, mgpu.WGStridedFragLayout)
|
|
pi_scalar = arith.constant(f32, ir.FloatAttr.get(f32, 3.14))
|
|
pi_splat = mgpu.FragmentedArray.splat(pi_scalar, ())
|
|
assert isinstance(pi_splat.layout, mgpu.WGSplatFragLayout)
|
|
pi_arr_sq = pi_arr * pi_splat.broadcast(pi_arr.shape)
|
|
assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout)
|
|
pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq
|
|
assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout)
|
|
(pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst)
|
|
|
|
out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32)
|
|
inp = jnp.ones_like(out_shape) * 3.14
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), inp, out_shape, ()
|
|
)(inp)
|
|
np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32))
|
|
|
|
@parameterized.product(in_shape=((128, 128), (128, 64), (64, 128)))
|
|
def test_strided_load_store(self, in_shape):
|
|
def kernel(ctx, *args):
|
|
gmem_input, gmem_output, (smem_input, smem_output) = args
|
|
copy(gmem_input, smem_input)
|
|
t = mgpu.FragmentedArray.load_strided(smem_input)
|
|
t.store_untiled(smem_output)
|
|
copy(smem_output, gmem_output)
|
|
|
|
inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out],
|
|
)(inp)
|
|
np.testing.assert_array_equal(inp, result)
|
|
|
|
def test_warp_tree_reduce(self):
|
|
def kernel(ctx, out, *_):
|
|
del ctx
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
tid = gpu.thread_id(gpu.Dimension.x)
|
|
value = arith.index_cast(i32, tid)
|
|
grp = warp_tree_reduce(value, arith.addi, 4)
|
|
memref.store(grp, out, [tid])
|
|
|
|
x = np.arange(128, dtype=jnp.int32)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), x, [],
|
|
)()
|
|
for i in range(0, 128, 4):
|
|
x[i:i + 4] = jnp.sum(x[i:i + 4])
|
|
|
|
np.testing.assert_array_equal(result, x)
|
|
|
|
@parameterized.parameters(2, 4)
|
|
def test_fast_i8_convert(self, reg_length):
|
|
jax_dtype_to = jnp.dtype(jnp.bfloat16)
|
|
jax_dtype_from = jnp.dtype(jnp.int8)
|
|
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
|
|
def kernel(ctx, inp, out, smem):
|
|
del ctx, smem
|
|
arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True)
|
|
assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length]
|
|
arr.astype(mlir_dtype_to).store_untiled(out)
|
|
|
|
x = jnp.arange(-128, 128, dtype=jax_dtype_from)
|
|
x = jnp.tile(x, reg_length // 2)
|
|
reference = x.astype(jax_dtype_to)
|
|
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), x, reference, None,
|
|
)(x)
|
|
np.testing.assert_array_equal(result, reference)
|
|
|
|
@parameterized.parameters(
|
|
([64 * 4], "WGMMA_ROW_LAYOUT"),
|
|
([64 * 4, 8 * 2], "WGMMA_LAYOUT"),
|
|
([64 * 4, 8 * 2], "TILED_LAYOUT_WGMMA"),
|
|
)
|
|
def test_to_layout(self, shape, new_layout):
|
|
def kernel(ctx, _):
|
|
# No assertions, we are just checking there are no compile-time errors.
|
|
arr = mgpu.FragmentedArray.splat(c(42.0, ir.F32Type.get()), shape)
|
|
arr.to_layout(getattr(mgpu, new_layout))
|
|
|
|
_ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), (), None)()
|
|
|
|
@parameterized.parameters(
|
|
(jnp.float16, jnp.float16), # Noop
|
|
(jnp.int16, jnp.bfloat16),
|
|
(jnp.int16, jnp.float16),
|
|
(jnp.uint16, jnp.float16),
|
|
(jnp.float32, jnp.int32),
|
|
(jnp.float32, jnp.uint32),
|
|
(jnp.uint32, jnp.int32),
|
|
(jnp.int32, jnp.uint32),
|
|
)
|
|
def test_bitcast(self, in_dtype, out_dtype):
|
|
out_ir_type = utils.dtype_to_ir_type(out_dtype)
|
|
in_is_signed = utils.is_signed(in_dtype)
|
|
out_is_signed = utils.is_signed(out_dtype)
|
|
|
|
def kernel(ctx, inp, out, smem):
|
|
del ctx, smem
|
|
arr = mgpu.FragmentedArray.load_strided(inp, is_signed=in_is_signed)
|
|
arr = arr.bitcast(out_ir_type, output_is_signed=out_is_signed)
|
|
arr.store_untiled(out)
|
|
|
|
x = jnp.arange(256, dtype=in_dtype)
|
|
reference = jax.lax.bitcast_convert_type(x, out_dtype)
|
|
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel,
|
|
(1, 1, 1),
|
|
(128, 1, 1),
|
|
x,
|
|
reference,
|
|
None,
|
|
)(x)
|
|
np.testing.assert_array_equal(result, reference)
|
|
|
|
@parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16)
|
|
def test_optimization_barrier(self, dtype):
|
|
def kernel(ctx, inp, out, smem):
|
|
del ctx, smem
|
|
arr = mgpu.FragmentedArray.load_strided(inp)
|
|
arr2 = arr * 2
|
|
arr, arr2 = mgpu.optimization_barrier(arr, arr2)
|
|
(arr + arr2).store_untiled(out)
|
|
|
|
x = jnp.arange(256, dtype=dtype)
|
|
|
|
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, None)
|
|
np.testing.assert_array_equal(f(x), x * 3)
|
|
|
|
def test_convert_bool_to_u8(self):
|
|
m, n = 128, 128
|
|
def kernel(ctx, dst, _):
|
|
i8 = ir.IntegerType.get_signless(8)
|
|
iota = iota_tensor(m, n, jnp.uint8)
|
|
(iota > 10).astype(i8, is_signed=False).store_untiled(dst)
|
|
|
|
out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
iota = np.arange(m * n, dtype=jnp.uint8).reshape(m, n)
|
|
np.testing.assert_array_equal(result, (iota > 10).astype(jnp.uint8))
|
|
|
|
@parameterized.parameters(
|
|
(jnp.uint8, jnp.uint16, 255),
|
|
(jnp.uint8, jnp.int16, 255),
|
|
(jnp.int8, jnp.uint16, -127),
|
|
(jnp.int8, jnp.int16, -127),
|
|
)
|
|
def test_convert_int_uint(self, from_dtype, to_dtype, value):
|
|
m, n = 1, 128
|
|
def kernel(ctx, dst, _):
|
|
i8 = ir.IntegerType.get_signless(8)
|
|
from_mlir_dtype = utils.dtype_to_ir_type(from_dtype)
|
|
to_mlir_dtype = utils.dtype_to_ir_type(to_dtype)
|
|
from_arr = mgpu.FragmentedArray.splat(
|
|
c(value, from_mlir_dtype),
|
|
(m, n),
|
|
is_signed=utils.is_signed(from_dtype),
|
|
)
|
|
to_arr = from_arr.astype(to_mlir_dtype, is_signed=utils.is_signed(to_dtype))
|
|
to_arr.store_untiled(dst)
|
|
|
|
out_shape = jax.ShapeDtypeStruct((m, n), to_dtype)
|
|
result = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
|
|
)()
|
|
expected = jnp.full((m, n), value, dtype=from_dtype).astype(to_dtype)
|
|
np.testing.assert_array_equal(result, expected)
|
|
|
|
|
|
class ProfilerTest(TestCase):
|
|
|
|
def test_measure_events_explicit(self):
|
|
x = jnp.arange(1024 * 1024)
|
|
_, runtime_ms = profiler.measure(lambda x, y: x + y, mode="events")(x, x)
|
|
self.assertIsInstance(runtime_ms, float)
|
|
|
|
def test_profile(self):
|
|
def kernel(ctx, src, dst, _):
|
|
mgpu.FragmentedArray.load_strided(src).store_untiled(dst)
|
|
x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64)
|
|
spec = profiler.ProfilerSpec(1024)
|
|
# This is just a smoke test.
|
|
f = jax.jit(mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), x, x, (), prof_spec=spec
|
|
))
|
|
jax.block_until_ready(f(x))
|
|
|
|
def test_multigpu(self):
|
|
if len(jax.devices()) < 2:
|
|
self.skipTest("Need at least 2 devices")
|
|
def kernel(ctx, src, dst, _):
|
|
mgpu.FragmentedArray.load_strided(src).store_untiled(dst)
|
|
x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64)
|
|
f = jax.jit(mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), x, x, ()
|
|
))
|
|
# Make sure we can invoke the same program on different devices.
|
|
for xd in (jax.device_put(x, d) for d in jax.devices()[:2]):
|
|
jax.block_until_ready(f(xd))
|
|
|
|
|
|
class TorchTest(TestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
raise unittest.SkipTest("Test requires PyTorch")
|
|
self.torch = torch
|
|
|
|
def test_basic(self):
|
|
def kernel(ctx, i_gmem, o_gmem, _):
|
|
x = mgpu.FragmentedArray.load_strided(i_gmem)
|
|
(x + x).store_untiled(o_gmem)
|
|
|
|
ty = jax.ShapeDtypeStruct((128, 128), jnp.float32)
|
|
x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda')
|
|
f = mgpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ())
|
|
y = f(x)
|
|
np.testing.assert_allclose(y.cpu(), x.cpu() * 2)
|
|
del y # Make sure the destructor runs successfully.
|
|
|
|
|
|
class LayoutTest(TestCase):
|
|
|
|
@parameterized.named_parameters(
|
|
("f32", jnp.float32, 256),
|
|
("f16", jnp.float16, 256),
|
|
("f16_small", jnp.float16, 128),
|
|
)
|
|
def test_store_untiled_splat(self, jax_dtype, size):
|
|
mlir_dtype = utils.dtype_to_ir_type(jax_dtype)
|
|
def kernel(ctx, out, _):
|
|
del ctx
|
|
arr = mgpu.FragmentedArray.splat(
|
|
c(1.0, mlir_dtype), (size,), is_signed=utils.is_signed(jax_dtype)
|
|
)
|
|
arr.store_untiled(out)
|
|
expected = np.ones((size,), jax_dtype)
|
|
mosaic_ones = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), (), expected, ()
|
|
)()
|
|
np.testing.assert_array_equal(mosaic_ones, expected)
|
|
|
|
@parameterized.product(
|
|
shape=((128, 128), (64, 8), (64, 256)),
|
|
dtype=(jnp.int32, jnp.int16, jnp.int8),
|
|
)
|
|
def test_wgmma_tiled_layout(self, shape, dtype):
|
|
def kernel(ctx, dst, _):
|
|
iota = iota_tensor(*shape, dtype)
|
|
tiled = iota.to_layout(fa._tiled_wgmma_layout(shape))
|
|
# Note that WGMMA layouts are always (shape[0] // 64, shape[1] // 8, 2, 1)
|
|
self.assertEqual(
|
|
tiled.registers.shape,
|
|
(shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1, 1),
|
|
)
|
|
self.assertEqual(tiled.shape, shape)
|
|
self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype)
|
|
tiled.store_untiled(dst)
|
|
ty = jax.ShapeDtypeStruct(shape, dtype)
|
|
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ())
|
|
expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
|
|
np.testing.assert_array_equal(f(), expected)
|
|
|
|
@parameterized.product(
|
|
load_tiled=[False, True],
|
|
store_tiled=[False, True],
|
|
dtype=[jnp.int8, jnp.int16, jnp.int32],
|
|
swizzle=[16, 32, 64, 128],
|
|
num_col_tiles=[1, 2, 3],
|
|
row_tiling=[8, 64],
|
|
)
|
|
def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles, row_tiling):
|
|
if (not load_tiled or not load_tiled) and row_tiling != 64:
|
|
self.skipTest("Old code path does not support this")
|
|
mlir_dtype = utils.dtype_to_ir_type(dtype)
|
|
bw = bytewidth(mlir_dtype)
|
|
col_tiling = swizzle // bw
|
|
if col_tiling % 8:
|
|
self.skipTest("WGMMA layout requires col_tiling % 8 == 0")
|
|
m, n = 128, col_tiling * num_col_tiles
|
|
tiling = (row_tiling, col_tiling)
|
|
tiled_layout = fa._tiled_wgmma_layout((m, n))
|
|
load_layout = tiled_layout if load_tiled else mgpu.TILED_LAYOUT_WGMMA
|
|
store_layout = tiled_layout if store_tiled else mgpu.TILED_LAYOUT_WGMMA
|
|
if (not load_tiled or not store_tiled) and bw == 4 and swizzle == 32:
|
|
self.skipTest("Old code path does not support this")
|
|
def kernel(ctx, in_, out, smems):
|
|
smem_in, smem_out, barrier = smems
|
|
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
|
|
barrier.wait()
|
|
t = mgpu.FragmentedArray.load_tiled(
|
|
smem_in, swizzle=swizzle, is_signed=True, layout=load_layout
|
|
)
|
|
t.to_layout(store_layout).store_tiled(smem_out, swizzle=swizzle)
|
|
mgpu.commit_shared()
|
|
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
|
|
ctx.await_async_copy(0)
|
|
expected = (
|
|
np.arange(m * n, dtype=dtype)
|
|
.reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1])
|
|
.transpose(0, 2, 1, 3)
|
|
)
|
|
|
|
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
|
|
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
|
|
try:
|
|
with jtu.capture_stdout() as get_sass:
|
|
iota = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
|
|
[expected, expected, mgpu.TMABarrier()],
|
|
)(expected)
|
|
finally:
|
|
if prev_dump is not None:
|
|
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
|
|
else:
|
|
del os.environ["MOSAIC_GPU_DUMP_SASS"]
|
|
np.testing.assert_array_equal(iota, expected)
|
|
|
|
# Verify that we don't use too many registers for the transfers.
|
|
# We verify LDS and STS separately, because they might use two different
|
|
# methods of computing offsets and we don't rely on CSE between them.
|
|
expected_regs = swizzle // bytewidth(mlir_dtype) // 8
|
|
# When the bytewidth is smaller than 2 the swizzle pattern changes every 2
|
|
# column tiles, so we only need half the registers.
|
|
if load_tiled and store_tiled: # The old code doesn't optimize properly.
|
|
if bytewidth(mlir_dtype) < 2:
|
|
expected_regs //= 2
|
|
for instr in ("STS", "LDS"):
|
|
with self.subTest(instr + " count"):
|
|
addrs = re.findall(instr + r".* \[(.*)\]", get_sass())
|
|
def get_reg(addr):
|
|
if (pos := addr.find("+")) != -1:
|
|
return addr[:pos]
|
|
return addr
|
|
used_regs = {get_reg(addr) for addr in addrs}
|
|
self.assertLessEqual(len(used_regs), expected_regs)
|
|
|
|
def test_copy_for_upcast(self):
|
|
dtype = jnp.int8
|
|
swizzle = 128
|
|
col_tiling = swizzle // bytewidth(utils.dtype_to_ir_type(dtype))
|
|
m, n = 128, col_tiling * 2
|
|
tiling = (64, col_tiling)
|
|
tiled_layout = fa._tiled_wgmma_layout_for_upcast((m, n))
|
|
def kernel(ctx, in_, out, smems):
|
|
smem_in, smem_out, barrier = smems
|
|
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
|
|
barrier.wait()
|
|
t = mgpu.FragmentedArray.load_tiled(
|
|
smem_in, swizzle=swizzle, is_signed=True, layout=tiled_layout
|
|
)
|
|
t.store_tiled(smem_out, swizzle=swizzle)
|
|
mgpu.commit_shared()
|
|
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
|
|
ctx.await_async_copy(0)
|
|
x = jax.random.randint(
|
|
jax.random.key(42), tile_shape((m, n), tiling), -128, 127, dtype=dtype
|
|
)
|
|
f = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), x, x, [x, x, mgpu.TMABarrier()],
|
|
)
|
|
np.testing.assert_array_equal(f(x), x)
|
|
|
|
@parameterized.product(
|
|
dtype=[jnp.int16], # TODO(apaszke): More dtypes
|
|
# TODO(apaszke): swizzle=64 <- not implemented in transfer_tiled right now
|
|
swizzle=[16, 32, 128],
|
|
)
|
|
def test_transpose_tiled(self, dtype, swizzle):
|
|
mlir_dtype = utils.dtype_to_ir_type(dtype)
|
|
bw = bytewidth(mlir_dtype)
|
|
col_tiling = swizzle // bw
|
|
m, n = 128, 256
|
|
tiling = (8, col_tiling)
|
|
transpose_layout = fa.WGMMA_TRANSPOSED_LAYOUT
|
|
def kernel(ctx, in_, out, smems):
|
|
smem_in, smem_out, barrier = smems
|
|
ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier)
|
|
barrier.wait()
|
|
t = mgpu.FragmentedArray.load_tiled(
|
|
smem_in, swizzle=swizzle, is_signed=True, layout=fa.TILED_LAYOUT_WGMMA
|
|
)
|
|
smem_out_t = memref_transpose(smem_out, (1, 0, 3, 2))
|
|
t.to_layout(transpose_layout).store_tiled(smem_out_t, swizzle=swizzle)
|
|
mgpu.commit_shared()
|
|
ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle)
|
|
ctx.await_async_copy(0)
|
|
x = (
|
|
np.arange(m * n, dtype=dtype)
|
|
.reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1])
|
|
.transpose(0, 2, 1, 3)
|
|
)
|
|
y_ref = (
|
|
np.arange(m * n, dtype=dtype)
|
|
.reshape(m, n)
|
|
.T.reshape(n // tiling[0], tiling[0], m // tiling[1], tiling[1])
|
|
.transpose(0, 2, 1, 3)
|
|
)
|
|
|
|
y = mgpu.as_gpu_kernel(
|
|
kernel, (1, 1, 1), (128, 1, 1), x, y_ref, [x, y_ref, mgpu.TMABarrier()],
|
|
)(x)
|
|
np.testing.assert_array_equal(y, y_ref)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Tile:
|
|
"""Defines a Tile transform in a TestCaseInput.
|
|
|
|
Note that we cannot simply alias mgpu_dialect.TileTransformAttr.get, because
|
|
we do not have an MLIR context at the point we define the TestCaseInput.
|
|
"""
|
|
|
|
tiling: tuple[int, ...]
|
|
|
|
def attr(self):
|
|
return mgpu_dialect.TileTransformAttr.get(self.tiling)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Transpose:
|
|
"""Defines a Transpose transform in a TestCaseInput.
|
|
|
|
Note that we cannot simply alias mgpu_dialect.TransposeTransformAttr.get,
|
|
because we do not have an MLIR context at the point we define the
|
|
TestCaseInput.
|
|
"""
|
|
|
|
permutation: tuple[int, ...]
|
|
|
|
def attr(self):
|
|
return mgpu_dialect.TransposeTransformAttr.get(self.permutation)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Swizzle:
|
|
"""Defines a Swizzle transform in a TestCaseInput.
|
|
|
|
Note that we cannot simply alias mgpu_dialect.SwizzleTransformAttr.get,
|
|
because we do not have an MLIR context at the point we define the
|
|
TestCaseInput.
|
|
"""
|
|
|
|
swizzle: mgpu_dialect.SwizzlingMode
|
|
|
|
def attr(self):
|
|
return mgpu_dialect.SwizzleTransformAttr.get(self.swizzle)
|
|
|
|
|
|
def memref_with_transforms(
|
|
mem_ref: ir.Value,
|
|
transforms: Sequence[Tile | Transpose | Swizzle],
|
|
) -> ir.Value:
|
|
"""Casts the memref to one that has a layout with the given transforms."""
|
|
mem_ref_type = ir.MemRefType(mem_ref.type)
|
|
|
|
transform_attr = [t.attr() for t in transforms]
|
|
if not transform_attr:
|
|
return mem_ref
|
|
|
|
layout = mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, transform_attr)
|
|
memref_new_type = ir.MemRefType.get(
|
|
mem_ref_type.shape,
|
|
mem_ref_type.element_type,
|
|
layout,
|
|
mem_ref_type.memory_space,
|
|
)
|
|
return memref.cast(memref_new_type, mem_ref)
|
|
|
|
|
|
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
|
"""Device tests with lowering from the MLIR dialect and layout inference."""
|
|
|
|
def setUp(self):
|
|
if mgpu_dialect is None:
|
|
raise self.skipTest("Test requires Mosaic GPU dialect")
|
|
super().setUp()
|
|
|
|
def test_pointwise_kernel(self):
|
|
def add(ctx, a, b, result, smem):
|
|
del ctx, smem
|
|
shape = ir.MemRefType(a.type).shape
|
|
elt_type = ir.MemRefType(a.type).element_type
|
|
|
|
zero_index = arith.constant(ir.IndexType.get(), 0)
|
|
zero_vector_indices = [zero_index] * len(shape)
|
|
|
|
# GMEM -> registers
|
|
ab_type = ir.VectorType.get(shape, elt_type)
|
|
a = vector.load(ab_type, a, zero_vector_indices)
|
|
b = vector.load(ab_type, b, zero_vector_indices)
|
|
|
|
# Computation
|
|
add = arith.addf(a, b)
|
|
|
|
# Registers -> GMEM
|
|
vector.store(add, result, zero_vector_indices)
|
|
|
|
dtype = jnp.bfloat16
|
|
shape = (128, 128)
|
|
jax_shape = jax.ShapeDtypeStruct(shape, dtype)
|
|
kernel = mgpu.as_gpu_kernel(
|
|
add,
|
|
grid=(1, 1, 1),
|
|
block=(128, 1, 1),
|
|
in_shape=(jax_shape, jax_shape),
|
|
out_shape=jax_shape,
|
|
smem_scratch_shape=[],
|
|
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
|
|
)
|
|
|
|
x = self.prng.uniform(-1, 1, shape).astype(dtype)
|
|
y = self.prng.uniform(-1, 1, shape).astype(dtype)
|
|
|
|
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y)
|
|
|
|
@staticmethod
|
|
def kernel_with_tma_cases(dtype: jnp.dtype):
|
|
@dataclasses.dataclass()
|
|
class TestCaseInput:
|
|
shape: tuple[int, ...]
|
|
shape_sliced: tuple[int, ...] = ()
|
|
slice_indices: tuple[int, ...] = ()
|
|
slice_lengths: tuple[int, ...] = ()
|
|
transforms: tuple[Tile | Transpose | Swizzle, ...] = ()
|
|
|
|
def __post_init__(self):
|
|
if not self.shape_sliced:
|
|
self.shape_sliced = self.shape
|
|
if not self.slice_lengths:
|
|
self.slice_lengths = self.shape_sliced
|
|
if not self.slice_indices:
|
|
self.slice_indices = tuple([0] * len(self.slice_lengths))
|
|
|
|
result = []
|
|
for swizzle in mgpu_dialect.SwizzlingMode:
|
|
n = swizzle * 8 // jnp.finfo(dtype).bits
|
|
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
|
|
# We need at least one case with no transforms, as this is handled
|
|
# differently.
|
|
result.append(TestCaseInput(shape=[128, n]))
|
|
result.extend([
|
|
TestCaseInput(
|
|
shape=[128, n],
|
|
transforms=[Swizzle(swizzle)],
|
|
),
|
|
TestCaseInput(
|
|
shape=[256, n],
|
|
shape_sliced=[128, n],
|
|
transforms=[Swizzle(swizzle)],
|
|
),
|
|
TestCaseInput(
|
|
shape=[2, 128, n],
|
|
shape_sliced=[128, n],
|
|
slice_lengths=[-1, 128, n],
|
|
slice_indices=[1, 0, 0],
|
|
transforms=[Swizzle(swizzle)],
|
|
),
|
|
TestCaseInput(
|
|
shape=[2, 3, 64, n],
|
|
transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)],
|
|
),
|
|
TestCaseInput(
|
|
shape=[2, 3, 64, n],
|
|
transforms=[
|
|
Transpose([1, 0, 2, 3]),
|
|
Transpose([1, 0, 2, 3]),
|
|
Swizzle(swizzle),
|
|
],
|
|
),
|
|
TestCaseInput(
|
|
shape=[2, 3, 64, n],
|
|
transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)],
|
|
),
|
|
TestCaseInput(
|
|
shape=[256, n],
|
|
shape_sliced=[128, n],
|
|
transforms=[Tile([64, n]), Swizzle(swizzle)],
|
|
),
|
|
TestCaseInput(
|
|
shape=[2 * 64, 3 * n],
|
|
transforms=[
|
|
Tile([64, n]),
|
|
Transpose([1, 0, 2, 3]),
|
|
Swizzle(swizzle),
|
|
],
|
|
),
|
|
])
|
|
return result
|
|
|
|
@parameterized.parameters(kernel_with_tma_cases(jnp.bfloat16))
|
|
def test_kernel_with_tma(self, test_case):
|
|
def add(
|
|
ctx: launch_context.LaunchContext,
|
|
in_gmem_ref: ir.Value,
|
|
result_gmem_ref: ir.Value,
|
|
smem: list[ir.Value],
|
|
):
|
|
del ctx
|
|
smem_ref, tma_barrier = smem
|
|
smem_ref = memref_with_transforms(smem_ref, test_case.transforms)
|
|
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
|
|
|
|
elt_type = ir.MemRefType(in_gmem_ref.type).element_type
|
|
memref_bytes = utils.bytewidth(elt_type) * math.prod(
|
|
test_case.shape_sliced
|
|
)
|
|
mgpu_dialect.arrive_expect_tx(
|
|
barrier=dialect_barrier, expect_tx= memref_bytes
|
|
)
|
|
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
slice_indices = [arith.constant(i32, i) for i in test_case.slice_indices]
|
|
|
|
# GMEM -> SMEM
|
|
mgpu_dialect.async_load(
|
|
source=in_gmem_ref,
|
|
destination=smem_ref,
|
|
barrier=dialect_barrier,
|
|
indices=slice_indices,
|
|
slice_lengths=test_case.slice_lengths,
|
|
collective=ir.ArrayAttr.get([]),
|
|
)
|
|
|
|
parities = memref.load(tma_barrier.phases, [])
|
|
parity, _ = tma_barrier.update_parities(parities)
|
|
mgpu_dialect.wait(dialect_barrier, parity)
|
|
|
|
# SMEM -> GMEM
|
|
zero_index = arith.constant(i32, 0)
|
|
mgpu_dialect.async_store(
|
|
source=smem_ref,
|
|
destination=result_gmem_ref,
|
|
indices=[zero_index] * len(test_case.shape_sliced),
|
|
slice_lengths=test_case.shape_sliced,
|
|
)
|
|
nvvm.cp_async_bulk_wait_group(0)
|
|
utils.warpgroup_barrier()
|
|
|
|
dtype = jnp.bfloat16
|
|
|
|
jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype)
|
|
jax_shape_sliced = jax.ShapeDtypeStruct(test_case.shape_sliced, dtype)
|
|
kernel = mgpu.as_gpu_kernel(
|
|
add,
|
|
grid=(1, 1, 1),
|
|
block=(128, 1, 1),
|
|
in_shape=(jax_shape),
|
|
out_shape=jax_shape_sliced,
|
|
smem_scratch_shape=[
|
|
jax_shape_sliced,
|
|
core.TMABarrier(1),
|
|
],
|
|
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
|
|
)
|
|
|
|
x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
|
|
|
|
input_slice = tuple(
|
|
slice(i * abs(l), (i + 1) * abs(l))
|
|
for i, l in zip(test_case.slice_indices, test_case.slice_lengths)
|
|
)
|
|
self.assertArraysEqual(
|
|
jax.jit(kernel)(x),
|
|
(x[input_slice]).reshape(test_case.shape_sliced),
|
|
)
|
|
|
|
@staticmethod
|
|
def pointwise_kernel_with_tma_cases(dtype: jnp.dtype):
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TestCaseInput:
|
|
shape: tuple[int, ...]
|
|
transforms: tuple[Tile | Transpose | Swizzle, ...] = ()
|
|
|
|
result = []
|
|
for swizzle in mgpu_dialect.SwizzlingMode:
|
|
n = swizzle * 8 // jnp.finfo(dtype).bits
|
|
if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle:
|
|
# We need at least one case with no transforms, as this is handled
|
|
# differently.
|
|
result.append(TestCaseInput(shape=[128, n]))
|
|
result.extend([
|
|
TestCaseInput(
|
|
shape=[128, n],
|
|
transforms=[Swizzle(swizzle)],
|
|
),
|
|
TestCaseInput(
|
|
shape=[2, 3, 64, n],
|
|
transforms=[Transpose([0, 1, 2, 3]), Swizzle(swizzle)],
|
|
),
|
|
TestCaseInput(
|
|
shape=[2, 3, 64, n],
|
|
transforms=[
|
|
Transpose([1, 0, 2, 3]),
|
|
Transpose([1, 0, 2, 3]),
|
|
Swizzle(swizzle),
|
|
],
|
|
),
|
|
TestCaseInput(
|
|
shape=[2, 3, 64, n],
|
|
transforms=[Transpose([1, 0, 2, 3]), Swizzle(swizzle)],
|
|
),
|
|
TestCaseInput(
|
|
shape=[128, n],
|
|
transforms=[Tile([64, n]), Swizzle(swizzle)],
|
|
),
|
|
TestCaseInput(
|
|
shape=[2 * 64, 3 * n],
|
|
transforms=[
|
|
Tile([64, n]),
|
|
Transpose([1, 0, 2, 3]),
|
|
Swizzle(swizzle),
|
|
],
|
|
),
|
|
])
|
|
return result
|
|
|
|
@parameterized.parameters(pointwise_kernel_with_tma_cases(jnp.bfloat16))
|
|
def test_pointwise_kernel_with_tma(self, test_case):
|
|
def add(
|
|
ctx: launch_context.LaunchContext,
|
|
a_gmem_ref: ir.Value,
|
|
b_gmem_ref: ir.Value,
|
|
result_gmem_ref: ir.Value,
|
|
smem: list[ir.Value],
|
|
):
|
|
del ctx
|
|
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem
|
|
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
|
|
|
|
memref_type = ir.MemRefType(a_gmem_ref.type)
|
|
shape = memref_type.shape
|
|
elt_type = memref_type.element_type
|
|
|
|
memref_bytes = utils.bytewidth(elt_type) * math.prod(shape)
|
|
mgpu_dialect.arrive_expect_tx(
|
|
barrier=dialect_barrier, expect_tx=2 * memref_bytes
|
|
)
|
|
|
|
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
|
|
zero_slice_indices = [zero_i32] * memref_type.rank
|
|
|
|
# GMEM -> SMEM
|
|
mgpu_dialect.async_load(
|
|
source=a_gmem_ref,
|
|
destination=memref_with_transforms(
|
|
a_smem_ref, test_case.transforms
|
|
),
|
|
barrier=dialect_barrier,
|
|
indices=zero_slice_indices,
|
|
slice_lengths=shape,
|
|
collective=ir.ArrayAttr.get([]),
|
|
)
|
|
mgpu_dialect.async_load(
|
|
source=b_gmem_ref,
|
|
destination=memref_with_transforms(
|
|
b_smem_ref, test_case.transforms
|
|
),
|
|
barrier=dialect_barrier,
|
|
indices=zero_slice_indices,
|
|
slice_lengths=shape,
|
|
collective=ir.ArrayAttr.get([]),
|
|
)
|
|
|
|
parities = memref.load(tma_barrier.phases, [])
|
|
parity, _ = tma_barrier.update_parities(parities)
|
|
mgpu_dialect.wait(dialect_barrier, parity)
|
|
|
|
zero_index = arith.constant(ir.IndexType.get(), 0)
|
|
zero_vector_indices = [zero_index] * memref_type.rank
|
|
|
|
# SMEM -> registers
|
|
ab_type = ir.VectorType.get(shape, elt_type)
|
|
a = vector.load(ab_type, a_smem_ref, zero_vector_indices)
|
|
b = vector.load(ab_type, b_smem_ref, zero_vector_indices)
|
|
|
|
# Computation
|
|
add = arith.addf(arith.addf(a, b), b)
|
|
|
|
# Registers -> SMEM
|
|
vector.store(add, result_smem_ref, zero_vector_indices)
|
|
|
|
# SMEM -> GMEM
|
|
mgpu_dialect.async_store(
|
|
source=memref_with_transforms(
|
|
result_smem_ref, test_case.transforms
|
|
),
|
|
destination=result_gmem_ref,
|
|
indices=zero_slice_indices,
|
|
slice_lengths=shape,
|
|
)
|
|
nvvm.cp_async_bulk_wait_group(0)
|
|
utils.warpgroup_barrier()
|
|
|
|
dtype = jnp.bfloat16
|
|
|
|
jax_shape = jax.ShapeDtypeStruct(test_case.shape, dtype)
|
|
kernel = mgpu.as_gpu_kernel(
|
|
add,
|
|
grid=(1, 1, 1),
|
|
block=(128, 1, 1),
|
|
in_shape=(jax_shape, jax_shape),
|
|
out_shape=jax_shape,
|
|
smem_scratch_shape=[
|
|
jax_shape,
|
|
jax_shape,
|
|
jax_shape,
|
|
core.TMABarrier(1),
|
|
],
|
|
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
|
|
)
|
|
|
|
x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
|
|
y = self.prng.uniform(-1, 1, test_case.shape).astype(dtype)
|
|
|
|
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
|
|
|
|
|
|
class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):
|
|
|
|
@staticmethod
|
|
def wgmma_kernel_with_tma_cases(abtype: jnp.dtype):
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TestCaseInput:
|
|
shape_a: tuple[int, ...] = ()
|
|
shape_b: tuple[int, ...] = ()
|
|
shape_res: tuple[int, ...] = ()
|
|
transforms_a: tuple[Tile | Transpose | Swizzle, ...] = ()
|
|
transforms_b: tuple[Tile | Transpose | Swizzle, ...] = ()
|
|
transpose_a: bool = False
|
|
transpose_b: bool = False
|
|
load_a_in_registers: bool = False
|
|
|
|
result = []
|
|
for swizzle in [
|
|
# TODO(dasenov): Add a test for kNoSwizzle, i.e. all swizzling modes.
|
|
mgpu_dialect.SwizzlingMode.k32ByteSwizzle,
|
|
mgpu_dialect.SwizzlingMode.k64ByteSwizzle,
|
|
mgpu_dialect.SwizzlingMode.k128ByteSwizzle,
|
|
]:
|
|
k = swizzle // np.dtype(abtype).itemsize
|
|
groups_m = 4
|
|
groups_n = 1
|
|
groups_k = 1
|
|
result.extend([
|
|
TestCaseInput(
|
|
shape_a=[groups_m * 64, groups_k * k],
|
|
shape_b=[groups_k * k, groups_n * k],
|
|
shape_res=[groups_m * 64, groups_n * k],
|
|
),
|
|
TestCaseInput(
|
|
shape_a=[groups_m * 64, groups_k * k],
|
|
shape_b=[groups_n * k, groups_k * k],
|
|
shape_res=[groups_m * 64, groups_n * k],
|
|
transpose_b=True,
|
|
),
|
|
TestCaseInput(
|
|
shape_a=[groups_m * 64, groups_k * k],
|
|
shape_b=[groups_k * k, groups_n * k],
|
|
shape_res=[groups_m * 64, groups_n * k],
|
|
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
|
|
transforms_b=[Tile([k, k]), Swizzle(swizzle)],
|
|
),
|
|
TestCaseInput(
|
|
shape_a=[groups_m * 64, groups_k * k],
|
|
shape_b=[groups_k * k, groups_n * k],
|
|
shape_res=[groups_m * 64, groups_n * k],
|
|
transforms_a=[Tile([64, k]), Swizzle(swizzle)],
|
|
load_a_in_registers=True,
|
|
),
|
|
])
|
|
# The below only works for 128-byte swizzling. Regardless of transposing,
|
|
# TMA needs the size of the last dimension to be compatible with the
|
|
# swizzle.
|
|
if swizzle == mgpu_dialect.SwizzlingMode.k128ByteSwizzle:
|
|
result.append(
|
|
TestCaseInput(
|
|
shape_a=[groups_k * k, groups_m * 64],
|
|
shape_b=[groups_k * k, groups_n * k],
|
|
shape_res=[groups_m * 64, groups_n * k],
|
|
transpose_a=True,
|
|
)
|
|
)
|
|
return result
|
|
|
|
@parameterized.parameters(wgmma_kernel_with_tma_cases(jnp.bfloat16))
|
|
def test_wgmma_kernel_with_tma(self, test_case):
|
|
|
|
def matmul(
|
|
ctx: launch_context.LaunchContext,
|
|
a_gmem_ref: ir.Value,
|
|
b_gmem_ref: ir.Value,
|
|
result_gmem_ref: ir.Value,
|
|
smem: list[ir.Value],
|
|
):
|
|
del ctx
|
|
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem
|
|
a_smem_ref = memref_with_transforms(a_smem_ref, test_case.transforms_a)
|
|
b_smem_ref = memref_with_transforms(b_smem_ref, test_case.transforms_b)
|
|
dialect_barrier = tma_barrier.as_dialect_barrier_memref()
|
|
|
|
ab_elt_type = ir.MemRefType(a_gmem_ref.type).element_type
|
|
bytes_a = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_a)
|
|
bytes_b = utils.bytewidth(ab_elt_type) * math.prod(test_case.shape_b)
|
|
|
|
mgpu_dialect.arrive_expect_tx(
|
|
barrier=dialect_barrier,
|
|
expect_tx=bytes_a + bytes_b,
|
|
)
|
|
|
|
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
|
|
# GMEM -> SMEM
|
|
mgpu_dialect.async_load(
|
|
source=a_gmem_ref,
|
|
destination=a_smem_ref,
|
|
barrier=dialect_barrier,
|
|
indices=[zero_i32] * len(test_case.shape_a),
|
|
slice_lengths=test_case.shape_a,
|
|
collective=ir.ArrayAttr.get([]),
|
|
)
|
|
mgpu_dialect.async_load(
|
|
source=b_gmem_ref,
|
|
destination=b_smem_ref,
|
|
barrier=dialect_barrier,
|
|
indices=[zero_i32] * len(test_case.shape_b),
|
|
slice_lengths=test_case.shape_b,
|
|
collective=ir.ArrayAttr.get([]),
|
|
)
|
|
|
|
parities = memref.load(tma_barrier.phases, [])
|
|
parity, _ = tma_barrier.update_parities(parities)
|
|
mgpu_dialect.wait(dialect_barrier, parity)
|
|
|
|
# SMEM -> Registers
|
|
a_operand = a_smem_ref
|
|
zero_index = arith.constant(ir.IndexType.get(), 0)
|
|
if test_case.load_a_in_registers:
|
|
a_vector_type = ir.VectorType.get(test_case.shape_a, ab_elt_type)
|
|
zero_vector_indices = [zero_index] * len(test_case.shape_a)
|
|
a_operand = vector.load(a_vector_type, a_smem_ref, zero_vector_indices)
|
|
|
|
# Computation
|
|
shape_result = ir.MemRefType(result_gmem_ref.type).shape
|
|
result_elt_type = ir.MemRefType(result_gmem_ref.type).element_type
|
|
zero_acc = arith.constant(
|
|
result_elt_type, ir.FloatAttr.get(result_elt_type, 0.0)
|
|
)
|
|
accumulator = vector.splat(
|
|
ir.VectorType.get(shape_result, result_elt_type), zero_acc
|
|
)
|
|
result = mgpu_dialect.wgmma(
|
|
accumulator,
|
|
a_operand,
|
|
b_smem_ref,
|
|
transpose_a=test_case.transpose_a,
|
|
transpose_b=test_case.transpose_b,
|
|
)
|
|
|
|
nvvm.wgmma_commit_group_sync_aligned()
|
|
nvvm.wgmma_wait_group_sync_aligned(0)
|
|
|
|
# Registers -> SMEM
|
|
vector.store(result, result_smem_ref, [zero_index] * len(shape_result))
|
|
|
|
# SMEM -> GMEM
|
|
mgpu_dialect.async_store(
|
|
source=result_smem_ref,
|
|
destination=result_gmem_ref,
|
|
indices=[zero_i32, zero_i32],
|
|
slice_lengths=shape_result,
|
|
)
|
|
nvvm.cp_async_bulk_wait_group(0)
|
|
|
|
abtype = jnp.bfloat16
|
|
acctype = jnp.float32
|
|
a_jax_shape = jax.ShapeDtypeStruct(test_case.shape_a, abtype)
|
|
b_jax_shape = jax.ShapeDtypeStruct(test_case.shape_b, abtype)
|
|
result_jax_shape = jax.ShapeDtypeStruct(test_case.shape_res, acctype)
|
|
kernel = mgpu.as_gpu_kernel(
|
|
matmul,
|
|
grid=(1, 1, 1),
|
|
block=(128, 1, 1),
|
|
in_shape=(a_jax_shape, b_jax_shape),
|
|
out_shape=result_jax_shape,
|
|
smem_scratch_shape=[
|
|
a_jax_shape,
|
|
b_jax_shape,
|
|
result_jax_shape,
|
|
core.TMABarrier(1),
|
|
],
|
|
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
|
|
)
|
|
|
|
x = self.prng.uniform(-1, 1, test_case.shape_a).astype(abtype)
|
|
y = self.prng.uniform(-1, 1, test_case.shape_b).astype(abtype)
|
|
|
|
transpose = lambda x, t: x.T if t else x
|
|
self.assertArraysAllClose(
|
|
jax.jit(kernel)(x, y),
|
|
np.matmul(
|
|
transpose(x.reshape(test_case.shape_a), test_case.transpose_a),
|
|
transpose(y.reshape(test_case.shape_b), test_case.transpose_b),
|
|
),
|
|
atol=1e-5,
|
|
rtol=1e-5,
|
|
)
|
|
|
|
|
|
class UtilsTest(TestCase):
|
|
@parameterized.parameters(
|
|
(1,),
|
|
(-1,),
|
|
(slice(2), slice(3),),
|
|
(slice(1), slice(1, 3)),
|
|
(slice(-2, 0),),
|
|
(slice(-2, -1),),
|
|
*([(utils.DynamicSlice(0, 2),)] if HAS_MOSAIC_GPU else []),
|
|
)
|
|
def test_parse_indices(self, *indices):
|
|
# We are simply making sure this does not raise.
|
|
_, _, _ = utils.parse_indices(indices, (2, 3, 4))
|
|
|
|
@parameterized.parameters(
|
|
(42,),
|
|
(-42,),
|
|
(slice(42),),
|
|
(slice(0, 42),),
|
|
(slice(-42, 0),),
|
|
(slice(-4, -42),),
|
|
*([(utils.DynamicSlice(0, 4),)] if HAS_MOSAIC_GPU else []),
|
|
)
|
|
def test_parse_indices_oob(self, indices):
|
|
with self.assertRaisesRegex(IndexError, "out of bounds"):
|
|
utils.parse_indices(indices, (2, 3, 4))
|
|
|
|
|
|
class SerializationTest(absltest.TestCase):
|
|
|
|
def test_pass_is_registered(self):
|
|
ctx = mlir.make_ir_context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with ir.Location.unknown(ctx):
|
|
module = ir.Module.create()
|
|
pipeline = passmanager.PassManager.parse(
|
|
"builtin.module(mosaic_gpu-serde{serialize=true})",
|
|
ctx,
|
|
)
|
|
pipeline.run(module.operation)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader())
|