mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Add support for tiled loads/stores with sub-byte types
Apparently MLIR and LLVM love to pad sub-byte types to whole bytes, so only the code where we do address arithmetic ourselves is easy to adapt. PiperOrigin-RevId: 720593538
This commit is contained in:
parent
e332b94f19
commit
f504d32492
@ -1060,6 +1060,11 @@ class FragmentedArray:
|
||||
return FragmentedArray(
|
||||
_registers=self.registers, _layout=self.layout, _is_signed=is_signed
|
||||
)
|
||||
# XLA packs elements into bytes in big-endian order, while LLVM assumes the
|
||||
# same endianness as the target machine (which is little for NVIDIA GPUs).
|
||||
# We'll need to add specialized casting routines that flip the endianness.
|
||||
if 1 < utils.bitwidth(cur_dtype) < 8 or 1 < utils.bitwidth(new_dtype) < 8:
|
||||
raise NotImplementedError("Conversion involving sub-byte types unsupported", cur_dtype, new_dtype)
|
||||
reg_type = self.registers.flat[0].type
|
||||
is_vector_reg = ir.VectorType.isinstance(reg_type)
|
||||
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
|
||||
@ -1641,30 +1646,57 @@ class FragmentedArray:
|
||||
if any(t % wt for t, wt in zip(ref_tiling_suffix, layout.base_tile_shape)):
|
||||
raise ValueError("Memory tiling must be a multiple of the register tiling")
|
||||
|
||||
if swizzle not in {32, 64, 128}:
|
||||
raise ValueError("Only swizzled transfers supported")
|
||||
bw = mgpu.bitwidth(dtype)
|
||||
swizzle_tile_elems = (16 * 8) // bw
|
||||
swizzle_group_elems = (128 * 8) // bw
|
||||
swizzle_groups_per_block = swizzle // 16
|
||||
swizzle_block_elems = swizzle_groups_per_block * swizzle_group_elems
|
||||
|
||||
tiled_strides = list(tiling.tile_strides(tuple(ref_strides)))
|
||||
elem_tiled_strides = list(tiling.tile_strides(tuple(ref_strides)))
|
||||
tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape)))
|
||||
lane_strides = [tiled_strides[d] for d in layout.lane_dims]
|
||||
elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims]
|
||||
lane_shape = [tiled_shape[d] for d in layout.lane_dims]
|
||||
if tiled_strides[layout.vector_dim] != 1:
|
||||
if elem_tiled_strides[layout.vector_dim] != 1:
|
||||
raise ValueError("Stride of the vectorized dimension should be 1")
|
||||
for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim):
|
||||
tiled_shape[d] = 1
|
||||
full_tiling = Tiling((ref_tiling_shape, *tiling.tiles))
|
||||
full_layout = dataclasses.replace(layout, tiling=full_tiling)
|
||||
|
||||
element_bits = mgpu.bitwidth(dtype)
|
||||
if (layout.vector_length * element_bits) % 8 != 0:
|
||||
raise ValueError(
|
||||
f"Vector length ({layout.vector_length}) must be a multiple of bytes,"
|
||||
f" but has {layout.vector_length * element_bits} bits"
|
||||
)
|
||||
transfer_bytes = (layout.vector_length * element_bits) // 8
|
||||
# Not sure if this is strictly required for all data types, but it certainly
|
||||
# is for sub-byte types (else we might not increment the pointer by whole bytes).
|
||||
if any(
|
||||
s % layout.vector_length and i != layout.vector_dim and d != 1
|
||||
for i, (s, d) in enumerate_negative(
|
||||
list(zip(elem_tiled_strides, tiled_shape))
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Tiled strides must be a multiple of the vector length, except for the"
|
||||
" vector dimension"
|
||||
)
|
||||
|
||||
if swizzle not in {32, 64, 128}:
|
||||
raise ValueError("Only swizzled transfers supported")
|
||||
# We will be computing the offsets in units of vectors, not elements,
|
||||
# to better support sub-byte types.
|
||||
swizzle_tile_transfers = 16 // transfer_bytes
|
||||
swizzle_group_transfers = 128 // transfer_bytes
|
||||
swizzle_groups_per_block = swizzle // 16
|
||||
swizzle_block_transfers = swizzle_groups_per_block * swizzle_group_transfers
|
||||
# Technically we should keep the vector_dim set to 1, but its shape is 1
|
||||
# so it does not matter.
|
||||
transfer_tiled_strides = [s // layout.vector_length for s in elem_tiled_strides]
|
||||
transfer_dtype = ir.VectorType.get((layout.vector_length,), dtype)
|
||||
|
||||
plan = plan_tiled_transfer(
|
||||
tiled_shape, tiled_strides, lane_shape, lane_strides, layout, bw, swizzle
|
||||
tiled_shape, elem_tiled_strides, lane_shape, elem_lane_strides, layout,
|
||||
element_bits, swizzle
|
||||
)
|
||||
|
||||
dyn_tiled_strides = [c(s) for s in tiled_strides]
|
||||
# All offsets are in units of transfer_dtype.
|
||||
dyn_tiled_strides = [c(s) for s in transfer_tiled_strides]
|
||||
lane_offset = utils.dyn_dot(full_layout.lane_indices(), dyn_tiled_strides)
|
||||
warp_offset = utils.dyn_dot(full_layout.warp_indices(), dyn_tiled_strides)
|
||||
dyn_offset = arith.addi(lane_offset, warp_offset)
|
||||
@ -1673,10 +1705,10 @@ class FragmentedArray:
|
||||
ptr = utils.memref_ptr(ref, memory_space=3)
|
||||
_as_consts = lambda consts: [c(const) for const in consts.tolist()]
|
||||
# This has bits set only for the offset bits that influence swizzling.
|
||||
swizzle_mask = swizzle_block_elems - swizzle_tile_elems
|
||||
swizzle_mask = swizzle_block_transfers - swizzle_tile_transfers
|
||||
for tile_idx in np.ndindex(*tiled_shape):
|
||||
indices = np.asarray([f(tile_idx) for f in plan.tile_index_transforms])
|
||||
const_offset = np.dot(indices, tiled_strides)
|
||||
const_offset = np.dot(indices, transfer_tiled_strides)
|
||||
# We split the offset into a part that interacts with swizzling and a
|
||||
# part that doesn't. This lets us generate better code because constant
|
||||
# offsets can be fused into load and store instructions.
|
||||
@ -1686,14 +1718,14 @@ class FragmentedArray:
|
||||
dyn_offset, plan.select(_as_consts(const_offset_swizzle))
|
||||
)
|
||||
swizzle_group = arith.remui(
|
||||
arith.divui(offset_pre_swizzle, c(swizzle_group_elems)),
|
||||
arith.divui(offset_pre_swizzle, c(swizzle_group_transfers)),
|
||||
c(swizzle_groups_per_block),
|
||||
)
|
||||
swizzle_bits = arith.muli(swizzle_group, c(swizzle_tile_elems))
|
||||
swizzle_bits = arith.muli(swizzle_group, c(swizzle_tile_transfers))
|
||||
offset = arith.xori(offset_pre_swizzle, swizzle_bits)
|
||||
reg_ptr = utils.getelementptr(ptr, [offset], dtype)
|
||||
reg_ptr = utils.getelementptr(ptr, [offset], transfer_dtype)
|
||||
offset_no_swizzle = plan.select(_as_consts(const_offset_no_swizzle))
|
||||
reg_ptr = utils.getelementptr(reg_ptr, [offset_no_swizzle], dtype)
|
||||
reg_ptr = utils.getelementptr(reg_ptr, [offset_no_swizzle], transfer_dtype)
|
||||
reg_idxs = [
|
||||
tiling.tile_indices(full_tiling.untile_indices(idx))
|
||||
for idx in indices.tolist()
|
||||
@ -1789,13 +1821,18 @@ def plan_tiled_transfer(
|
||||
lane_shape: Sequence[int],
|
||||
lane_strides: Sequence[int],
|
||||
layout: TiledLayout,
|
||||
bw: int,
|
||||
element_bits: int,
|
||||
swizzle: int,
|
||||
) -> TransferPlan:
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
c = lambda x: arith.constant(i32, x)
|
||||
swizzle_tile_elems = (16 * 8) // bw
|
||||
swizzle_group_elems = (128 * 8) // bw
|
||||
# TODO(apaszke): Rewrite this function in terms of transfer_bytes (that we get
|
||||
# from the caller).
|
||||
swizzle_tile_elems = (16 * 8) // element_bits
|
||||
swizzle_group_elems = (128 * 8) // element_bits
|
||||
# Should be checked at the call site.
|
||||
assert layout.vector_length * element_bits % 8 == 0
|
||||
transfer_bytes = (layout.vector_length * element_bits) // 8
|
||||
# Below, all calculations are in elements, not in bytes, since it should
|
||||
# generalize better to sub-byte types.
|
||||
# Here, we verify two conditions:
|
||||
@ -1821,16 +1858,13 @@ def plan_tiled_transfer(
|
||||
# we simply narrow each bank to the transfer width. The truth is more likely
|
||||
# that bank conflicts only don't occur if the addresses mapping to the same
|
||||
# bank are contiguous, but that's a more complicated check to perform.
|
||||
if (layout.vector_length * bw) % 8 != 0:
|
||||
raise ValueError(f"Vector must be whole bytes {layout.vector_length, bw}")
|
||||
transfer_bytes = (layout.vector_length * bw) // 8
|
||||
if transfer_bytes > SMEM_BANK_BYTES * 4:
|
||||
raise NotImplementedError
|
||||
if bw > SMEM_BANK_BYTES * 8:
|
||||
if element_bits > SMEM_BANK_BYTES * 8:
|
||||
raise NotImplementedError
|
||||
smem_bank_bytes = min(SMEM_BANK_BYTES, transfer_bytes)
|
||||
num_banks = SMEM_BANKS * (SMEM_BANK_BYTES // smem_bank_bytes)
|
||||
elems_per_bank = (smem_bank_bytes * 8) // bw
|
||||
elems_per_bank = (smem_bank_bytes * 8) // element_bits
|
||||
num_wavefronts = max(transfer_bytes // smem_bank_bytes, 1)
|
||||
wavefront_lanes = WARP_SIZE // num_wavefronts
|
||||
|
||||
|
@ -334,8 +334,9 @@ def bytewidth(ty: ir.Type):
|
||||
assert bw % 8 == 0, ty
|
||||
return bw // 8
|
||||
|
||||
def bitwidth(ty: ir.Type):
|
||||
# The actual width of TF32 is 19 bits. However, sinc we need to treat it as
|
||||
|
||||
def bitwidth_impl(ty: ir.Type):
|
||||
# The actual width of TF32 is 19 bits. However, we need to treat it as
|
||||
# 32 bits for compatibility reasons. TF32 used to be 32 bits wide in upstream
|
||||
# MLIR, but it changed in
|
||||
# https://github.com/llvm/llvm-project/commit/67a1fdb014790a38a205d28e1748634de34471dd.
|
||||
@ -350,6 +351,13 @@ def bitwidth(ty: ir.Type):
|
||||
raise NotImplementedError(ty)
|
||||
|
||||
|
||||
def bitwidth(ty: ir.Type):
|
||||
result = bitwidth_impl(ty)
|
||||
if result.bit_count() != 1:
|
||||
raise ValueError(f"Only power of 2 bitwidths are supported, got: {result}")
|
||||
return result
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DynamicSlice:
|
||||
base: ir.Value | int
|
||||
|
@ -561,27 +561,43 @@ class WGMMATest(TestCase):
|
||||
("bf16_i8", jnp.bfloat16, jnp.int8),
|
||||
("i8_bf16", jnp.int8, jnp.bfloat16),
|
||||
("i8_i8", jnp.int8, jnp.int8),
|
||||
("i4_i4", jnp.int4, jnp.int4),
|
||||
# TODO(apaszke): This needs specialized casts to handle the fact that XLA
|
||||
# packs int4 in big-endian order into bytes, which is the opposite of
|
||||
# what LLVM expects...
|
||||
# ("i4_bf16", jnp.int4, jnp.bfloat16),
|
||||
)
|
||||
def test_convert_tiled(self, jax_dtype_from, jax_dtype_to):
|
||||
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
|
||||
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
|
||||
m = 128
|
||||
n = 256 // bytewidth(mlir_dtype_from)
|
||||
n = 256 * 8 // bitwidth(mlir_dtype_from)
|
||||
def kernel(ctx, inp, out, smem):
|
||||
del ctx
|
||||
smem_from, smem_to = smem
|
||||
copy(inp, smem_from, swizzle=128)
|
||||
t = mgpu.FragmentedArray.load_tiled(
|
||||
smem_from, swizzle=128, is_signed=utils.is_signed(jax_dtype_from)
|
||||
smem_from,
|
||||
swizzle=128,
|
||||
is_signed=utils.is_signed(jax_dtype_from),
|
||||
layout=fa._tiled_wgmma_layout((m, n))
|
||||
)
|
||||
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
|
||||
t.store_tiled(smem_to, swizzle=128)
|
||||
copy(smem_to, out, swizzle=128)
|
||||
|
||||
from_tiling = (64, 128 // bytewidth(mlir_dtype_from))
|
||||
to_tiling = (64, 128 // bytewidth(mlir_dtype_to))
|
||||
from_tiling = (64, 128 * 8 // bitwidth(mlir_dtype_from))
|
||||
to_tiling = (64, 128 * 8 // bitwidth(mlir_dtype_to))
|
||||
# We only test lossless conversions for now.
|
||||
# TODO(apaszke): Test and fix failures that appear with lossy conversions.
|
||||
int_sample_dtype = getattr(
|
||||
jnp,
|
||||
"int" + str(min(bitwidth(mlir_dtype_from), bitwidth(mlir_dtype_to))),
|
||||
)
|
||||
sample_iinfo = jnp.iinfo(int_sample_dtype)
|
||||
expected_raw = self.prng.integers(
|
||||
low=-127, high=127, size=(m, n), dtype=np.int8
|
||||
low=sample_iinfo.min, high=sample_iinfo.max,
|
||||
size=(m, n), dtype=np.int32
|
||||
)
|
||||
expected = lambda jax_dtype, tiling: expected_raw.reshape(
|
||||
m // tiling[0], tiling[0], n // tiling[1], tiling[1]
|
||||
|
Loading…
x
Reference in New Issue
Block a user