[Mosaic GPU] Refactor the Blackwell matmul example and make it runnable

The previous impelmentation depends on LLVM intrinsics that have not been submitted
yet. This replaces them with inline PTX (as far as I can tell there's no downside to
that) that's encapsulated into convenience functions.

PiperOrigin-RevId: 723498248
This commit is contained in:
Adam Paszke 2025-02-05 07:10:20 -08:00 committed by jax authors
parent e7a4f89343
commit b79ab01ee7
4 changed files with 306 additions and 228 deletions

View File

@ -85,6 +85,8 @@ from .utils import (
warpgroup_idx as warpgroup_idx,
when as when,
)
# The import below shadows the module, so we need to rename it.
from . import wgmma as _wgmma # noqa: F401
from .wgmma import (
WGMMAAccumulator as WGMMAAccumulator,
wgmma as wgmma,

View File

@ -12,256 +12,197 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import contextlib
import numpy as np
"""Matmul kernel for Blackwell."""
import jax
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import gpu
from jax._src.lib.mlir.dialects import llvm
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import nvvm
from jax._src.lib.mlir.dialects import vector
from jax.experimental.mosaic import gpu as mgpu
from jax.experimental.mosaic.gpu import c, ds, utils
from jax.experimental.mosaic.gpu import tcgen05
import jax.numpy as jnp
import jax.random as jr
from jax import ShapeDtypeStruct as SDS
from jax._src.interpreters import mlir
from jax.experimental.mosaic import gpu as mgpu
from jax.experimental.mosaic.gpu import c, utils, create_descriptor, ds
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvvm
from jaxlib.mlir.dialects import vector
from jaxlib.mlir.dialects import scf
import numpy as np
# TODO(andportnoy) refactor into a single function for Hopper/Blackwell?
def create_smem_descriptor(
memref_arg,
leading_byte_offset: int,
stride_byte_offset: int,
swizzle: int | None,
memory_space: int | None = None,
BLACKWELL_MMA_FP16_K = 16
TMA_WARP = 1
MMA_WARP = 0
def bytecount(shape, dtype):
return int(np.prod(shape) * dtype.dtype.itemsize)
def build_kernel(
m, n, k,
tile_m: int = 128,
tile_n: int = 128,
):
i64 = ir.IntegerType.get_signless(64)
hopper_desc = create_descriptor(
memref_arg, leading_byte_offset, stride_byte_offset, swizzle, memory_space=3)
blackwell_bit = arith.shli(c(1,i64), c(46,i64))
blackwell_desc = llvm.or_(hopper_desc, blackwell_bit)
return blackwell_desc
def smem_descriptor_increment_address(desc, nbytes):
i64 = ir.IntegerType.get_signless(64)
return arith.addi(desc, arith.shrui(c(nbytes,i64), c(4,i64)))
def create_blackwell_fp16_mma_descriptor(
m: int,
n: int,
dtype,
transpose_b: bool = False,
):
i32 = ir.IntegerType.get_signless(32)
desc = c(0,i32)
def fieldset(val, bit):
field = arith.shli(c(val,i32), c(bit, i32))
nonlocal desc
desc = arith.ori(desc, field)
# encoding is dependent on .kind::<foo>, below is for .kind::f16
# 1: 0 - sparsity selector if sparsity is enabled
# 2 - sparsity: dense = 0, sparse = 1
# 3 - saturate for integer types, irrelevant for fp
# 5: 4 - output dtype
fieldset(1, 4)
# 6 - reserved
# 9: 7 - A dtype: f16 = 0, b16 = 1
if dtype == jnp.bfloat16:
fieldset(1, 7)
# 12:10 - B dtype: f16 = 0, b16 = 1
if dtype == jnp.bfloat16:
fieldset(1, 10)
# 13 - negate A
# 14 - negate B
# 15 - transpose A
# 16 - transpose B
if transpose_b:
fieldset(1, 16)
# 22:17 - N >> 3
fieldset(n >> 3, 17) # TODO validate field width
# 23 - reserved
# 28:24 - M >> 4
fieldset(m >> 4, 24)
# 29 - reserved
# 31:30 - max shift under .ws: irrelevant here
return desc
WARPSIZE = 32
m_tile = 128
n_tile = 128
k_tile = 64
m = 16*m_tile
n = 16*n_tile
k = 16*k_tile
# K = 16 is inherent to Blackwell MMA half precision instructions
blackwell_mma_fp16_k = 16
ashape = (m, k) # k-major (row major)
ashape_tile = (m_tile, k_tile)
bshape = (n, k) # k-major (column major)
bshape_tile = (n_tile, k_tile)
dshape = (m, n) # n-major (row major)
in_dtype = jnp.bfloat16
ddtype = jnp.float32
grid = (n//n_tile, m//m_tile, 1)
block = (8*WARPSIZE, 1, 1)
def kernel(ctx, a, b, d, smem):
i1 = ir.IntegerType.get_signless(1)
i32 = ir.IntegerType.get_signless(32)
f32 = ir.F32Type.get()
index = ir.IndexType.get()
ptr = ir.Type.parse("!llvm.ptr")
ptr6 = ir.Type.parse("!llvm.ptr<6>")
tidx = gpu.thread_id(gpu.Dimension.x)
warpid = arith.shrui(tidx, c(5,index))
tma_warp = 0
mma_warp = 1
# need a full aligned warpgroup, here it's warps 4-7
ldtm_warp_range = (4, 8)
ptr6 = ir.Type.parse("!llvm.ptr<6>") # TMEM
@contextlib.contextmanager
def only_warp(i):
is_warp_i = arith.cmpi(arith.CmpIPredicate.eq, warpid, c(i,index))
with ir.InsertionPoint(scf.IfOp(is_warp_i).then_block):
yield
scf.yield_([])
swizzle = 128
tile_k = 64 # TODO(apaszke): I think we need to tile TMA to change this.
in_dtype = jnp.float16
k_loop_iter = k // tile_k
@contextlib.contextmanager
def only_warp_range(a, b):
gea = arith.cmpi(arith.CmpIPredicate.uge, warpid, c(a,index))
ltb = arith.cmpi(arith.CmpIPredicate.ult, warpid, c(b,index))
predicate = arith.andi(gea, ltb)
with ir.InsertionPoint(scf.IfOp(predicate).then_block):
yield
scf.yield_([])
if m % tile_m != 0:
raise ValueError(f"{m=} must be divisible by {tile_m=}")
if n % tile_n != 0:
raise ValueError(f"{n=} must be divisible by {tile_n=}")
if k % tile_k != 0:
raise ValueError(f"{k=} must be divisible by {tile_k=}")
@contextlib.contextmanager
def single_warp_thread():
elected = nvvm.elect_sync(i1)
with ir.InsertionPoint(scf.IfOp(elected).then_block):
yield
scf.yield_([])
def kernel(ctx, a, b, d, smem):
# TODO(apaszke): Use more SMEM slots to avoid oversynchronizing warps.
a_smem, b_smem, barriers, tmem_addr = smem
(ab_full_barrier, ab_empty_barrier, mma_done_barrier) = barriers
a_shared, b_shared, (ab_full_barrier, ab_empty_barrier, mma_done_barrier), tmem_addr = smem
thread_idx = mgpu.thread_idx()
warp_idx = mgpu.warp_idx(sync=True)
warp_leader = nvvm.elect_sync(i1)
k_loop_iter = k//k_tile
is_warp = lambda i: arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32))
def bytecount(shape, dtype):
return int(np.prod(shape) * dtype.dtype.itemsize)
with mgpu.when(arith.andi(is_warp(TMA_WARP), warp_leader)):
m_start = arith.muli(gpu.block_id(gpu.Dimension.y), c(tile_m,index))
n_start = arith.muli(gpu.block_id(gpu.Dimension.x), c(tile_n,index))
@mgpu.fori(c(k_loop_iter, index), None)
def _tma_body(ki, _):
# TODO(apaszke): Use a predicate instead of a conditional.
with mgpu.when(arith.cmpi(arith.CmpIPredicate.ugt, ki, c(0, index))):
ab_empty_barrier.wait()
ab_full_barrier.arrive_expect_tx(
bytecount((tile_m, tile_k), in_dtype) + bytecount((tile_n, tile_k), in_dtype)
)
k_start = arith.muli(ki, c(tile_k, index))
common_args = dict(
swizzle=swizzle, barrier=ab_full_barrier, arrive=False, uniform=False,
)
ctx.async_copy(
src_ref=a,
dst_ref=a_smem,
gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)),
**common_args,
)
ctx.async_copy(
src_ref=b,
dst_ref=b_smem,
gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)),
**common_args,
)
txcount = bytecount(ashape_tile, in_dtype) + bytecount(bshape_tile, in_dtype)
with only_warp(tma_warp), single_warp_thread():
# XXX TODO move loop iteration into MLIR, otherwise we force unroll
for i in range(k_loop_iter):
if i > 0:
ab_empty_barrier.wait()
ab_full_barrier.arrive_expect_tx(txcount)
m_start = arith.muli(gpu.block_id(gpu.Dimension.y), c(m_tile,index))
n_start = arith.muli(gpu.block_id(gpu.Dimension.x), c(n_tile,index))
k_start = i*k_tile
ctx.async_copy(
src_ref=a,
dst_ref=a_shared,
gmem_slice=(ds(m_start, m_tile), ds(k_start,k_tile)),
swizzle=128,
barrier=ab_full_barrier,
arrive=False,
uniform=False,
)
ctx.async_copy(
src_ref=b,
dst_ref=b_shared,
gmem_slice=(ds(n_start, n_tile), ds(k_start,k_tile)),
swizzle=128,
barrier=ab_full_barrier,
arrive=False,
uniform=False,
)
with mgpu.when(is_warp(MMA_WARP)):
ncols = c(b_smem.type.shape[0], i32)
tmem_addr_addr = utils.memref_ptr(tmem_addr, memory_space=3)
tcgen05.tmem_alloc(tmem_addr_addr, ncols)
tcgen05.tmem_relinquish_alloc_permit()
with mgpu.when(warp_leader):
tmem_addr_value = llvm.load(ptr6, tmem_addr_addr)
idesc = tcgen05.create_instr_descriptor(
m=tile_n, n=tile_n, acc_dtype=jnp.float32, input_dtype=in_dtype
)
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
def _mma_body(ki, accumulate):
adesc = tcgen05.create_smem_descriptor(
a_smem, leading_byte_offset=16, stride_byte_offset=1024, swizzle=swizzle)
bdesc = tcgen05.create_smem_descriptor(
b_smem, leading_byte_offset=16, stride_byte_offset=1024, swizzle=swizzle)
ab_full_barrier.wait()
with only_warp(mma_warp):
ncols = c(b_shared.type.shape[0], i32)
tmem_addr_addr = utils.memref_ptr(tmem_addr, memory_space=3)
nvvm.tcgen05_alloc(tmem_addr_addr, ncols)
nvvm.tcgen05_relinquish_alloc_permit()
accumulate = 0
with single_warp_thread():
tmem_addr_value = llvm.load(ptr6, tmem_addr_addr)
idesc = create_blackwell_fp16_mma_descriptor(m_tile, n_tile, in_dtype)
for i in range(k_loop_iter):
adesc = create_smem_descriptor(
a_shared, leading_byte_offset=16, stride_byte_offset=1024, swizzle=128)
bdesc = create_smem_descriptor(
b_shared, leading_byte_offset=16, stride_byte_offset=1024, swizzle=128)
ab_full_barrier.wait()
for _ in range(4):
nvvm.tcgen05_mma("f16", "cta_1", tmem_addr_value, adesc, bdesc, idesc, enable_input_d=c(accumulate,i1))
accumulate = 1
adesc = smem_descriptor_increment_address(adesc, blackwell_mma_fp16_k*2)
bdesc = smem_descriptor_increment_address(bdesc, blackwell_mma_fp16_k*2)
last_iter = i == k_loop_iter-1
barrier = mma_done_barrier if last_iter else ab_empty_barrier
nvvm.tcgen05_commit_arrive(barrier.get_ptr())
# TODO(apaszke): Abstract this into a function.
assert tile_k % BLACKWELL_MMA_FP16_K == 0
def smem_descriptor_increment_address(desc, nbytes):
i64 = ir.IntegerType.get_signless(64)
return arith.addi(desc, arith.shrui(c(nbytes,i64), c(4,i64)))
for _ in range(tile_k // BLACKWELL_MMA_FP16_K):
tcgen05.mma("f16", 1, tmem_addr_value, adesc, bdesc, idesc, enable_input_d=accumulate)
accumulate = arith.constant(i1, 1)
adesc = smem_descriptor_increment_address(
adesc, BLACKWELL_MMA_FP16_K * 2
)
bdesc = smem_descriptor_increment_address(
bdesc, BLACKWELL_MMA_FP16_K * 2
)
with only_warp_range(*ldtm_warp_range), ctx.named_region("LDTM"):
is_last_iter = arith.cmpi(
arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index)
)
barrier_ptr = arith.select(
is_last_iter, mma_done_barrier.get_ptr(), ab_empty_barrier.get_ptr()
)
tcgen05.commit_arrive(barrier_ptr)
return accumulate
# TODO(apaszke): Should we have a warpgroup that's dedicated to this store?
gpu.barrier()
# TODO(apaszke): This has a very slow GMEM access pattern.
mma_done_barrier.wait()
tmem_ptr = llvm.inttoptr(ptr6, memref.load(tmem_addr, [c(0,index)]))
# TODO automate type creation
vector_i32 = nvvm.tcgen05_ld(ir.VectorType.get([n_tile], i32), shape="shape_32x32b", num=n_tile, tmem_addr=tmem_ptr)
vector_i32 = tcgen05.tmem_load(num=tile_n, tmem_addr=tmem_ptr)
vector_f32 = vector.bitcast(ir.VectorType.get(vector_i32.type.shape, f32), vector_i32)
wg_tidx = arith.andi(tidx, c((1 << 7)-1, index)) # tidx within warpgroup
wg_tidx = arith.remui(
arith.index_castui(index, thread_idx), c(utils.WARPGROUP_SIZE, index)
)
row = arith.addi(
arith.muli(
gpu.block_id(gpu.Dimension.y),
c(m_tile,index)),
wg_tidx)
column = arith.muli(
gpu.block_id(gpu.Dimension.x),
c(n_tile,index))
arith.muli(gpu.block_id(gpu.Dimension.y), c(tile_m, index)), wg_tidx
)
column = arith.muli(gpu.block_id(gpu.Dimension.x), c(tile_n, index))
vector.store(vector_f32, d, [row, column])
if __name__ == '__main__':
ka, kb = jr.split(jr.key(0), 2)
a = jr.uniform(key=ka, shape=ashape, dtype=in_dtype)
b = jr.uniform(key=kb, shape=bshape, dtype=in_dtype)
asds_tile = SDS(ashape_tile, a.dtype)
bsds_tile = SDS(bshape_tile, b.dtype)
dsds = SDS(dshape, ddtype)
tmem_addr = SDS((1,), np.uint32)
smem = (asds_tile, bsds_tile, tuple(mgpu.Barrier(arrival_count=1) for _ in range(3)), tmem_addr)
with mlir.make_ir_context(), ir.Location.unknown():
f = mgpu.as_gpu_kernel(
kernel,
grid,
block,
(SDS(a.shape, a.dtype), SDS(b.shape, b.dtype)),
dsds,
smem
)
y = f(a, b)
@jax.jit
def ref_f(x, y):
return jax.lax.dot_general(
x,
y,
dimension_numbers=(((1,), (1,)), ((), ())),
preferred_element_type=jnp.float32,
).astype(jnp.float32)
ref = ref_f(a, b)
np.testing.assert_allclose(
y.astype(jnp.float32), ref.astype(jnp.float32), atol=1e-3, rtol=1e-3
smem = (
jax.ShapeDtypeStruct((tile_m, tile_k), jnp.float16),
jax.ShapeDtypeStruct((tile_n, tile_k), jnp.float16),
[mgpu.Barrier(arrival_count=1)] * 3,
jax.ShapeDtypeStruct((1,), np.uint32), # TMEM address
)
return mgpu.as_gpu_kernel(
kernel,
(n // tile_n, m // tile_m, 1),
(128, 1, 1),
(
jax.ShapeDtypeStruct((m, k), jnp.float16),
jax.ShapeDtypeStruct((n, k), jnp.float16),
),
jax.ShapeDtypeStruct((m, n), jnp.float32),
smem,
)
def main(unused_argv):
m_tile = 128
n_tile = 128
k_tile = 64
m = 16*m_tile
n = 16*n_tile
k = 16*k_tile
ka, kb = jr.split(jr.key(0), 2)
a = jr.normal(key=ka, shape=(m, k), dtype=jnp.float16)
b = jr.normal(key=kb, shape=(n, k), dtype=jnp.float16)
with mlir.make_ir_context(), ir.Location.unknown():
f = build_kernel(m, n, k, tile_m=m_tile, tile_n=n_tile)
y = f(a, b).block_until_ready()
ref = np.asarray(a) @ np.asarray(b).T
np.testing.assert_allclose(y, ref, atol=1e-3, rtol=1e-3)
print("OK!")
if __name__ == "__main__":
from absl import app
import jax
jax.config.config_with_absl()
app.run(main)

View File

@ -0,0 +1,133 @@
# Copyright 2025 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from jax._src import dtypes
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import llvm
import numpy as np
from . import _wgmma
def create_smem_descriptor(
memref_arg,
leading_byte_offset: int,
stride_byte_offset: int,
swizzle: int | mgpu_dialect.SwizzlingMode | None,
):
blackwell_bit = 1 << 46
return _wgmma.create_descriptor(
memref_arg,
leading_byte_offset,
stride_byte_offset,
swizzle,
memory_space=3,
const_init=blackwell_bit,
)
def create_instr_descriptor(
m: int,
n: int,
acc_dtype,
input_dtype,
transpose_a: bool = False,
transpose_b: bool = False,
):
if input_dtype not in {np.float16, dtypes.bfloat16}:
raise NotImplementedError("Only float16 and bfloat16 inputs supported")
if acc_dtype not in {np.float32, np.float16}:
raise NotImplementedError("Only float32 and float16 accumulators supported")
desc = 0
# We ignore sparsity in bits 0-3
desc |= (acc_dtype == np.float32) << 4 # D dtype, bits 4-5
# Bit 6 is reserved
desc |= (input_dtype == dtypes.bfloat16) << 7 # A dtype, bits 7-9
desc |= (input_dtype == dtypes.bfloat16) << 10 # B dtype, bits 10-12
# We ignore negate bits 13-14
desc |= transpose_a << 15 # Transpose A
desc |= transpose_b << 16 # Transpose B
if n % 8 or n > 256:
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
desc |= (n >> 3) << 17 # N, bits 17-22
# Bit 23 is reserved
if m % 16 or m > 256:
raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}")
desc |= (m >> 4) << 24 # M >> 4, bits 24-28
# Bit 29 is reserved
# We ignore max shift under .ws, bits 30-31
return arith.constant(ir.IntegerType.get_signless(32), desc) # type: ignore
def mma(dtype, num_cta, d_tmem, adesc, bdesc, idesc, enable_input_d):
if not (1 <= num_cta <= 2):
raise ValueError(f"num_cta must be 1 or 2, got: {num_cta}")
return llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[d_tmem, adesc, bdesc, idesc, enable_input_d],
f"tcgen05.mma.cta_group::1.kind::{dtype} [$0], $1, $2, $3, $4;",
"r,l,l,r,b",
has_side_effects=True,
)
def commit_arrive(barrier):
return llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[barrier],
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$0];",
"l",
has_side_effects=True
)
def tmem_alloc(tmem_addr, ncols):
return llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[tmem_addr, ncols],
"tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$0], $1;",
"r,r",
has_side_effects=True,
)
def tmem_relinquish_alloc_permit():
return llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[],
"tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;",
"",
has_side_effects=True,
)
def tmem_load(tmem_addr, num):
if num.bit_count() != 1 or num > 128:
raise ValueError(f"num must be a power of 2 and <= 128, got: {num}")
i32 = ir.IntegerType.get_signless(32)
out_regs = ",".join("$" + str(i) for i in range(num))
regs = llvm.inline_asm(
ir.Type.parse(
"!llvm.struct<(" + ",".join("i32" for _ in range(num)) + ")>"
),
[tmem_addr],
f"tcgen05.ld.sync.aligned.32x32b.x{num}.b32 {{{out_regs}}}, [${num}];",
"=r," * num + "r",
has_side_effects=True,
)
out_ty = ir.VectorType.get([num], i32)
out_vec = llvm.mlir_undef(out_ty)
for i in range(num):
out_vec = llvm.insertelement(
out_vec, llvm.extractvalue(i32, regs, [i]), arith.constant(i32, i)
)
return out_vec

View File

@ -100,6 +100,7 @@ def create_descriptor(
stride_byte_offset: int,
swizzle: int | mgpu_dialect.SwizzlingMode | None,
memory_space: int | None = None,
const_init: int = 0,
):
i64 = ir.IntegerType.get_signless(64)
ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, memory_space))
@ -118,7 +119,8 @@ def create_descriptor(
)
# We ignore the offset
desc_const = (
(wgmma_encode(leading_byte_offset) << 16)
const_init
| (wgmma_encode(leading_byte_offset) << 16)
| (wgmma_encode(stride_byte_offset) << 32)
)
desc = llvm.or_(