mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic GPU] Refactor the Blackwell matmul example and make it runnable
The previous impelmentation depends on LLVM intrinsics that have not been submitted yet. This replaces them with inline PTX (as far as I can tell there's no downside to that) that's encapsulated into convenience functions. PiperOrigin-RevId: 723498248
This commit is contained in:
parent
e7a4f89343
commit
b79ab01ee7
@ -85,6 +85,8 @@ from .utils import (
|
||||
warpgroup_idx as warpgroup_idx,
|
||||
when as when,
|
||||
)
|
||||
# The import below shadows the module, so we need to rename it.
|
||||
from . import wgmma as _wgmma # noqa: F401
|
||||
from .wgmma import (
|
||||
WGMMAAccumulator as WGMMAAccumulator,
|
||||
wgmma as wgmma,
|
||||
|
@ -12,256 +12,197 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import contextlib
|
||||
|
||||
import numpy as np
|
||||
"""Matmul kernel for Blackwell."""
|
||||
|
||||
import jax
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import gpu
|
||||
from jax._src.lib.mlir.dialects import llvm
|
||||
from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import nvvm
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
from jax.experimental.mosaic import gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import c, ds, utils
|
||||
from jax.experimental.mosaic.gpu import tcgen05
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jr
|
||||
from jax import ShapeDtypeStruct as SDS
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.experimental.mosaic import gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import c, utils, create_descriptor, ds
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import gpu
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
from jaxlib.mlir.dialects import memref
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
from jaxlib.mlir.dialects import vector
|
||||
from jaxlib.mlir.dialects import scf
|
||||
import numpy as np
|
||||
|
||||
# TODO(andportnoy) refactor into a single function for Hopper/Blackwell?
|
||||
def create_smem_descriptor(
|
||||
memref_arg,
|
||||
leading_byte_offset: int,
|
||||
stride_byte_offset: int,
|
||||
swizzle: int | None,
|
||||
memory_space: int | None = None,
|
||||
|
||||
BLACKWELL_MMA_FP16_K = 16
|
||||
TMA_WARP = 1
|
||||
MMA_WARP = 0
|
||||
|
||||
|
||||
def bytecount(shape, dtype):
|
||||
return int(np.prod(shape) * dtype.dtype.itemsize)
|
||||
|
||||
|
||||
def build_kernel(
|
||||
m, n, k,
|
||||
tile_m: int = 128,
|
||||
tile_n: int = 128,
|
||||
):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
|
||||
hopper_desc = create_descriptor(
|
||||
memref_arg, leading_byte_offset, stride_byte_offset, swizzle, memory_space=3)
|
||||
blackwell_bit = arith.shli(c(1,i64), c(46,i64))
|
||||
blackwell_desc = llvm.or_(hopper_desc, blackwell_bit)
|
||||
return blackwell_desc
|
||||
|
||||
|
||||
def smem_descriptor_increment_address(desc, nbytes):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
return arith.addi(desc, arith.shrui(c(nbytes,i64), c(4,i64)))
|
||||
|
||||
|
||||
def create_blackwell_fp16_mma_descriptor(
|
||||
m: int,
|
||||
n: int,
|
||||
dtype,
|
||||
transpose_b: bool = False,
|
||||
):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
desc = c(0,i32)
|
||||
def fieldset(val, bit):
|
||||
field = arith.shli(c(val,i32), c(bit, i32))
|
||||
nonlocal desc
|
||||
desc = arith.ori(desc, field)
|
||||
# encoding is dependent on .kind::<foo>, below is for .kind::f16
|
||||
# 1: 0 - sparsity selector if sparsity is enabled
|
||||
# 2 - sparsity: dense = 0, sparse = 1
|
||||
# 3 - saturate for integer types, irrelevant for fp
|
||||
# 5: 4 - output dtype
|
||||
fieldset(1, 4)
|
||||
# 6 - reserved
|
||||
# 9: 7 - A dtype: f16 = 0, b16 = 1
|
||||
if dtype == jnp.bfloat16:
|
||||
fieldset(1, 7)
|
||||
# 12:10 - B dtype: f16 = 0, b16 = 1
|
||||
if dtype == jnp.bfloat16:
|
||||
fieldset(1, 10)
|
||||
# 13 - negate A
|
||||
# 14 - negate B
|
||||
# 15 - transpose A
|
||||
# 16 - transpose B
|
||||
if transpose_b:
|
||||
fieldset(1, 16)
|
||||
# 22:17 - N >> 3
|
||||
fieldset(n >> 3, 17) # TODO validate field width
|
||||
# 23 - reserved
|
||||
# 28:24 - M >> 4
|
||||
fieldset(m >> 4, 24)
|
||||
# 29 - reserved
|
||||
# 31:30 - max shift under .ws: irrelevant here
|
||||
return desc
|
||||
|
||||
WARPSIZE = 32
|
||||
|
||||
m_tile = 128
|
||||
n_tile = 128
|
||||
k_tile = 64
|
||||
m = 16*m_tile
|
||||
n = 16*n_tile
|
||||
k = 16*k_tile
|
||||
# K = 16 is inherent to Blackwell MMA half precision instructions
|
||||
blackwell_mma_fp16_k = 16
|
||||
ashape = (m, k) # k-major (row major)
|
||||
ashape_tile = (m_tile, k_tile)
|
||||
bshape = (n, k) # k-major (column major)
|
||||
bshape_tile = (n_tile, k_tile)
|
||||
dshape = (m, n) # n-major (row major)
|
||||
in_dtype = jnp.bfloat16
|
||||
ddtype = jnp.float32
|
||||
grid = (n//n_tile, m//m_tile, 1)
|
||||
block = (8*WARPSIZE, 1, 1)
|
||||
|
||||
def kernel(ctx, a, b, d, smem):
|
||||
i1 = ir.IntegerType.get_signless(1)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
f32 = ir.F32Type.get()
|
||||
index = ir.IndexType.get()
|
||||
ptr = ir.Type.parse("!llvm.ptr")
|
||||
ptr6 = ir.Type.parse("!llvm.ptr<6>")
|
||||
tidx = gpu.thread_id(gpu.Dimension.x)
|
||||
warpid = arith.shrui(tidx, c(5,index))
|
||||
tma_warp = 0
|
||||
mma_warp = 1
|
||||
# need a full aligned warpgroup, here it's warps 4-7
|
||||
ldtm_warp_range = (4, 8)
|
||||
ptr6 = ir.Type.parse("!llvm.ptr<6>") # TMEM
|
||||
|
||||
@contextlib.contextmanager
|
||||
def only_warp(i):
|
||||
is_warp_i = arith.cmpi(arith.CmpIPredicate.eq, warpid, c(i,index))
|
||||
with ir.InsertionPoint(scf.IfOp(is_warp_i).then_block):
|
||||
yield
|
||||
scf.yield_([])
|
||||
swizzle = 128
|
||||
tile_k = 64 # TODO(apaszke): I think we need to tile TMA to change this.
|
||||
in_dtype = jnp.float16
|
||||
k_loop_iter = k // tile_k
|
||||
|
||||
@contextlib.contextmanager
|
||||
def only_warp_range(a, b):
|
||||
gea = arith.cmpi(arith.CmpIPredicate.uge, warpid, c(a,index))
|
||||
ltb = arith.cmpi(arith.CmpIPredicate.ult, warpid, c(b,index))
|
||||
predicate = arith.andi(gea, ltb)
|
||||
with ir.InsertionPoint(scf.IfOp(predicate).then_block):
|
||||
yield
|
||||
scf.yield_([])
|
||||
if m % tile_m != 0:
|
||||
raise ValueError(f"{m=} must be divisible by {tile_m=}")
|
||||
if n % tile_n != 0:
|
||||
raise ValueError(f"{n=} must be divisible by {tile_n=}")
|
||||
if k % tile_k != 0:
|
||||
raise ValueError(f"{k=} must be divisible by {tile_k=}")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def single_warp_thread():
|
||||
elected = nvvm.elect_sync(i1)
|
||||
with ir.InsertionPoint(scf.IfOp(elected).then_block):
|
||||
yield
|
||||
scf.yield_([])
|
||||
def kernel(ctx, a, b, d, smem):
|
||||
# TODO(apaszke): Use more SMEM slots to avoid oversynchronizing warps.
|
||||
a_smem, b_smem, barriers, tmem_addr = smem
|
||||
(ab_full_barrier, ab_empty_barrier, mma_done_barrier) = barriers
|
||||
|
||||
a_shared, b_shared, (ab_full_barrier, ab_empty_barrier, mma_done_barrier), tmem_addr = smem
|
||||
thread_idx = mgpu.thread_idx()
|
||||
warp_idx = mgpu.warp_idx(sync=True)
|
||||
warp_leader = nvvm.elect_sync(i1)
|
||||
|
||||
k_loop_iter = k//k_tile
|
||||
is_warp = lambda i: arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32))
|
||||
|
||||
def bytecount(shape, dtype):
|
||||
return int(np.prod(shape) * dtype.dtype.itemsize)
|
||||
with mgpu.when(arith.andi(is_warp(TMA_WARP), warp_leader)):
|
||||
m_start = arith.muli(gpu.block_id(gpu.Dimension.y), c(tile_m,index))
|
||||
n_start = arith.muli(gpu.block_id(gpu.Dimension.x), c(tile_n,index))
|
||||
@mgpu.fori(c(k_loop_iter, index), None)
|
||||
def _tma_body(ki, _):
|
||||
# TODO(apaszke): Use a predicate instead of a conditional.
|
||||
with mgpu.when(arith.cmpi(arith.CmpIPredicate.ugt, ki, c(0, index))):
|
||||
ab_empty_barrier.wait()
|
||||
ab_full_barrier.arrive_expect_tx(
|
||||
bytecount((tile_m, tile_k), in_dtype) + bytecount((tile_n, tile_k), in_dtype)
|
||||
)
|
||||
k_start = arith.muli(ki, c(tile_k, index))
|
||||
common_args = dict(
|
||||
swizzle=swizzle, barrier=ab_full_barrier, arrive=False, uniform=False,
|
||||
)
|
||||
ctx.async_copy(
|
||||
src_ref=a,
|
||||
dst_ref=a_smem,
|
||||
gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)),
|
||||
**common_args,
|
||||
)
|
||||
ctx.async_copy(
|
||||
src_ref=b,
|
||||
dst_ref=b_smem,
|
||||
gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)),
|
||||
**common_args,
|
||||
)
|
||||
|
||||
txcount = bytecount(ashape_tile, in_dtype) + bytecount(bshape_tile, in_dtype)
|
||||
with only_warp(tma_warp), single_warp_thread():
|
||||
# XXX TODO move loop iteration into MLIR, otherwise we force unroll
|
||||
for i in range(k_loop_iter):
|
||||
if i > 0:
|
||||
ab_empty_barrier.wait()
|
||||
ab_full_barrier.arrive_expect_tx(txcount)
|
||||
m_start = arith.muli(gpu.block_id(gpu.Dimension.y), c(m_tile,index))
|
||||
n_start = arith.muli(gpu.block_id(gpu.Dimension.x), c(n_tile,index))
|
||||
k_start = i*k_tile
|
||||
ctx.async_copy(
|
||||
src_ref=a,
|
||||
dst_ref=a_shared,
|
||||
gmem_slice=(ds(m_start, m_tile), ds(k_start,k_tile)),
|
||||
swizzle=128,
|
||||
barrier=ab_full_barrier,
|
||||
arrive=False,
|
||||
uniform=False,
|
||||
)
|
||||
ctx.async_copy(
|
||||
src_ref=b,
|
||||
dst_ref=b_shared,
|
||||
gmem_slice=(ds(n_start, n_tile), ds(k_start,k_tile)),
|
||||
swizzle=128,
|
||||
barrier=ab_full_barrier,
|
||||
arrive=False,
|
||||
uniform=False,
|
||||
)
|
||||
with mgpu.when(is_warp(MMA_WARP)):
|
||||
ncols = c(b_smem.type.shape[0], i32)
|
||||
tmem_addr_addr = utils.memref_ptr(tmem_addr, memory_space=3)
|
||||
tcgen05.tmem_alloc(tmem_addr_addr, ncols)
|
||||
tcgen05.tmem_relinquish_alloc_permit()
|
||||
with mgpu.when(warp_leader):
|
||||
tmem_addr_value = llvm.load(ptr6, tmem_addr_addr)
|
||||
idesc = tcgen05.create_instr_descriptor(
|
||||
m=tile_n, n=tile_n, acc_dtype=jnp.float32, input_dtype=in_dtype
|
||||
)
|
||||
@mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0))
|
||||
def _mma_body(ki, accumulate):
|
||||
adesc = tcgen05.create_smem_descriptor(
|
||||
a_smem, leading_byte_offset=16, stride_byte_offset=1024, swizzle=swizzle)
|
||||
bdesc = tcgen05.create_smem_descriptor(
|
||||
b_smem, leading_byte_offset=16, stride_byte_offset=1024, swizzle=swizzle)
|
||||
ab_full_barrier.wait()
|
||||
|
||||
with only_warp(mma_warp):
|
||||
ncols = c(b_shared.type.shape[0], i32)
|
||||
tmem_addr_addr = utils.memref_ptr(tmem_addr, memory_space=3)
|
||||
nvvm.tcgen05_alloc(tmem_addr_addr, ncols)
|
||||
nvvm.tcgen05_relinquish_alloc_permit()
|
||||
accumulate = 0
|
||||
with single_warp_thread():
|
||||
tmem_addr_value = llvm.load(ptr6, tmem_addr_addr)
|
||||
idesc = create_blackwell_fp16_mma_descriptor(m_tile, n_tile, in_dtype)
|
||||
for i in range(k_loop_iter):
|
||||
adesc = create_smem_descriptor(
|
||||
a_shared, leading_byte_offset=16, stride_byte_offset=1024, swizzle=128)
|
||||
bdesc = create_smem_descriptor(
|
||||
b_shared, leading_byte_offset=16, stride_byte_offset=1024, swizzle=128)
|
||||
ab_full_barrier.wait()
|
||||
for _ in range(4):
|
||||
nvvm.tcgen05_mma("f16", "cta_1", tmem_addr_value, adesc, bdesc, idesc, enable_input_d=c(accumulate,i1))
|
||||
accumulate = 1
|
||||
adesc = smem_descriptor_increment_address(adesc, blackwell_mma_fp16_k*2)
|
||||
bdesc = smem_descriptor_increment_address(bdesc, blackwell_mma_fp16_k*2)
|
||||
last_iter = i == k_loop_iter-1
|
||||
barrier = mma_done_barrier if last_iter else ab_empty_barrier
|
||||
nvvm.tcgen05_commit_arrive(barrier.get_ptr())
|
||||
# TODO(apaszke): Abstract this into a function.
|
||||
assert tile_k % BLACKWELL_MMA_FP16_K == 0
|
||||
def smem_descriptor_increment_address(desc, nbytes):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
return arith.addi(desc, arith.shrui(c(nbytes,i64), c(4,i64)))
|
||||
for _ in range(tile_k // BLACKWELL_MMA_FP16_K):
|
||||
tcgen05.mma("f16", 1, tmem_addr_value, adesc, bdesc, idesc, enable_input_d=accumulate)
|
||||
accumulate = arith.constant(i1, 1)
|
||||
adesc = smem_descriptor_increment_address(
|
||||
adesc, BLACKWELL_MMA_FP16_K * 2
|
||||
)
|
||||
bdesc = smem_descriptor_increment_address(
|
||||
bdesc, BLACKWELL_MMA_FP16_K * 2
|
||||
)
|
||||
|
||||
with only_warp_range(*ldtm_warp_range), ctx.named_region("LDTM"):
|
||||
is_last_iter = arith.cmpi(
|
||||
arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index)
|
||||
)
|
||||
barrier_ptr = arith.select(
|
||||
is_last_iter, mma_done_barrier.get_ptr(), ab_empty_barrier.get_ptr()
|
||||
)
|
||||
tcgen05.commit_arrive(barrier_ptr)
|
||||
return accumulate
|
||||
|
||||
# TODO(apaszke): Should we have a warpgroup that's dedicated to this store?
|
||||
gpu.barrier()
|
||||
|
||||
# TODO(apaszke): This has a very slow GMEM access pattern.
|
||||
mma_done_barrier.wait()
|
||||
tmem_ptr = llvm.inttoptr(ptr6, memref.load(tmem_addr, [c(0,index)]))
|
||||
# TODO automate type creation
|
||||
vector_i32 = nvvm.tcgen05_ld(ir.VectorType.get([n_tile], i32), shape="shape_32x32b", num=n_tile, tmem_addr=tmem_ptr)
|
||||
vector_i32 = tcgen05.tmem_load(num=tile_n, tmem_addr=tmem_ptr)
|
||||
vector_f32 = vector.bitcast(ir.VectorType.get(vector_i32.type.shape, f32), vector_i32)
|
||||
wg_tidx = arith.andi(tidx, c((1 << 7)-1, index)) # tidx within warpgroup
|
||||
wg_tidx = arith.remui(
|
||||
arith.index_castui(index, thread_idx), c(utils.WARPGROUP_SIZE, index)
|
||||
)
|
||||
row = arith.addi(
|
||||
arith.muli(
|
||||
gpu.block_id(gpu.Dimension.y),
|
||||
c(m_tile,index)),
|
||||
wg_tidx)
|
||||
column = arith.muli(
|
||||
gpu.block_id(gpu.Dimension.x),
|
||||
c(n_tile,index))
|
||||
arith.muli(gpu.block_id(gpu.Dimension.y), c(tile_m, index)), wg_tidx
|
||||
)
|
||||
column = arith.muli(gpu.block_id(gpu.Dimension.x), c(tile_n, index))
|
||||
vector.store(vector_f32, d, [row, column])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ka, kb = jr.split(jr.key(0), 2)
|
||||
a = jr.uniform(key=ka, shape=ashape, dtype=in_dtype)
|
||||
b = jr.uniform(key=kb, shape=bshape, dtype=in_dtype)
|
||||
|
||||
asds_tile = SDS(ashape_tile, a.dtype)
|
||||
bsds_tile = SDS(bshape_tile, b.dtype)
|
||||
dsds = SDS(dshape, ddtype)
|
||||
tmem_addr = SDS((1,), np.uint32)
|
||||
|
||||
smem = (asds_tile, bsds_tile, tuple(mgpu.Barrier(arrival_count=1) for _ in range(3)), tmem_addr)
|
||||
with mlir.make_ir_context(), ir.Location.unknown():
|
||||
f = mgpu.as_gpu_kernel(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
(SDS(a.shape, a.dtype), SDS(b.shape, b.dtype)),
|
||||
dsds,
|
||||
smem
|
||||
)
|
||||
y = f(a, b)
|
||||
|
||||
@jax.jit
|
||||
def ref_f(x, y):
|
||||
return jax.lax.dot_general(
|
||||
x,
|
||||
y,
|
||||
dimension_numbers=(((1,), (1,)), ((), ())),
|
||||
preferred_element_type=jnp.float32,
|
||||
).astype(jnp.float32)
|
||||
|
||||
ref = ref_f(a, b)
|
||||
np.testing.assert_allclose(
|
||||
y.astype(jnp.float32), ref.astype(jnp.float32), atol=1e-3, rtol=1e-3
|
||||
smem = (
|
||||
jax.ShapeDtypeStruct((tile_m, tile_k), jnp.float16),
|
||||
jax.ShapeDtypeStruct((tile_n, tile_k), jnp.float16),
|
||||
[mgpu.Barrier(arrival_count=1)] * 3,
|
||||
jax.ShapeDtypeStruct((1,), np.uint32), # TMEM address
|
||||
)
|
||||
return mgpu.as_gpu_kernel(
|
||||
kernel,
|
||||
(n // tile_n, m // tile_m, 1),
|
||||
(128, 1, 1),
|
||||
(
|
||||
jax.ShapeDtypeStruct((m, k), jnp.float16),
|
||||
jax.ShapeDtypeStruct((n, k), jnp.float16),
|
||||
),
|
||||
jax.ShapeDtypeStruct((m, n), jnp.float32),
|
||||
smem,
|
||||
)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
m_tile = 128
|
||||
n_tile = 128
|
||||
k_tile = 64
|
||||
m = 16*m_tile
|
||||
n = 16*n_tile
|
||||
k = 16*k_tile
|
||||
|
||||
ka, kb = jr.split(jr.key(0), 2)
|
||||
a = jr.normal(key=ka, shape=(m, k), dtype=jnp.float16)
|
||||
b = jr.normal(key=kb, shape=(n, k), dtype=jnp.float16)
|
||||
|
||||
with mlir.make_ir_context(), ir.Location.unknown():
|
||||
f = build_kernel(m, n, k, tile_m=m_tile, tile_n=n_tile)
|
||||
y = f(a, b).block_until_ready()
|
||||
|
||||
ref = np.asarray(a) @ np.asarray(b).T
|
||||
np.testing.assert_allclose(y, ref, atol=1e-3, rtol=1e-3)
|
||||
print("OK!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from absl import app
|
||||
import jax
|
||||
jax.config.config_with_absl()
|
||||
app.run(main)
|
||||
|
133
jax/experimental/mosaic/gpu/tcgen05.py
Normal file
133
jax/experimental/mosaic/gpu/tcgen05.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2025 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
import numpy as np
|
||||
|
||||
from . import _wgmma
|
||||
|
||||
def create_smem_descriptor(
|
||||
memref_arg,
|
||||
leading_byte_offset: int,
|
||||
stride_byte_offset: int,
|
||||
swizzle: int | mgpu_dialect.SwizzlingMode | None,
|
||||
):
|
||||
blackwell_bit = 1 << 46
|
||||
return _wgmma.create_descriptor(
|
||||
memref_arg,
|
||||
leading_byte_offset,
|
||||
stride_byte_offset,
|
||||
swizzle,
|
||||
memory_space=3,
|
||||
const_init=blackwell_bit,
|
||||
)
|
||||
|
||||
def create_instr_descriptor(
|
||||
m: int,
|
||||
n: int,
|
||||
acc_dtype,
|
||||
input_dtype,
|
||||
transpose_a: bool = False,
|
||||
transpose_b: bool = False,
|
||||
):
|
||||
if input_dtype not in {np.float16, dtypes.bfloat16}:
|
||||
raise NotImplementedError("Only float16 and bfloat16 inputs supported")
|
||||
if acc_dtype not in {np.float32, np.float16}:
|
||||
raise NotImplementedError("Only float32 and float16 accumulators supported")
|
||||
|
||||
desc = 0
|
||||
# We ignore sparsity in bits 0-3
|
||||
desc |= (acc_dtype == np.float32) << 4 # D dtype, bits 4-5
|
||||
# Bit 6 is reserved
|
||||
desc |= (input_dtype == dtypes.bfloat16) << 7 # A dtype, bits 7-9
|
||||
desc |= (input_dtype == dtypes.bfloat16) << 10 # B dtype, bits 10-12
|
||||
# We ignore negate bits 13-14
|
||||
desc |= transpose_a << 15 # Transpose A
|
||||
desc |= transpose_b << 16 # Transpose B
|
||||
if n % 8 or n > 256:
|
||||
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
|
||||
desc |= (n >> 3) << 17 # N, bits 17-22
|
||||
# Bit 23 is reserved
|
||||
if m % 16 or m > 256:
|
||||
raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}")
|
||||
desc |= (m >> 4) << 24 # M >> 4, bits 24-28
|
||||
# Bit 29 is reserved
|
||||
# We ignore max shift under .ws, bits 30-31
|
||||
return arith.constant(ir.IntegerType.get_signless(32), desc) # type: ignore
|
||||
|
||||
|
||||
def mma(dtype, num_cta, d_tmem, adesc, bdesc, idesc, enable_input_d):
|
||||
if not (1 <= num_cta <= 2):
|
||||
raise ValueError(f"num_cta must be 1 or 2, got: {num_cta}")
|
||||
return llvm.inline_asm(
|
||||
ir.Type.parse("!llvm.void"),
|
||||
[d_tmem, adesc, bdesc, idesc, enable_input_d],
|
||||
f"tcgen05.mma.cta_group::1.kind::{dtype} [$0], $1, $2, $3, $4;",
|
||||
"r,l,l,r,b",
|
||||
has_side_effects=True,
|
||||
)
|
||||
|
||||
def commit_arrive(barrier):
|
||||
return llvm.inline_asm(
|
||||
ir.Type.parse("!llvm.void"),
|
||||
[barrier],
|
||||
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$0];",
|
||||
"l",
|
||||
has_side_effects=True
|
||||
)
|
||||
|
||||
def tmem_alloc(tmem_addr, ncols):
|
||||
return llvm.inline_asm(
|
||||
ir.Type.parse("!llvm.void"),
|
||||
[tmem_addr, ncols],
|
||||
"tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$0], $1;",
|
||||
"r,r",
|
||||
has_side_effects=True,
|
||||
)
|
||||
|
||||
def tmem_relinquish_alloc_permit():
|
||||
return llvm.inline_asm(
|
||||
ir.Type.parse("!llvm.void"),
|
||||
[],
|
||||
"tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;",
|
||||
"",
|
||||
has_side_effects=True,
|
||||
)
|
||||
|
||||
def tmem_load(tmem_addr, num):
|
||||
if num.bit_count() != 1 or num > 128:
|
||||
raise ValueError(f"num must be a power of 2 and <= 128, got: {num}")
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
out_regs = ",".join("$" + str(i) for i in range(num))
|
||||
regs = llvm.inline_asm(
|
||||
ir.Type.parse(
|
||||
"!llvm.struct<(" + ",".join("i32" for _ in range(num)) + ")>"
|
||||
),
|
||||
[tmem_addr],
|
||||
f"tcgen05.ld.sync.aligned.32x32b.x{num}.b32 {{{out_regs}}}, [${num}];",
|
||||
"=r," * num + "r",
|
||||
has_side_effects=True,
|
||||
)
|
||||
out_ty = ir.VectorType.get([num], i32)
|
||||
out_vec = llvm.mlir_undef(out_ty)
|
||||
for i in range(num):
|
||||
out_vec = llvm.insertelement(
|
||||
out_vec, llvm.extractvalue(i32, regs, [i]), arith.constant(i32, i)
|
||||
)
|
||||
return out_vec
|
@ -100,6 +100,7 @@ def create_descriptor(
|
||||
stride_byte_offset: int,
|
||||
swizzle: int | mgpu_dialect.SwizzlingMode | None,
|
||||
memory_space: int | None = None,
|
||||
const_init: int = 0,
|
||||
):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
ptr_val = llvm.ptrtoint(i64, utils.memref_ptr(memref_arg, memory_space))
|
||||
@ -118,7 +119,8 @@ def create_descriptor(
|
||||
)
|
||||
# We ignore the offset
|
||||
desc_const = (
|
||||
(wgmma_encode(leading_byte_offset) << 16)
|
||||
const_init
|
||||
| (wgmma_encode(leading_byte_offset) << 16)
|
||||
| (wgmma_encode(stride_byte_offset) << 32)
|
||||
)
|
||||
desc = llvm.or_(
|
||||
|
Loading…
x
Reference in New Issue
Block a user