[Mosaic GPU] Add a lowering for simple async_load and async_store ops.

Only untransformed and unsliced loads/stores are supported for now. The rest will be a follow up.

PiperOrigin-RevId: 708347442
This commit is contained in:
Dimitar (Mitko) Asenov 2024-12-20 09:37:30 -08:00 committed by jax authors
parent 01e8f889c2
commit dad23fed09
5 changed files with 242 additions and 50 deletions

View File

@ -174,7 +174,6 @@ def _construct_smem_reftree(
) -> RefTree:
index = ir.IndexType.get()
i8 = ir.IntegerType.get_signless(8)
ptr = ir.Type.parse("!llvm.ptr")
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
@ -183,11 +182,19 @@ def _construct_smem_reftree(
for ref_ty in flat_ref_tys:
def get_barrier_ptr(num_barriers: int) -> ir.Value:
nonlocal dynamic_smem_offset
smem_base_ptr = utils.memref_ptr(dynamic_smem, memory_space=3)
barrier_base_ptr = llvm.getelementptr(
ptr, smem_base_ptr, [], [dynamic_smem_offset], i8
workgroup_nvptx_address_space = (
dialect_lowering.gpu_address_space_to_nvptx(
gpu.AddressSpace.Workgroup
)
)
dynamic_smem_offset += num_barriers * MBARRIER_BYTES
smem_base_ptr = utils.memref_ptr(
dynamic_smem, memory_space=workgroup_nvptx_address_space
)
smem_ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
barrier_base_ptr = llvm.getelementptr(
smem_ptr_ty, smem_base_ptr, [], [dynamic_smem_offset], i8
)
dynamic_smem_offset += num_barriers * utils.MBARRIER_BYTES
return barrier_base_ptr
match ref_ty:
case Union(members):
@ -227,9 +234,6 @@ def _construct_smem_reftree(
return jax.tree.unflatten(smem_buffer_tree, smem_refs)
MBARRIER_BYTES = 8
def _smem_tree_size(smem_buffers: ShapeTree) -> int:
leaves = jax.tree.leaves(
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
@ -244,9 +248,9 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int:
| ClusterBarrier(_, num_barriers=num_barriers)
| Barrier(_, num_barriers=num_barriers)
):
if size % MBARRIER_BYTES:
if size % utils.MBARRIER_BYTES:
raise NotImplementedError("Misaligned barrier allocation")
size += num_barriers * MBARRIER_BYTES
size += num_barriers * utils.MBARRIER_BYTES
case _:
size += _count_buffer_bytes(l)
return size
@ -379,9 +383,11 @@ def _lower_as_gpu_kernel(
attrs["sym_name"] = ir.StringAttr.get(module_name)
if kernel_name is None:
kernel_name = getattr(body, "__name__", "anonymous")
# These are needed as nonlocal below.
launch_ctx, scratch_arr = None, None
with ir.InsertionPoint(module.body):
_declare_runtime_functions()
gmem_scratch_bytes = 0
global_scratch = llvm.GlobalOp(
ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet.
"global_scratch",
@ -390,7 +396,7 @@ def _lower_as_gpu_kernel(
)
@func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"mosaic_gpu_{kernel_name}")
def main(token_ptr, buffers):
nonlocal gmem_scratch_bytes
nonlocal launch_ctx, scratch_arr
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
arg_refs = []
for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
@ -408,27 +414,40 @@ def _lower_as_gpu_kernel(
with _launch(
token, grid, cluster, block, scratch_arr, smem_scratch_shape,
prof_spec, prof_buffer
) as (launch_ctx, smem_refs):
) as (_launch_ctx, smem_refs):
nonlocal launch_ctx
launch_ctx = _launch_ctx
body(launch_ctx, *in_refs, *out_refs, smem_refs)
gmem_scratch_bytes = launch_ctx.next_scratch_offset
# Allocate and initialize the host buffer right before the launch.
# Note that we couldn't do that before, because we had to run the body
# to learn what the scratch contains.
with ir.InsertionPoint(scratch_arr.owner):
scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>")
scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty)
scratch_arr.set_type(scratch_arr_ty)
for init_callback in launch_ctx.host_scratch_init:
init_callback(scratch_alloc.result)
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
sym_tab = ir.SymbolTable(module.operation)
sym_tab.insert(main.func_op)
sym_tab.insert(global_scratch)
module.operation.verify()
return module, out_shape, unwrap_output_tuple
return module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr
def _initialize_scratch(
launch_ctx : launch_context.LaunchContext,
scratch_arr: ir.Value,
):
"""
Allocates and initializes the host buffer right before the launch. This needs
to be done after all TMA descriptors have been recorded by the launch context.
Only then we know what the scratch contains.
When using the Mosaic GPU dialect, the necessary information is known only
after the lowering passes have run.
"""
with ir.InsertionPoint(scratch_arr.owner):
gmem_scratch_bytes = launch_ctx.next_scratch_offset
scratch_alloc_op = scratch_arr.owner.opview.addr.owner.opview
scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>")
scratch_alloc_op.elem_type = ir.TypeAttr.get(scratch_arr_ty)
scratch_arr.set_type(scratch_arr_ty)
for init_callback in launch_ctx.host_scratch_init:
init_callback(scratch_alloc_op.result)
def _declare_runtime_functions():
"""Declares the runtime functions that can be used by the generated code."""
ptr_ty = ir.Type.parse("!llvm.ptr")
@ -462,7 +481,7 @@ def as_gpu_kernel(
elif not isinstance(in_shape, tuple):
in_shape = (in_shape,)
module, out_shape, unwrap_output_tuple = (
module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = (
_lower_as_gpu_kernel(
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
module_name, kernel_name, prof_spec
@ -473,7 +492,10 @@ def as_gpu_kernel(
# Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr)
module.operation.verify()
expected_arg_treedef = jax.tree.structure(in_shape)
def _check_args(*args):
@ -530,6 +552,7 @@ def as_torch_gpu_kernel(
cluster: tuple[int, int, int] = (1, 1, 1),
module_name: str = "unknown",
kernel_name: str | None = None,
thread_semantics: ThreadSemantics = ThreadSemantics.Lane,
):
try:
import torch
@ -545,13 +568,22 @@ def as_torch_gpu_kernel(
flat_out_types, out_treedef = jax.tree.flatten(out_shape)
expected_arg_treedef = jax.tree.structure(in_shape)
module, out_shape, unwrap_output_tuple = (
module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = (
_lower_as_gpu_kernel(
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
module_name, kernel_name, prof_spec
)
)
if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None:
# Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
_initialize_scratch(launch_ctx, scratch_arr)
module.operation.verify()
# Get our hands on the compilation and unload functions
try:
import jax_plugins.xla_cuda12 as cuda_plugin

View File

@ -31,13 +31,14 @@ from jax._src.lib.mlir.dialects import vector
import numpy as np
from .fragmented_array import FragmentedArray, WGStridedFragLayout
from .launch_context import LaunchContext
from .layouts import from_strided_fragmented_layout_attr, has_any_layout_set, is_strided_fragmented_layout, should_have_layout, to_strided_fragmented_layout_attr
from .utils import c, ptr_as_memref, single_thread_predicate
from .utils import BarrierRef, c, memref_ptr, ptr_as_memref, single_thread_predicate
# mypy: ignore-errors
MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]]
MlirLoweringRule = Callable[[LaunchContext, ir.Operation | ir.OpView], Sequence[ir.Value]]
_lowerings: dict[str, MlirLoweringRule] = {}
@ -88,6 +89,9 @@ def _fragmented_array_from_ir(
[operand.type for operand in conversion_cast.operands],
conversion_cast.results,
)
if not isinstance(converted_outputs, list):
converted_outputs = [converted_outputs]
reverse_conversion_cast = converted_outputs[0].owner.opview
for attribute in conversion_cast.attributes:
@ -138,6 +142,7 @@ def gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int:
@_register_lowering(InitializeBarrierOp)
def _initialize_barrier_op_lowering_rule(
_: LaunchContext,
initialize_barrier_op: InitializeBarrierOp,
) -> Sequence[ir.Value]:
@ -170,7 +175,7 @@ def _initialize_barrier_op_lowering_rule(
@_register_lowering(vector.LoadOp)
def _vector_load_op_lowering_rule(
vector_load_op: vector.LoadOp,
_: LaunchContext, vector_load_op: vector.LoadOp
) -> Sequence[ir.Value]:
(out_layout_attr,) = cast(
ir.ArrayAttr, vector_load_op.attributes["out_layouts"]
@ -199,7 +204,7 @@ def _vector_load_op_lowering_rule(
@_register_lowering(vector.StoreOp)
def _vector_store_op_lowering_rule(
vector_store_op: vector.StoreOp,
_: LaunchContext, vector_store_op: vector.StoreOp
) -> Sequence[ir.Value]:
in_layout_attr, *_ = cast(
@ -229,8 +234,44 @@ def _vector_store_op_lowering_rule(
return []
@_register_lowering(mgpu.AsyncLoadOp)
def _mgpu_async_load_op_lowering_rule(
launch_context: LaunchContext, load_op: mgpu.AsyncLoadOp
) -> Sequence[ir.Value]:
mem_space = gpu_address_space_to_nvptx(gpu.AddressSpace.Workgroup)
# TODO(dasenov): Add support for the remaining op properties.
launch_context.async_copy(
src_ref=load_op.source,
dst_ref=load_op.destination,
barrier=BarrierRef(
base_address=memref_ptr(load_op.barrier, memory_space=mem_space),
offset=c(0, ir.IntegerType.get_signless(64)),
phases=None,
num_barriers=1,
),
arrive=load_op.arrive,
uniform=False,
)
return []
@_register_lowering(mgpu.AsyncStoreOp)
def _mgpu_async_store_op_lowering_rule(
launch_context: LaunchContext, store_op: mgpu.AsyncStoreOp
) -> Sequence[ir.Value]:
# TODO(dasenov): Add support for the remaining op properties.
launch_context.async_copy(
src_ref=store_op.source,
dst_ref=store_op.destination,
)
return []
@_register_lowering(arith.AddFOp)
def _arith_addf_op_lowering_rule(add: arith.AddFOp) -> Sequence[ir.Value]:
def _arith_addf_op_lowering_rule(
_: LaunchContext, add: arith.AddFOp
) -> Sequence[ir.Value]:
fragmented_array_lhs = _fragmented_array_from_ir(add.lhs)
fragmented_array_rhs = _fragmented_array_from_ir(add.rhs)
@ -242,7 +283,7 @@ def _arith_addf_op_lowering_rule(add: arith.AddFOp) -> Sequence[ir.Value]:
]
def lower_mgpu_dialect(module: ir.Module):
def lower_mgpu_dialect(module: ir.Module, launch_context: LaunchContext):
module.context.append_dialect_registry(mlir_interpreter.upstream_dialects)
module.context.load_all_available_dialects()
@ -257,7 +298,7 @@ def lower_mgpu_dialect(module: ir.Module):
if should_have_layout(op) and not has_any_layout_set(op):
raise ValueError(f"{op} is missing a layout and can not be lowered.")
new_results = lowering_rule(op)
new_results = lowering_rule(launch_context, op)
for old, new in zip(op.results, new_results):
old.replace_all_uses_with(new)

View File

@ -36,24 +36,28 @@ from jaxlib.mlir.dialects import scf
from jaxlib.mlir.dialects import vector
import numpy as np
from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401
# mypy: ignore-errors
WARPGROUP_SIZE: int = 128
DYNAMIC = -9223372036854775808
DYNAMIC32 = -2147483648
MBARRIER_BYTES = 8
# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes
def ptr_as_memref(ptr, memref_ty: ir.MemRefType):
def ptr_as_memref(ptr, memref_ty: ir.MemRefType, ptr_memory_space: int | None = None):
i64 = ir.IntegerType.get_signless(64)
rank = len(memref_ty.shape)
ptr_ty = "ptr" if ptr_memory_space is None else f"ptr<{ptr_memory_space}>"
if rank > 0:
desc_ty = ir.Type.parse(
f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>"
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>, array<{rank} x i64>)>"
)
else:
desc_ty = ir.Type.parse("!llvm.struct<(ptr, ptr, i64)>")
desc_ty = ir.Type.parse(f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64)>")
desc = llvm.UndefOp(desc_ty)
desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation
desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base
@ -321,6 +325,8 @@ def bytewidth(ty: ir.Type):
return ir.IntegerType(ty).width // 8
if ir.FloatType.isinstance(ty):
return ir.FloatType(ty).width // 8
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
return MBARRIER_BYTES
raise NotImplementedError(ty)
@ -743,6 +749,18 @@ class BarrierRef:
ptr, self.base_address, [self.offset], [DYNAMIC32], i64
)
def as_dialect_barrier(self) -> ir.Value:
if self.num_barriers > 1:
raise NotImplementedError(
f"Only BarrierRef with num_barriers=1 is suppored in the MLIR "
f"Mosaic GPU dialect, but got num_barriers={self.num_barriers}"
)
return ptr_as_memref(
self.base_address,
ir.MemRefType.get((), ir.Type.parse("!mosaic_gpu.barrier")),
ptr_memory_space=3,
)
@dataclasses.dataclass(frozen=True)
class CollectiveBarrierRef:
@ -997,19 +1015,21 @@ def warp_tree_reduce(value, op, group_size):
def memref_ptr(memref_arg, memory_space=None):
i64 = ir.IntegerType.get_signless(64)
memref_ty = ir.MemRefType(memref_arg.type)
if len(memref_ty.shape) == 0:
raise NotImplementedError
elem_bytewidth = bytewidth(memref_ty.element_type)
rank = len(memref_ty.shape)
# TODO: Read out memory space from memref
space = "" if memory_space is None else "<" + str(memory_space) + ">"
ptr_ty = ir.Type.parse("!llvm.ptr" + space)
desc_ty = ir.Type.parse(
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
f" array<{rank} x i64>)>"
)
if rank == 0:
desc_ty = ir.Type.parse(f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64)>")
else:
desc_ty = ir.Type.parse(
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
f" array<{rank} x i64>)>"
)
desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])
elem_bytewidth = bytewidth(memref_ty.element_type)
offset_elems = llvm.extractvalue(i64, desc, [2])
offset_bytes = llvm.mul(
offset_elems,

View File

@ -586,7 +586,7 @@ class DialectLoweringTest(MosaicGpuTest):
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1,
)
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
self.assertEmpty(
list(filter(is_mosaic_gpu_op, self.module.body.operations))
@ -604,7 +604,7 @@ class DialectLoweringTest(MosaicGpuTest):
arrival_count=1,
)
scf.yield_([])
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
self.assertEmpty(
list(filter(is_mosaic_gpu_op, if_op.then_block.operations))
@ -626,7 +626,7 @@ class DialectLoweringTest(MosaicGpuTest):
memref.copy(barriers_ref, barriers_ref)
self.assertTrue(self.module.operation.verify())
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
self.assertTrue(self.module.operation.verify())
all_mbarrier_init_shared_ops = find_if(
@ -654,7 +654,7 @@ class DialectLoweringTest(MosaicGpuTest):
with self.assertRaisesRegex(
ValueError, "missing a layout and can not be lowered"
):
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
def test_lowering_eliminates_layouts(self):
shape = (4, 128)
@ -670,7 +670,7 @@ class DialectLoweringTest(MosaicGpuTest):
)
])
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
all_ops_with_layouts = find_if(
self.module,
@ -691,7 +691,7 @@ class DialectLoweringTest(MosaicGpuTest):
vector.store(array, ref, [zero_index, zero_index])
mgpu.infer_layout(self.module)
mgpu.lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module, None)
all_loads = find_if(
self.module,

View File

@ -47,6 +47,8 @@ except ImportError:
z = 2
else:
import jax.experimental.mosaic.gpu as mgpu
from jax.experimental.mosaic.gpu import core
from jax.experimental.mosaic.gpu import launch_context
from jax.experimental.mosaic.gpu import utils as utils
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
@ -1937,6 +1939,103 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y)
def test_pointwise_kernel_with_tma(self):
def add(
ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value,
b_gmem_ref: ir.Value,
result_gmem_ref: ir.Value,
smem: list[ir.Value],
):
del ctx
a_smem_ref, b_smem_ref, result_smem_ref = smem[:3]
tma_barrier = smem[3]
memref_type = ir.MemRefType(a_gmem_ref.type)
shape = memref_type.shape
elt_type = memref_type.element_type
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
with utils.single_thread():
memref_bytes = utils.bytewidth(elt_type) # Also correct if rank == 0
for size in shape:
memref_bytes *= size
nvvm.mbarrier_arrive_expect_tx_shared(
tma_barrier.get_ptr(),
arith.constant(ir.IntegerType.get_signless(32), 2*memref_bytes),
)
# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=a_smem_ref,
barrier=tma_barrier.as_dialect_barrier(),
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
arrive=False,
)
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=b_smem_ref,
barrier=tma_barrier.as_dialect_barrier(),
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
arrive=False,
)
tma_barrier.wait()
zero_index = arith.constant(ir.IndexType.get(), 0)
# SMEM -> registers
ab_type = ir.VectorType.get(shape, elt_type)
a = vector.load(ab_type, a_smem_ref, [zero_index, zero_index])
b = vector.load(ab_type, b_smem_ref, [zero_index, zero_index])
# Computation
add = arith.addf(arith.addf(a, b), b)
# Registers -> SMEM
vector.store(add, result_smem_ref, [zero_index, zero_index])
# SMEM -> GMEM
mgpu_dialect.async_store(
source=result_smem_ref,
destination=result_gmem_ref,
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
)
nvvm.cp_async_bulk_wait_group(0)
utils.warpgroup_barrier()
dtype = jnp.bfloat16
shape = (128, 128)
jax_shape = jax.ShapeDtypeStruct(shape, dtype)
kernel = mgpu.as_gpu_kernel(
add,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(jax_shape, jax_shape),
out_shape=jax_shape,
smem_scratch_shape=[
jax_shape,
jax_shape,
jax_shape,
core.TMABarrier(1),
],
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
)
x = self.prng.uniform(-1, 1, shape).astype(dtype)
y = self.prng.uniform(-1, 1, shape).astype(dtype)
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
class UtilsTest(TestCase):
@parameterized.parameters(