mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic GPU] Add a lowering for simple async_load
and async_store
ops.
Only untransformed and unsliced loads/stores are supported for now. The rest will be a follow up. PiperOrigin-RevId: 708347442
This commit is contained in:
parent
01e8f889c2
commit
dad23fed09
@ -174,7 +174,6 @@ def _construct_smem_reftree(
|
||||
) -> RefTree:
|
||||
index = ir.IndexType.get()
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
ptr = ir.Type.parse("!llvm.ptr")
|
||||
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
||||
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(
|
||||
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
|
||||
@ -183,11 +182,19 @@ def _construct_smem_reftree(
|
||||
for ref_ty in flat_ref_tys:
|
||||
def get_barrier_ptr(num_barriers: int) -> ir.Value:
|
||||
nonlocal dynamic_smem_offset
|
||||
smem_base_ptr = utils.memref_ptr(dynamic_smem, memory_space=3)
|
||||
barrier_base_ptr = llvm.getelementptr(
|
||||
ptr, smem_base_ptr, [], [dynamic_smem_offset], i8
|
||||
workgroup_nvptx_address_space = (
|
||||
dialect_lowering.gpu_address_space_to_nvptx(
|
||||
gpu.AddressSpace.Workgroup
|
||||
)
|
||||
)
|
||||
dynamic_smem_offset += num_barriers * MBARRIER_BYTES
|
||||
smem_base_ptr = utils.memref_ptr(
|
||||
dynamic_smem, memory_space=workgroup_nvptx_address_space
|
||||
)
|
||||
smem_ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
|
||||
barrier_base_ptr = llvm.getelementptr(
|
||||
smem_ptr_ty, smem_base_ptr, [], [dynamic_smem_offset], i8
|
||||
)
|
||||
dynamic_smem_offset += num_barriers * utils.MBARRIER_BYTES
|
||||
return barrier_base_ptr
|
||||
match ref_ty:
|
||||
case Union(members):
|
||||
@ -227,9 +234,6 @@ def _construct_smem_reftree(
|
||||
return jax.tree.unflatten(smem_buffer_tree, smem_refs)
|
||||
|
||||
|
||||
MBARRIER_BYTES = 8
|
||||
|
||||
|
||||
def _smem_tree_size(smem_buffers: ShapeTree) -> int:
|
||||
leaves = jax.tree.leaves(
|
||||
smem_buffers, is_leaf=lambda x: isinstance(x, Union)
|
||||
@ -244,9 +248,9 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int:
|
||||
| ClusterBarrier(_, num_barriers=num_barriers)
|
||||
| Barrier(_, num_barriers=num_barriers)
|
||||
):
|
||||
if size % MBARRIER_BYTES:
|
||||
if size % utils.MBARRIER_BYTES:
|
||||
raise NotImplementedError("Misaligned barrier allocation")
|
||||
size += num_barriers * MBARRIER_BYTES
|
||||
size += num_barriers * utils.MBARRIER_BYTES
|
||||
case _:
|
||||
size += _count_buffer_bytes(l)
|
||||
return size
|
||||
@ -379,9 +383,11 @@ def _lower_as_gpu_kernel(
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
if kernel_name is None:
|
||||
kernel_name = getattr(body, "__name__", "anonymous")
|
||||
|
||||
# These are needed as nonlocal below.
|
||||
launch_ctx, scratch_arr = None, None
|
||||
with ir.InsertionPoint(module.body):
|
||||
_declare_runtime_functions()
|
||||
gmem_scratch_bytes = 0
|
||||
global_scratch = llvm.GlobalOp(
|
||||
ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet.
|
||||
"global_scratch",
|
||||
@ -390,7 +396,7 @@ def _lower_as_gpu_kernel(
|
||||
)
|
||||
@func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"mosaic_gpu_{kernel_name}")
|
||||
def main(token_ptr, buffers):
|
||||
nonlocal gmem_scratch_bytes
|
||||
nonlocal launch_ctx, scratch_arr
|
||||
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
|
||||
arg_refs = []
|
||||
for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
|
||||
@ -408,27 +414,40 @@ def _lower_as_gpu_kernel(
|
||||
with _launch(
|
||||
token, grid, cluster, block, scratch_arr, smem_scratch_shape,
|
||||
prof_spec, prof_buffer
|
||||
) as (launch_ctx, smem_refs):
|
||||
) as (_launch_ctx, smem_refs):
|
||||
nonlocal launch_ctx
|
||||
launch_ctx = _launch_ctx
|
||||
body(launch_ctx, *in_refs, *out_refs, smem_refs)
|
||||
gmem_scratch_bytes = launch_ctx.next_scratch_offset
|
||||
# Allocate and initialize the host buffer right before the launch.
|
||||
# Note that we couldn't do that before, because we had to run the body
|
||||
# to learn what the scratch contains.
|
||||
with ir.InsertionPoint(scratch_arr.owner):
|
||||
scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>")
|
||||
scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty)
|
||||
scratch_arr.set_type(scratch_arr_ty)
|
||||
for init_callback in launch_ctx.host_scratch_init:
|
||||
init_callback(scratch_alloc.result)
|
||||
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||
sym_tab = ir.SymbolTable(module.operation)
|
||||
sym_tab.insert(main.func_op)
|
||||
sym_tab.insert(global_scratch)
|
||||
module.operation.verify()
|
||||
|
||||
return module, out_shape, unwrap_output_tuple
|
||||
return module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr
|
||||
|
||||
|
||||
def _initialize_scratch(
|
||||
launch_ctx : launch_context.LaunchContext,
|
||||
scratch_arr: ir.Value,
|
||||
):
|
||||
"""
|
||||
Allocates and initializes the host buffer right before the launch. This needs
|
||||
to be done after all TMA descriptors have been recorded by the launch context.
|
||||
Only then we know what the scratch contains.
|
||||
|
||||
When using the Mosaic GPU dialect, the necessary information is known only
|
||||
after the lowering passes have run.
|
||||
"""
|
||||
with ir.InsertionPoint(scratch_arr.owner):
|
||||
gmem_scratch_bytes = launch_ctx.next_scratch_offset
|
||||
scratch_alloc_op = scratch_arr.owner.opview.addr.owner.opview
|
||||
scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>")
|
||||
scratch_alloc_op.elem_type = ir.TypeAttr.get(scratch_arr_ty)
|
||||
scratch_arr.set_type(scratch_arr_ty)
|
||||
for init_callback in launch_ctx.host_scratch_init:
|
||||
init_callback(scratch_alloc_op.result)
|
||||
|
||||
def _declare_runtime_functions():
|
||||
"""Declares the runtime functions that can be used by the generated code."""
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
@ -462,7 +481,7 @@ def as_gpu_kernel(
|
||||
elif not isinstance(in_shape, tuple):
|
||||
in_shape = (in_shape,)
|
||||
|
||||
module, out_shape, unwrap_output_tuple = (
|
||||
module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = (
|
||||
_lower_as_gpu_kernel(
|
||||
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
|
||||
module_name, kernel_name, prof_spec
|
||||
@ -473,7 +492,10 @@ def as_gpu_kernel(
|
||||
# Run Python lowering passes. The remaining passes will be run in C++ in
|
||||
# jax/jaxlib/mosaic/gpu/custom_call.cc
|
||||
layout_inference.infer_layout(module) # pytype: disable=attribute-error
|
||||
dialect_lowering.lower_mgpu_dialect(module) # pytype: disable=attribute-error
|
||||
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
|
||||
|
||||
_initialize_scratch(launch_ctx, scratch_arr)
|
||||
module.operation.verify()
|
||||
|
||||
expected_arg_treedef = jax.tree.structure(in_shape)
|
||||
def _check_args(*args):
|
||||
@ -530,6 +552,7 @@ def as_torch_gpu_kernel(
|
||||
cluster: tuple[int, int, int] = (1, 1, 1),
|
||||
module_name: str = "unknown",
|
||||
kernel_name: str | None = None,
|
||||
thread_semantics: ThreadSemantics = ThreadSemantics.Lane,
|
||||
):
|
||||
try:
|
||||
import torch
|
||||
@ -545,13 +568,22 @@ def as_torch_gpu_kernel(
|
||||
flat_out_types, out_treedef = jax.tree.flatten(out_shape)
|
||||
expected_arg_treedef = jax.tree.structure(in_shape)
|
||||
|
||||
module, out_shape, unwrap_output_tuple = (
|
||||
module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = (
|
||||
_lower_as_gpu_kernel(
|
||||
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
|
||||
module_name, kernel_name, prof_spec
|
||||
)
|
||||
)
|
||||
|
||||
if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None:
|
||||
# Run Python lowering passes. The remaining passes will be run in C++ in
|
||||
# jax/jaxlib/mosaic/gpu/custom_call.cc
|
||||
layout_inference.infer_layout(module) # pytype: disable=attribute-error
|
||||
dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error
|
||||
|
||||
_initialize_scratch(launch_ctx, scratch_arr)
|
||||
module.operation.verify()
|
||||
|
||||
# Get our hands on the compilation and unload functions
|
||||
try:
|
||||
import jax_plugins.xla_cuda12 as cuda_plugin
|
||||
|
@ -31,13 +31,14 @@ from jax._src.lib.mlir.dialects import vector
|
||||
import numpy as np
|
||||
|
||||
from .fragmented_array import FragmentedArray, WGStridedFragLayout
|
||||
from .launch_context import LaunchContext
|
||||
from .layouts import from_strided_fragmented_layout_attr, has_any_layout_set, is_strided_fragmented_layout, should_have_layout, to_strided_fragmented_layout_attr
|
||||
from .utils import c, ptr_as_memref, single_thread_predicate
|
||||
from .utils import BarrierRef, c, memref_ptr, ptr_as_memref, single_thread_predicate
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]]
|
||||
MlirLoweringRule = Callable[[LaunchContext, ir.Operation | ir.OpView], Sequence[ir.Value]]
|
||||
|
||||
|
||||
_lowerings: dict[str, MlirLoweringRule] = {}
|
||||
@ -88,6 +89,9 @@ def _fragmented_array_from_ir(
|
||||
[operand.type for operand in conversion_cast.operands],
|
||||
conversion_cast.results,
|
||||
)
|
||||
if not isinstance(converted_outputs, list):
|
||||
converted_outputs = [converted_outputs]
|
||||
|
||||
|
||||
reverse_conversion_cast = converted_outputs[0].owner.opview
|
||||
for attribute in conversion_cast.attributes:
|
||||
@ -138,6 +142,7 @@ def gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int:
|
||||
|
||||
@_register_lowering(InitializeBarrierOp)
|
||||
def _initialize_barrier_op_lowering_rule(
|
||||
_: LaunchContext,
|
||||
initialize_barrier_op: InitializeBarrierOp,
|
||||
) -> Sequence[ir.Value]:
|
||||
|
||||
@ -170,7 +175,7 @@ def _initialize_barrier_op_lowering_rule(
|
||||
|
||||
@_register_lowering(vector.LoadOp)
|
||||
def _vector_load_op_lowering_rule(
|
||||
vector_load_op: vector.LoadOp,
|
||||
_: LaunchContext, vector_load_op: vector.LoadOp
|
||||
) -> Sequence[ir.Value]:
|
||||
(out_layout_attr,) = cast(
|
||||
ir.ArrayAttr, vector_load_op.attributes["out_layouts"]
|
||||
@ -199,7 +204,7 @@ def _vector_load_op_lowering_rule(
|
||||
|
||||
@_register_lowering(vector.StoreOp)
|
||||
def _vector_store_op_lowering_rule(
|
||||
vector_store_op: vector.StoreOp,
|
||||
_: LaunchContext, vector_store_op: vector.StoreOp
|
||||
) -> Sequence[ir.Value]:
|
||||
|
||||
in_layout_attr, *_ = cast(
|
||||
@ -229,8 +234,44 @@ def _vector_store_op_lowering_rule(
|
||||
return []
|
||||
|
||||
|
||||
@_register_lowering(mgpu.AsyncLoadOp)
|
||||
def _mgpu_async_load_op_lowering_rule(
|
||||
launch_context: LaunchContext, load_op: mgpu.AsyncLoadOp
|
||||
) -> Sequence[ir.Value]:
|
||||
mem_space = gpu_address_space_to_nvptx(gpu.AddressSpace.Workgroup)
|
||||
|
||||
# TODO(dasenov): Add support for the remaining op properties.
|
||||
launch_context.async_copy(
|
||||
src_ref=load_op.source,
|
||||
dst_ref=load_op.destination,
|
||||
barrier=BarrierRef(
|
||||
base_address=memref_ptr(load_op.barrier, memory_space=mem_space),
|
||||
offset=c(0, ir.IntegerType.get_signless(64)),
|
||||
phases=None,
|
||||
num_barriers=1,
|
||||
),
|
||||
arrive=load_op.arrive,
|
||||
uniform=False,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
@_register_lowering(mgpu.AsyncStoreOp)
|
||||
def _mgpu_async_store_op_lowering_rule(
|
||||
launch_context: LaunchContext, store_op: mgpu.AsyncStoreOp
|
||||
) -> Sequence[ir.Value]:
|
||||
# TODO(dasenov): Add support for the remaining op properties.
|
||||
launch_context.async_copy(
|
||||
src_ref=store_op.source,
|
||||
dst_ref=store_op.destination,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
@_register_lowering(arith.AddFOp)
|
||||
def _arith_addf_op_lowering_rule(add: arith.AddFOp) -> Sequence[ir.Value]:
|
||||
def _arith_addf_op_lowering_rule(
|
||||
_: LaunchContext, add: arith.AddFOp
|
||||
) -> Sequence[ir.Value]:
|
||||
|
||||
fragmented_array_lhs = _fragmented_array_from_ir(add.lhs)
|
||||
fragmented_array_rhs = _fragmented_array_from_ir(add.rhs)
|
||||
@ -242,7 +283,7 @@ def _arith_addf_op_lowering_rule(add: arith.AddFOp) -> Sequence[ir.Value]:
|
||||
]
|
||||
|
||||
|
||||
def lower_mgpu_dialect(module: ir.Module):
|
||||
def lower_mgpu_dialect(module: ir.Module, launch_context: LaunchContext):
|
||||
module.context.append_dialect_registry(mlir_interpreter.upstream_dialects)
|
||||
module.context.load_all_available_dialects()
|
||||
|
||||
@ -257,7 +298,7 @@ def lower_mgpu_dialect(module: ir.Module):
|
||||
if should_have_layout(op) and not has_any_layout_set(op):
|
||||
raise ValueError(f"{op} is missing a layout and can not be lowered.")
|
||||
|
||||
new_results = lowering_rule(op)
|
||||
new_results = lowering_rule(launch_context, op)
|
||||
|
||||
for old, new in zip(op.results, new_results):
|
||||
old.replace_all_uses_with(new)
|
||||
|
@ -36,24 +36,28 @@ from jaxlib.mlir.dialects import scf
|
||||
from jaxlib.mlir.dialects import vector
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401
|
||||
|
||||
# mypy: ignore-errors
|
||||
|
||||
WARPGROUP_SIZE: int = 128
|
||||
DYNAMIC = -9223372036854775808
|
||||
DYNAMIC32 = -2147483648
|
||||
MBARRIER_BYTES = 8
|
||||
|
||||
# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes
|
||||
|
||||
|
||||
def ptr_as_memref(ptr, memref_ty: ir.MemRefType):
|
||||
def ptr_as_memref(ptr, memref_ty: ir.MemRefType, ptr_memory_space: int | None = None):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
rank = len(memref_ty.shape)
|
||||
ptr_ty = "ptr" if ptr_memory_space is None else f"ptr<{ptr_memory_space}>"
|
||||
if rank > 0:
|
||||
desc_ty = ir.Type.parse(
|
||||
f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>"
|
||||
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>, array<{rank} x i64>)>"
|
||||
)
|
||||
else:
|
||||
desc_ty = ir.Type.parse("!llvm.struct<(ptr, ptr, i64)>")
|
||||
desc_ty = ir.Type.parse(f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64)>")
|
||||
desc = llvm.UndefOp(desc_ty)
|
||||
desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation
|
||||
desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base
|
||||
@ -321,6 +325,8 @@ def bytewidth(ty: ir.Type):
|
||||
return ir.IntegerType(ty).width // 8
|
||||
if ir.FloatType.isinstance(ty):
|
||||
return ir.FloatType(ty).width // 8
|
||||
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
|
||||
return MBARRIER_BYTES
|
||||
raise NotImplementedError(ty)
|
||||
|
||||
|
||||
@ -743,6 +749,18 @@ class BarrierRef:
|
||||
ptr, self.base_address, [self.offset], [DYNAMIC32], i64
|
||||
)
|
||||
|
||||
def as_dialect_barrier(self) -> ir.Value:
|
||||
if self.num_barriers > 1:
|
||||
raise NotImplementedError(
|
||||
f"Only BarrierRef with num_barriers=1 is suppored in the MLIR "
|
||||
f"Mosaic GPU dialect, but got num_barriers={self.num_barriers}"
|
||||
)
|
||||
return ptr_as_memref(
|
||||
self.base_address,
|
||||
ir.MemRefType.get((), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
ptr_memory_space=3,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CollectiveBarrierRef:
|
||||
@ -997,19 +1015,21 @@ def warp_tree_reduce(value, op, group_size):
|
||||
def memref_ptr(memref_arg, memory_space=None):
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
memref_ty = ir.MemRefType(memref_arg.type)
|
||||
if len(memref_ty.shape) == 0:
|
||||
raise NotImplementedError
|
||||
elem_bytewidth = bytewidth(memref_ty.element_type)
|
||||
rank = len(memref_ty.shape)
|
||||
# TODO: Read out memory space from memref
|
||||
space = "" if memory_space is None else "<" + str(memory_space) + ">"
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr" + space)
|
||||
desc_ty = ir.Type.parse(
|
||||
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
|
||||
f" array<{rank} x i64>)>"
|
||||
)
|
||||
if rank == 0:
|
||||
desc_ty = ir.Type.parse(f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64)>")
|
||||
else:
|
||||
desc_ty = ir.Type.parse(
|
||||
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
|
||||
f" array<{rank} x i64>)>"
|
||||
)
|
||||
desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
|
||||
aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])
|
||||
|
||||
elem_bytewidth = bytewidth(memref_ty.element_type)
|
||||
offset_elems = llvm.extractvalue(i64, desc, [2])
|
||||
offset_bytes = llvm.mul(
|
||||
offset_elems,
|
||||
|
@ -586,7 +586,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=1,
|
||||
)
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
|
||||
self.assertEmpty(
|
||||
list(filter(is_mosaic_gpu_op, self.module.body.operations))
|
||||
@ -604,7 +604,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
arrival_count=1,
|
||||
)
|
||||
scf.yield_([])
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
|
||||
self.assertEmpty(
|
||||
list(filter(is_mosaic_gpu_op, if_op.then_block.operations))
|
||||
@ -626,7 +626,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
memref.copy(barriers_ref, barriers_ref)
|
||||
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
|
||||
all_mbarrier_init_shared_ops = find_if(
|
||||
@ -654,7 +654,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "missing a layout and can not be lowered"
|
||||
):
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
|
||||
def test_lowering_eliminates_layouts(self):
|
||||
shape = (4, 128)
|
||||
@ -670,7 +670,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
)
|
||||
])
|
||||
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
|
||||
all_ops_with_layouts = find_if(
|
||||
self.module,
|
||||
@ -691,7 +691,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
vector.store(array, ref, [zero_index, zero_index])
|
||||
|
||||
mgpu.infer_layout(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
|
||||
all_loads = find_if(
|
||||
self.module,
|
||||
|
@ -47,6 +47,8 @@ except ImportError:
|
||||
z = 2
|
||||
else:
|
||||
import jax.experimental.mosaic.gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import core
|
||||
from jax.experimental.mosaic.gpu import launch_context
|
||||
from jax.experimental.mosaic.gpu import utils as utils
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.mosaic.gpu.utils import * # noqa: F403
|
||||
@ -1937,6 +1939,103 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
|
||||
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y)
|
||||
|
||||
def test_pointwise_kernel_with_tma(self):
|
||||
def add(
|
||||
ctx: launch_context.LaunchContext,
|
||||
a_gmem_ref: ir.Value,
|
||||
b_gmem_ref: ir.Value,
|
||||
result_gmem_ref: ir.Value,
|
||||
smem: list[ir.Value],
|
||||
):
|
||||
del ctx
|
||||
a_smem_ref, b_smem_ref, result_smem_ref = smem[:3]
|
||||
tma_barrier = smem[3]
|
||||
memref_type = ir.MemRefType(a_gmem_ref.type)
|
||||
shape = memref_type.shape
|
||||
elt_type = memref_type.element_type
|
||||
|
||||
zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
|
||||
|
||||
with utils.single_thread():
|
||||
memref_bytes = utils.bytewidth(elt_type) # Also correct if rank == 0
|
||||
for size in shape:
|
||||
memref_bytes *= size
|
||||
nvvm.mbarrier_arrive_expect_tx_shared(
|
||||
tma_barrier.get_ptr(),
|
||||
arith.constant(ir.IntegerType.get_signless(32), 2*memref_bytes),
|
||||
)
|
||||
|
||||
# GMEM -> SMEM
|
||||
mgpu_dialect.async_load(
|
||||
source=a_gmem_ref,
|
||||
destination=a_smem_ref,
|
||||
barrier=tma_barrier.as_dialect_barrier(),
|
||||
indices=[zero_i32, zero_i32],
|
||||
slice_lengths=shape,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
arrive=False,
|
||||
)
|
||||
mgpu_dialect.async_load(
|
||||
source=b_gmem_ref,
|
||||
destination=b_smem_ref,
|
||||
barrier=tma_barrier.as_dialect_barrier(),
|
||||
indices=[zero_i32, zero_i32],
|
||||
slice_lengths=shape,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
arrive=False,
|
||||
)
|
||||
|
||||
tma_barrier.wait()
|
||||
|
||||
zero_index = arith.constant(ir.IndexType.get(), 0)
|
||||
|
||||
# SMEM -> registers
|
||||
ab_type = ir.VectorType.get(shape, elt_type)
|
||||
a = vector.load(ab_type, a_smem_ref, [zero_index, zero_index])
|
||||
b = vector.load(ab_type, b_smem_ref, [zero_index, zero_index])
|
||||
|
||||
# Computation
|
||||
add = arith.addf(arith.addf(a, b), b)
|
||||
|
||||
# Registers -> SMEM
|
||||
vector.store(add, result_smem_ref, [zero_index, zero_index])
|
||||
|
||||
# SMEM -> GMEM
|
||||
mgpu_dialect.async_store(
|
||||
source=result_smem_ref,
|
||||
destination=result_gmem_ref,
|
||||
indices=[zero_i32, zero_i32],
|
||||
slice_lengths=shape,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
nvvm.cp_async_bulk_wait_group(0)
|
||||
utils.warpgroup_barrier()
|
||||
|
||||
dtype = jnp.bfloat16
|
||||
shape = (128, 128)
|
||||
jax_shape = jax.ShapeDtypeStruct(shape, dtype)
|
||||
kernel = mgpu.as_gpu_kernel(
|
||||
add,
|
||||
grid=(1, 1, 1),
|
||||
block=(128, 1, 1),
|
||||
in_shape=(jax_shape, jax_shape),
|
||||
out_shape=jax_shape,
|
||||
smem_scratch_shape=[
|
||||
jax_shape,
|
||||
jax_shape,
|
||||
jax_shape,
|
||||
core.TMABarrier(1),
|
||||
],
|
||||
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
|
||||
)
|
||||
|
||||
x = self.prng.uniform(-1, 1, shape).astype(dtype)
|
||||
y = self.prng.uniform(-1, 1, shape).astype(dtype)
|
||||
|
||||
self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y)
|
||||
|
||||
|
||||
class UtilsTest(TestCase):
|
||||
@parameterized.parameters(
|
||||
|
Loading…
x
Reference in New Issue
Block a user