[Mosaic GPU] Add basic support for TMA with sub-byte types

PiperOrigin-RevId: 719240287
This commit is contained in:
Adam Paszke 2025-01-24 03:53:19 -08:00 committed by jax authors
parent 313e35a214
commit 7043b852ec
4 changed files with 106 additions and 16 deletions

View File

@ -290,7 +290,7 @@ class LaunchContext:
args = [
host_ptr,
base_ptr,
c(utils.bytewidth(ref_ty.element_type), i64),
c(utils.bitwidth(ref_ty.element_type), i64),
c(rank, i64),
utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]),
utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]),

View File

@ -1067,13 +1067,24 @@ def memref_ptr(memref_arg, memory_space=None):
desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])
elem_bytewidth = bytewidth(memref_ty.element_type)
offset_elems = llvm.extractvalue(i64, desc, [2])
offset_bytes = llvm.mul(
offset_elems,
c(elem_bytewidth, i64),
overflow_flags=llvm.IntegerOverflowFlags.none,
)
elem_bitwidth = bitwidth(memref_ty.element_type)
if elem_bitwidth < 8:
*_, static_offset = memref_ty.get_strides_and_offset()
if static_offset == ir.ShapedType.get_dynamic_stride_or_offset():
raise NotImplementedError
assert elem_bitwidth.bit_count() == 1
packing = 8 // elem_bitwidth
if static_offset % packing != 0:
raise ValueError
offset_bytes = c(static_offset // packing, i64)
else:
assert elem_bitwidth % 8 == 0
offset_bytes = llvm.mul(
offset_elems,
c(elem_bitwidth // 8, i64),
overflow_flags=llvm.IntegerOverflowFlags.none,
)
return llvm.inttoptr(
ptr_ty,
llvm.add(

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cassert>
#include <cstdint>
#include <cstdio>
@ -21,7 +22,7 @@ limitations under the License.
extern "C" {
void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
int64_t elem_bytewidth, int64_t rank,
int64_t elem_bitwidth, int64_t rank,
int64_t *sizes, int64_t *strides,
int64_t swizzle_bytes, int64_t *window_shape) {
if (((uintptr_t)tma_desc) % 64 != 0) {
@ -31,6 +32,28 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
abort();
}
// Pack 4 bit types in 8 bit pairs.
int64_t elem_bytewidth;
if (elem_bitwidth < 8) {
// Check that it's a power of 2.
assert((elem_bitwidth & (elem_bitwidth - 1)) == 0);
int packing = 8 / elem_bitwidth;
assert(sizes[rank - 1] % packing == 0);
assert(window_shape[rank - 1] % packing == 0);
assert(strides[rank - 1] == 1);
// TMA requires that the last dimension be the contiguous one so we pack the
// elements under that assumption.
sizes[rank - 1] /= packing;
window_shape[rank - 1] /= packing;
for (int i = 0; i < rank - 1; i++) {
strides[i] /= packing;
}
elem_bytewidth = 1;
} else {
elem_bytewidth = elem_bitwidth / 8;
}
CUtensorMapDataType data_type;
if (elem_bytewidth == 1) {
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8;

View File

@ -88,6 +88,7 @@ def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
thread_id = arith.addi(thread_id, arith.muli(gpu.thread_id(dim), stride))
stride = arith.muli(stride, gpu.block_dim(dim))
is_first_thread = arith.cmpi(arith.CmpIPredicate.eq, thread_id, c(0, index))
src_ty = ir.MemRefType(src.type)
dst_ty = ir.MemRefType(dst.type)
if src_ty.shape != dst_ty.shape:
@ -95,13 +96,65 @@ def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
f"src and dst shapes don't match: {src_ty.shape} != {dst_ty.shape}"
)
shape = src_ty.shape
dyn_strides = [c(s, index) for s in get_contiguous_strides(shape)]
if src_ty.element_type != dst_ty.element_type:
raise ValueError(
f"src and dst element types don't match: {src_ty.element_type} !="
f" {dst_ty.element_type}"
)
contig_strides = get_contiguous_strides(shape)
# If swizzling is on, at least one of the memrefs must be contiguous
# (simulating a TMA).
if (swizzle is not None and
src_ty.get_strides_and_offset()[0] != contig_strides and
dst_ty.get_strides_and_offset()[0] != contig_strides):
raise NotImplementedError(src_ty, dst_ty)
bw = bitwidth(src_ty.element_type)
if bw < 8:
assert bw.bit_count() == 1
packing = 8 // bw
if shape[-1] % packing:
raise NotImplementedError
workgroup_mem = ir.Attribute.parse("#gpu.address_space<workgroup>")
shape = (*shape[:-1], shape[-1] // packing)
contig_strides = get_contiguous_strides(shape)
def bitcast(ref):
ref_ty = ir.MemRefType(ref.type)
old_strides = ref_ty.get_strides_and_offset()[0]
if old_strides[-1] != 1:
raise NotImplementedError
new_strides = [s // packing for s in old_strides[:-1]] + [1]
new_ref_ty = ir.MemRefType.get(
shape,
ir.VectorType.get((packing,), src_ty.element_type), # noqa: F821
ir.StridedLayoutAttr.get(0, new_strides),
ref_ty.memory_space,
)
ptr_space = (
3
if ref_ty.memory_space is not None
and ref_ty.memory_space == workgroup_mem
else None
)
return ptr_as_memref(
# NOTE: memref_ptr applies the offset in case there was any.
memref_ptr(ref, memory_space=ptr_space),
new_ref_ty,
ptr_memory_space=ptr_space,
)
src = bitcast(src)
dst = bitcast(dst)
bw = 8
del src_ty, dst_ty # If you remove this, update it in the branch above
dyn_strides = [c(s, index) for s in contig_strides]
with ir.InsertionPoint(scf.IfOp(is_first_thread).then_block):
def body(*idx):
dst_idx = idx
if swizzle is not None:
assert swizzle.bit_count() == 1
bytes_per_element = bytewidth(src_ty.element_type)
assert bw % 8 == 0
bytes_per_element = bw // 8
linear_idx = c(0, index)
for stride, i in zip(dyn_strides, idx):
linear_idx = arith.addi(linear_idx, arith.muli(i, stride))
@ -963,10 +1016,11 @@ class TMATest(TestCase):
@parameterized.product(
swizzle=(None, 32, 64, 128),
shape=((64, None), (5, None), (2, 3, 5, None)),
dtype=(jnp.float16, jnp.float32),
dtype=(jnp.float32, jnp.float16, jnp.int4),
)
def test_tma_load_basic(self, swizzle, shape, dtype):
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
bw = bitwidth(dtype_to_ir_type(dtype))
minor_size = 64 if swizzle is None else 8 * swizzle // bw
shape = (*shape[:-1], minor_size)
i1 = ir.IntegerType.get_signless(1)
def kernel(ctx, src, dst, smem):
@ -1044,12 +1098,14 @@ class TMATest(TestCase):
idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))
)
stride *= cluster[d]
slc = ds(
arith.muli(idx, c(16, index)), 16
idx_minor = arith.divui(idx, c(2, index))
idx_major = arith.remui(idx, c(2, index))
slc_minor = ds(
arith.muli(idx_minor, c(16 * 2, index)), 16 * 2
)
copy(
memref_slice(tmp, (slice(None), slc)),
memref_slice(dst, (noncollective_idx, slice(None), slc)),
memref_slice(tmp, (idx_major, slc_minor)),
memref_slice(dst, (noncollective_idx, idx_major, slc_minor)),
swizzle=swizzle,
)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)