mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic GPU] Add basic support for TMA with sub-byte types
PiperOrigin-RevId: 719240287
This commit is contained in:
parent
313e35a214
commit
7043b852ec
@ -290,7 +290,7 @@ class LaunchContext:
|
||||
args = [
|
||||
host_ptr,
|
||||
base_ptr,
|
||||
c(utils.bytewidth(ref_ty.element_type), i64),
|
||||
c(utils.bitwidth(ref_ty.element_type), i64),
|
||||
c(rank, i64),
|
||||
utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]),
|
||||
utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]),
|
||||
|
@ -1067,13 +1067,24 @@ def memref_ptr(memref_arg, memory_space=None):
|
||||
desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
|
||||
aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])
|
||||
|
||||
elem_bytewidth = bytewidth(memref_ty.element_type)
|
||||
offset_elems = llvm.extractvalue(i64, desc, [2])
|
||||
offset_bytes = llvm.mul(
|
||||
offset_elems,
|
||||
c(elem_bytewidth, i64),
|
||||
overflow_flags=llvm.IntegerOverflowFlags.none,
|
||||
)
|
||||
elem_bitwidth = bitwidth(memref_ty.element_type)
|
||||
if elem_bitwidth < 8:
|
||||
*_, static_offset = memref_ty.get_strides_and_offset()
|
||||
if static_offset == ir.ShapedType.get_dynamic_stride_or_offset():
|
||||
raise NotImplementedError
|
||||
assert elem_bitwidth.bit_count() == 1
|
||||
packing = 8 // elem_bitwidth
|
||||
if static_offset % packing != 0:
|
||||
raise ValueError
|
||||
offset_bytes = c(static_offset // packing, i64)
|
||||
else:
|
||||
assert elem_bitwidth % 8 == 0
|
||||
offset_bytes = llvm.mul(
|
||||
offset_elems,
|
||||
c(elem_bitwidth // 8, i64),
|
||||
overflow_flags=llvm.IntegerOverflowFlags.none,
|
||||
)
|
||||
return llvm.inttoptr(
|
||||
ptr_ty,
|
||||
llvm.add(
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
@ -21,7 +22,7 @@ limitations under the License.
|
||||
extern "C" {
|
||||
|
||||
void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
|
||||
int64_t elem_bytewidth, int64_t rank,
|
||||
int64_t elem_bitwidth, int64_t rank,
|
||||
int64_t *sizes, int64_t *strides,
|
||||
int64_t swizzle_bytes, int64_t *window_shape) {
|
||||
if (((uintptr_t)tma_desc) % 64 != 0) {
|
||||
@ -31,6 +32,28 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
|
||||
abort();
|
||||
}
|
||||
|
||||
// Pack 4 bit types in 8 bit pairs.
|
||||
int64_t elem_bytewidth;
|
||||
if (elem_bitwidth < 8) {
|
||||
// Check that it's a power of 2.
|
||||
assert((elem_bitwidth & (elem_bitwidth - 1)) == 0);
|
||||
int packing = 8 / elem_bitwidth;
|
||||
assert(sizes[rank - 1] % packing == 0);
|
||||
assert(window_shape[rank - 1] % packing == 0);
|
||||
assert(strides[rank - 1] == 1);
|
||||
|
||||
// TMA requires that the last dimension be the contiguous one so we pack the
|
||||
// elements under that assumption.
|
||||
sizes[rank - 1] /= packing;
|
||||
window_shape[rank - 1] /= packing;
|
||||
for (int i = 0; i < rank - 1; i++) {
|
||||
strides[i] /= packing;
|
||||
}
|
||||
elem_bytewidth = 1;
|
||||
} else {
|
||||
elem_bytewidth = elem_bitwidth / 8;
|
||||
}
|
||||
|
||||
CUtensorMapDataType data_type;
|
||||
if (elem_bytewidth == 1) {
|
||||
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
|
@ -88,6 +88,7 @@ def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
|
||||
thread_id = arith.addi(thread_id, arith.muli(gpu.thread_id(dim), stride))
|
||||
stride = arith.muli(stride, gpu.block_dim(dim))
|
||||
is_first_thread = arith.cmpi(arith.CmpIPredicate.eq, thread_id, c(0, index))
|
||||
|
||||
src_ty = ir.MemRefType(src.type)
|
||||
dst_ty = ir.MemRefType(dst.type)
|
||||
if src_ty.shape != dst_ty.shape:
|
||||
@ -95,13 +96,65 @@ def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
|
||||
f"src and dst shapes don't match: {src_ty.shape} != {dst_ty.shape}"
|
||||
)
|
||||
shape = src_ty.shape
|
||||
dyn_strides = [c(s, index) for s in get_contiguous_strides(shape)]
|
||||
if src_ty.element_type != dst_ty.element_type:
|
||||
raise ValueError(
|
||||
f"src and dst element types don't match: {src_ty.element_type} !="
|
||||
f" {dst_ty.element_type}"
|
||||
)
|
||||
contig_strides = get_contiguous_strides(shape)
|
||||
# If swizzling is on, at least one of the memrefs must be contiguous
|
||||
# (simulating a TMA).
|
||||
if (swizzle is not None and
|
||||
src_ty.get_strides_and_offset()[0] != contig_strides and
|
||||
dst_ty.get_strides_and_offset()[0] != contig_strides):
|
||||
raise NotImplementedError(src_ty, dst_ty)
|
||||
|
||||
bw = bitwidth(src_ty.element_type)
|
||||
if bw < 8:
|
||||
assert bw.bit_count() == 1
|
||||
packing = 8 // bw
|
||||
if shape[-1] % packing:
|
||||
raise NotImplementedError
|
||||
workgroup_mem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
shape = (*shape[:-1], shape[-1] // packing)
|
||||
contig_strides = get_contiguous_strides(shape)
|
||||
def bitcast(ref):
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
old_strides = ref_ty.get_strides_and_offset()[0]
|
||||
if old_strides[-1] != 1:
|
||||
raise NotImplementedError
|
||||
new_strides = [s // packing for s in old_strides[:-1]] + [1]
|
||||
new_ref_ty = ir.MemRefType.get(
|
||||
shape,
|
||||
ir.VectorType.get((packing,), src_ty.element_type), # noqa: F821
|
||||
ir.StridedLayoutAttr.get(0, new_strides),
|
||||
ref_ty.memory_space,
|
||||
)
|
||||
ptr_space = (
|
||||
3
|
||||
if ref_ty.memory_space is not None
|
||||
and ref_ty.memory_space == workgroup_mem
|
||||
else None
|
||||
)
|
||||
return ptr_as_memref(
|
||||
# NOTE: memref_ptr applies the offset in case there was any.
|
||||
memref_ptr(ref, memory_space=ptr_space),
|
||||
new_ref_ty,
|
||||
ptr_memory_space=ptr_space,
|
||||
)
|
||||
src = bitcast(src)
|
||||
dst = bitcast(dst)
|
||||
bw = 8
|
||||
del src_ty, dst_ty # If you remove this, update it in the branch above
|
||||
dyn_strides = [c(s, index) for s in contig_strides]
|
||||
|
||||
with ir.InsertionPoint(scf.IfOp(is_first_thread).then_block):
|
||||
def body(*idx):
|
||||
dst_idx = idx
|
||||
if swizzle is not None:
|
||||
assert swizzle.bit_count() == 1
|
||||
bytes_per_element = bytewidth(src_ty.element_type)
|
||||
assert bw % 8 == 0
|
||||
bytes_per_element = bw // 8
|
||||
linear_idx = c(0, index)
|
||||
for stride, i in zip(dyn_strides, idx):
|
||||
linear_idx = arith.addi(linear_idx, arith.muli(i, stride))
|
||||
@ -963,10 +1016,11 @@ class TMATest(TestCase):
|
||||
@parameterized.product(
|
||||
swizzle=(None, 32, 64, 128),
|
||||
shape=((64, None), (5, None), (2, 3, 5, None)),
|
||||
dtype=(jnp.float16, jnp.float32),
|
||||
dtype=(jnp.float32, jnp.float16, jnp.int4),
|
||||
)
|
||||
def test_tma_load_basic(self, swizzle, shape, dtype):
|
||||
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
|
||||
bw = bitwidth(dtype_to_ir_type(dtype))
|
||||
minor_size = 64 if swizzle is None else 8 * swizzle // bw
|
||||
shape = (*shape[:-1], minor_size)
|
||||
i1 = ir.IntegerType.get_signless(1)
|
||||
def kernel(ctx, src, dst, smem):
|
||||
@ -1044,12 +1098,14 @@ class TMATest(TestCase):
|
||||
idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))
|
||||
)
|
||||
stride *= cluster[d]
|
||||
slc = ds(
|
||||
arith.muli(idx, c(16, index)), 16
|
||||
idx_minor = arith.divui(idx, c(2, index))
|
||||
idx_major = arith.remui(idx, c(2, index))
|
||||
slc_minor = ds(
|
||||
arith.muli(idx_minor, c(16 * 2, index)), 16 * 2
|
||||
)
|
||||
copy(
|
||||
memref_slice(tmp, (slice(None), slc)),
|
||||
memref_slice(dst, (noncollective_idx, slice(None), slc)),
|
||||
memref_slice(tmp, (idx_major, slc_minor)),
|
||||
memref_slice(dst, (noncollective_idx, idx_major, slc_minor)),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user