mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
715 lines
24 KiB
Python
715 lines
24 KiB
Python
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
from collections.abc import Callable
|
|
import contextlib
|
|
import ctypes
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import os
|
|
import pathlib
|
|
import subprocess
|
|
import tempfile
|
|
import time
|
|
from typing import Any, Generic, Sequence, TypeVar
|
|
|
|
import jax
|
|
from jax._src import config
|
|
from jax._src import core as jax_core
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.lib import xla_client
|
|
from jaxlib.mlir import ir
|
|
from jaxlib.mlir.dialects import arith
|
|
from jaxlib.mlir.dialects import builtin
|
|
from jaxlib.mlir.dialects import func
|
|
from jaxlib.mlir.dialects import gpu
|
|
from jaxlib.mlir.dialects import llvm
|
|
from jaxlib.mlir.dialects import memref
|
|
from jaxlib.mlir.dialects import nvgpu
|
|
from jaxlib.mlir.dialects import nvvm
|
|
from jaxlib.mlir.passmanager import PassManager
|
|
import numpy as np
|
|
|
|
from . import dsl as mgpu
|
|
from . import profiler
|
|
from . import utils
|
|
|
|
# mypy: ignore-errors
|
|
|
|
# MLIR can't find libdevice unless we point it to the CUDA path
|
|
# TODO(apaszke): Unify with jax._src.lib.cuda_path
|
|
CUDA_ROOT = "/usr/local/cuda"
|
|
if os.environ.get("CUDA_ROOT") is None:
|
|
os.environ["CUDA_ROOT"] = CUDA_ROOT
|
|
else:
|
|
CUDA_ROOT = os.environ["CUDA_ROOT"]
|
|
|
|
PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas")
|
|
NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm")
|
|
|
|
TMA_DESCRIPTOR_BYTES = 128
|
|
TMA_DESCRIPTOR_ALIGNMENT = 64
|
|
|
|
|
|
c = mgpu.c # This is too common to fully qualify.
|
|
|
|
|
|
RUNTIME_PATH = None
|
|
try:
|
|
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
|
|
|
|
RUNTIME_PATH = (
|
|
pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent
|
|
/ "libmosaic_gpu_runtime.so"
|
|
)
|
|
except ImportError:
|
|
pass
|
|
|
|
if RUNTIME_PATH and RUNTIME_PATH.exists():
|
|
# Set this so that the custom call can find it
|
|
os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH)
|
|
|
|
|
|
mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p")
|
|
mosaic_gpu_p.multiple_results = True
|
|
|
|
|
|
@mosaic_gpu_p.def_abstract_eval
|
|
def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes):
|
|
del module, gmem_scratch_bytes # Unused.
|
|
return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types]
|
|
|
|
# TODO(apaszke): Implement a proper system for managing kernel lifetimes
|
|
kernel_idx = itertools.count()
|
|
|
|
def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes):
|
|
del out_types # Unused.
|
|
idx_bytes = next(kernel_idx).to_bytes(8, byteorder="little")
|
|
op = mlir.custom_call(
|
|
"mosaic_gpu",
|
|
result_types=[
|
|
*(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out),
|
|
mlir.aval_to_ir_type(
|
|
jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8)
|
|
),
|
|
],
|
|
operands=args,
|
|
backend_config=idx_bytes
|
|
+ module.operation.get_asm(binary=True, enable_debug_info=True),
|
|
)
|
|
return op.results[:-1] # Skip the scratch space.
|
|
|
|
mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda")
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class MemRefTransform:
|
|
def apply(self, ref: ir.Value) -> ir.Value:
|
|
raise NotImplementedError("Subclasses should override this method")
|
|
|
|
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
|
|
raise NotImplementedError("Subclasses should override this method")
|
|
|
|
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
|
|
raise NotImplementedError("Subclasses should override this method")
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TileTransform(MemRefTransform):
|
|
"""Tiles a suffix of memref dimensions.
|
|
|
|
For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32),
|
|
the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with
|
|
the tile shape, and the size of tiled dimensions is divided by the tile size.
|
|
This is especially useful for swizzled WGMMA, which expect tiled layouts in
|
|
shared memory.
|
|
"""
|
|
tiling: tuple[int, ...]
|
|
|
|
def apply(self, ref: ir.Value) -> ir.Value:
|
|
untiled_rank = ir.MemRefType(ref.type).rank
|
|
tiling_rank = len(self.tiling)
|
|
tiled_rank = untiled_rank + tiling_rank
|
|
for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]):
|
|
ref = mgpu.memref_unfold(ref, d, (None, t))
|
|
permutation = (
|
|
*range(untiled_rank - tiling_rank),
|
|
*range(untiled_rank - tiling_rank, tiled_rank, 2),
|
|
*range(untiled_rank - tiling_rank + 1, tiled_rank, 2),
|
|
)
|
|
return mgpu.memref_transpose(ref, permutation)
|
|
|
|
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
|
|
index = ir.IndexType.get()
|
|
tiling_rank = len(self.tiling)
|
|
return (
|
|
*idx[:-tiling_rank],
|
|
*(
|
|
arith.divui(i, c(t, index))
|
|
for i, t in zip(idx[-tiling_rank:], self.tiling)
|
|
),
|
|
*(
|
|
arith.remui(i, c(t, index))
|
|
for i, t in zip(idx[-tiling_rank:], self.tiling)
|
|
),
|
|
)
|
|
|
|
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
|
|
# Note that this also checks that tiled dims are not squeezed. Their slice
|
|
# size would be 1 if so.
|
|
tiling_rank = len(self.tiling)
|
|
for size, tile_size in zip(shape[-tiling_rank:], self.tiling):
|
|
if size % tile_size:
|
|
raise ValueError(
|
|
f"Expected GMEM slice shape {shape} suffix to be a multiple"
|
|
f" of tiling {self.tiling}"
|
|
)
|
|
return (
|
|
*shape[:-tiling_rank],
|
|
*(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)),
|
|
*self.tiling,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TransposeTransform(MemRefTransform):
|
|
"""Transposes memref dimensions."""
|
|
permutation: tuple[int, ...]
|
|
|
|
def __post_init__(self):
|
|
if len(self.permutation) != len(set(self.permutation)):
|
|
raise ValueError("Permutation must be a permutation")
|
|
|
|
def apply(self, ref: ir.Value) -> ir.Value:
|
|
return mgpu.memref_transpose(ref, self.permutation)
|
|
|
|
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
|
|
return tuple(idx[p] for p in self.permutation)
|
|
|
|
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
|
|
return tuple(shape[p] for p in self.permutation)
|
|
|
|
|
|
OnDeviceProfiler = profiler.OnDeviceProfiler
|
|
|
|
|
|
@dataclasses.dataclass()
|
|
class LaunchContext:
|
|
launch_op: gpu.LaunchOp
|
|
profiler: OnDeviceProfiler | None = None
|
|
next_scratch_offset: int = 0
|
|
host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field(
|
|
default_factory=list, init=False
|
|
)
|
|
tma_descriptors: dict[
|
|
tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]],
|
|
ir.Value,
|
|
] = dataclasses.field(default_factory=dict, init=False)
|
|
|
|
@contextlib.contextmanager
|
|
def named_region(self, *args, **kwargs):
|
|
if self.profiler is not None:
|
|
with self.profiler.record(*args, **kwargs):
|
|
yield
|
|
else:
|
|
yield
|
|
|
|
def _alloc_scratch(
|
|
self,
|
|
size: int,
|
|
alignment: int | None = None,
|
|
host_init: Callable[[ir.Value], None] = lambda _: None,
|
|
device_init: Callable[[ir.Value], Any] = lambda x: x,
|
|
) -> ir.Value:
|
|
"""Allocates a GMEM scratch buffer.
|
|
|
|
The buffer is initialized on the host and then copied to GMEM before the
|
|
kernel launch.
|
|
"""
|
|
i8 = ir.IntegerType.get_signless(8)
|
|
ptr_ty = ir.Type.parse("!llvm.ptr")
|
|
if alignment is None:
|
|
alignment = size
|
|
if self.next_scratch_offset % alignment:
|
|
raise NotImplementedError # TODO(apaszke): Pad to match alignment
|
|
alloc_base = self.next_scratch_offset
|
|
self.next_scratch_offset += size
|
|
def host_init_wrapped(host_ptr):
|
|
host_init(
|
|
llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8)
|
|
)
|
|
self.host_scratch_init.append(host_init_wrapped)
|
|
|
|
with ir.InsertionPoint.at_block_begin(self.launch_op.body.blocks[0]):
|
|
ptr_ty = ir.Type.parse("!llvm.ptr")
|
|
const_ptr_ty = ir.Type.parse("!llvm.ptr<4>")
|
|
gmem_scratch_ptr = llvm.call_intrinsic(
|
|
ptr_ty,
|
|
"llvm.nvvm.ptr.constant.to.gen.p0.p4",
|
|
[llvm.mlir_addressof(const_ptr_ty, "global_scratch")],
|
|
)
|
|
return device_init(llvm.getelementptr(
|
|
ptr_ty, gmem_scratch_ptr, [], [alloc_base], i8
|
|
))
|
|
|
|
def _get_tma_desc(
|
|
self,
|
|
ref,
|
|
gmem_transform: tuple[MemRefTransform, ...],
|
|
transformed_slice_shape: tuple[int, ...],
|
|
swizzle: int | None,
|
|
):
|
|
index = ir.IndexType.get()
|
|
ref_ty = ir.MemRefType(ref.type)
|
|
tma_desc_key = (ref, transformed_slice_shape, swizzle, gmem_transform)
|
|
if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None:
|
|
swizzle_str = f"swizzle_{swizzle}b" if swizzle is not None else "none"
|
|
default_tensor_map_attrs = dict(
|
|
swizzle=swizzle_str, l2promo="none", oob="zero", interleave="none"
|
|
)
|
|
tensor_map_ty = utils.get_tensormap_descriptor(
|
|
tensor=(
|
|
f"memref<{'x'.join(map(str, transformed_slice_shape))}x{ref_ty.element_type}, 3>"
|
|
),
|
|
**default_tensor_map_attrs,
|
|
)
|
|
with ir.InsertionPoint(self.launch_op):
|
|
for t in gmem_transform:
|
|
ref = t.apply(ref)
|
|
ref_ty = ir.MemRefType(ref.type)
|
|
|
|
i64 = ir.IntegerType.get_signless(64)
|
|
ptr_ty = ir.Type.parse("!llvm.ptr")
|
|
def init_tma_desc(host_ptr):
|
|
_, offset, *sizes_and_strides = memref.extract_strided_metadata(ref)
|
|
aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref)
|
|
as_i64 = lambda i: arith.index_cast(i64, i)
|
|
alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx))
|
|
llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings...
|
|
base_ptr = llvm.getelementptr(
|
|
ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type,
|
|
)
|
|
rank = ref_ty.rank
|
|
assert rank * 2 == len(sizes_and_strides)
|
|
args = [
|
|
host_ptr,
|
|
base_ptr,
|
|
c(utils.bytewidth(ref_ty.element_type), i64),
|
|
c(rank, i64),
|
|
utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]),
|
|
utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]),
|
|
c(0 if swizzle is None else swizzle, i64),
|
|
utils.pack_array([c(v, i64) for v in transformed_slice_shape]),
|
|
]
|
|
func.call([], "mosaic_gpu_init_tma_desc", args)
|
|
def cast_tma_desc(device_ptr):
|
|
# TODO(apaszke): Investigate why prefetching can cause launch failures
|
|
# nvvm.prefetch_tensormap(device_ptr)
|
|
return builtin.unrealized_conversion_cast(
|
|
[tensor_map_ty], [device_ptr]
|
|
)
|
|
tma_desc = self._alloc_scratch(
|
|
TMA_DESCRIPTOR_BYTES,
|
|
alignment=TMA_DESCRIPTOR_ALIGNMENT,
|
|
host_init=init_tma_desc,
|
|
device_init=cast_tma_desc,
|
|
)
|
|
self.tma_descriptors[tma_desc_key] = tma_desc
|
|
return tma_desc
|
|
|
|
def async_copy(
|
|
self,
|
|
*,
|
|
src_ref,
|
|
dst_ref,
|
|
gmem_slice: Any = (),
|
|
gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (),
|
|
barrier: mgpu.Barrier | None = None,
|
|
swizzle: int | None = None,
|
|
arrive: bool | None = None,
|
|
uniform: bool = True,
|
|
):
|
|
index = ir.IndexType.get()
|
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
|
src_ref_ty = ir.MemRefType(src_ref.type)
|
|
dst_ref_ty = ir.MemRefType(dst_ref.type)
|
|
element_type = src_ref_ty.element_type
|
|
if element_type != dst_ref_ty.element_type:
|
|
raise ValueError(
|
|
f"Expected same element type, got {element_type} and"
|
|
f" {dst_ref_ty.element_type}"
|
|
)
|
|
if not isinstance(gmem_transform, tuple):
|
|
gmem_transform = (gmem_transform,)
|
|
|
|
if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem:
|
|
gmem_ref, smem_ref = src_ref, dst_ref
|
|
if barrier is None:
|
|
raise ValueError("Barriers are required for GMEM -> SMEM copies")
|
|
if arrive is None:
|
|
arrive = True # Arrive by default
|
|
elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None:
|
|
gmem_ref, smem_ref = dst_ref, src_ref
|
|
if barrier is not None:
|
|
raise ValueError("Barriers are unsupported for SMEM -> GMEM copies")
|
|
if arrive is not None:
|
|
raise ValueError("arrive is unsupported for SMEM -> GMEM copies")
|
|
else:
|
|
raise ValueError("Only SMEM <-> GMEM copies supported")
|
|
# TODO(apaszke): This is a very approximate check. Improve it!
|
|
expected_name = "builtin.unrealized_conversion_cast"
|
|
if (
|
|
gmem_ref.owner is None
|
|
or gmem_ref.owner.opview.OPERATION_NAME != expected_name
|
|
):
|
|
raise ValueError("GMEM reference in async_copy must be a kernel argument")
|
|
|
|
base_indices, slice_shape, is_squeezed = utils.parse_indices(
|
|
gmem_slice, ir.MemRefType(gmem_ref.type).shape
|
|
)
|
|
dyn_base_indices = tuple(
|
|
c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices
|
|
)
|
|
slice_shape = tuple(slice_shape)
|
|
for t in gmem_transform:
|
|
dyn_base_indices = t.transform_index(dyn_base_indices)
|
|
slice_shape = t.transform_shape(slice_shape)
|
|
for dim, squeezed in enumerate(is_squeezed):
|
|
if squeezed:
|
|
smem_ref = mgpu.memref_unsqueeze(smem_ref, dim)
|
|
smem_ref_ty = ir.MemRefType(smem_ref.type)
|
|
|
|
if slice_shape != tuple(smem_ref_ty.shape):
|
|
raise ValueError(
|
|
"Expected the SMEM reference to have the same shape as the tiled"
|
|
f" slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
|
|
)
|
|
tma_desc = self._get_tma_desc(
|
|
gmem_ref, gmem_transform, slice_shape, swizzle,
|
|
)
|
|
|
|
# nvgpu TMA instructions expect reversed indices...
|
|
rev_dyn_based_indices = reversed(dyn_base_indices)
|
|
|
|
uniform_ctx = (
|
|
functools.partial(mgpu.single_thread, per_block=False)
|
|
if uniform
|
|
else contextlib.nullcontext
|
|
)
|
|
|
|
if gmem_ref is src_ref:
|
|
with uniform_ctx():
|
|
assert barrier is not None # for pytype
|
|
barrier_group = barrier.barrier_array.value
|
|
barrier_idx = barrier.offset
|
|
if arrive:
|
|
slice_bytes = c(
|
|
np.prod(slice_shape) * mgpu.bytewidth(element_type), index
|
|
)
|
|
nvgpu.mbarrier_arrive_expect_tx(
|
|
barrier_group, slice_bytes, barrier_idx
|
|
)
|
|
nvgpu.tma_async_load(
|
|
smem_ref, barrier_group, tma_desc, rev_dyn_based_indices, barrier_idx
|
|
)
|
|
else:
|
|
with uniform_ctx():
|
|
nvgpu.tma_async_store(smem_ref, tma_desc, rev_dyn_based_indices)
|
|
nvvm.cp_async_bulk_commit_group()
|
|
|
|
def await_async_copy(
|
|
self, allow_groups: int, await_read_only: bool = False
|
|
):
|
|
nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only)
|
|
gpu.barrier() # Groups are supposedly tracked per-thread
|
|
|
|
|
|
# ShapeTrees currently can not contain unions.
|
|
ShapeTree = Any
|
|
RefTree = Any
|
|
T = TypeVar('T')
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Union(Generic[T]):
|
|
members: Sequence[T]
|
|
|
|
|
|
def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int:
|
|
return np.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize
|
|
|
|
|
|
def _construct_smem_reftree(
|
|
dynamic_smem: ir.Value, smem_buffers: ShapeTree) -> RefTree:
|
|
index = ir.IndexType.get()
|
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
|
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(smem_buffers)
|
|
smem_refs = []
|
|
dynamic_smem_offset = 0
|
|
for ref_ty in flat_ref_tys:
|
|
mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype)
|
|
tile_smem = memref.view(
|
|
ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem),
|
|
dynamic_smem, c(dynamic_smem_offset, index), [],
|
|
)
|
|
dynamic_smem_offset += _count_buffer_bytes(ref_ty)
|
|
smem_refs.append(tile_smem)
|
|
return jax.tree.unflatten(smem_buffer_tree, smem_refs)
|
|
|
|
|
|
# TODO(apaszke): Inline this
|
|
@contextlib.contextmanager
|
|
def _launch(
|
|
token,
|
|
grid,
|
|
block,
|
|
smem_buffers: ShapeTree | Union[ShapeTree],
|
|
profiler_spec: profiler.ProfilerSpec | None = None,
|
|
maybe_prof_buffer: ir.Value | None = None,
|
|
):
|
|
if (profiler_spec is None) != (maybe_prof_buffer is None):
|
|
raise ValueError
|
|
index = ir.IndexType.get()
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
i8 = ir.IntegerType.get_signless(8)
|
|
grid_vals = [c(i, index) for i in grid]
|
|
block_vals = [c(i, index) for i in block]
|
|
|
|
if isinstance(smem_buffers, Union):
|
|
smem_disjoint_live_buffers_collections = smem_buffers.members
|
|
compute_smem_bytes = max(
|
|
sum(_count_buffer_bytes(l) for l in jax.tree.leaves(s))
|
|
for s in smem_buffers.members)
|
|
else:
|
|
smem_disjoint_live_buffers_collections = [smem_buffers]
|
|
compute_smem_bytes = sum(
|
|
_count_buffer_bytes(l) for l in jax.tree.leaves(smem_buffers))
|
|
|
|
smem_bytes = compute_smem_bytes
|
|
if profiler_spec is not None:
|
|
smem_bytes += profiler_spec.smem_bytes(block=block)
|
|
|
|
# TODO(cperivol): Query the shared memory size programmatically.
|
|
if smem_bytes > 228 * 1024:
|
|
raise ValueError(f"Mosaic GPU kernel exceeds available shared memory {smem_bytes=} > 228000")
|
|
launch_op = gpu.LaunchOp(
|
|
token.type, [token], *grid_vals, *block_vals,
|
|
dynamicSharedMemorySize=c(smem_bytes, i32))
|
|
launch_op.body.blocks.append(*([index] * 12)) # Append an empty block
|
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
|
with ir.InsertionPoint(launch_op.body.blocks[0]):
|
|
dynamic_smem = gpu.dynamic_shared_memory(
|
|
ir.MemRefType.get(
|
|
(ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem
|
|
)
|
|
)
|
|
smem_ref_trees = []
|
|
|
|
for smem_live_buffers_collection in smem_disjoint_live_buffers_collections:
|
|
smem_ref_tree = _construct_smem_reftree(
|
|
dynamic_smem, smem_live_buffers_collection)
|
|
smem_ref_trees.append(smem_ref_tree)
|
|
|
|
if profiler_spec:
|
|
prof_smem = memref.view(
|
|
ir.MemRefType.get(
|
|
(profiler_spec.smem_i32_elements(block=block),),
|
|
i32, memory_space=smem,
|
|
),
|
|
dynamic_smem, c(compute_smem_bytes, index), [],
|
|
)
|
|
prof = profiler.OnDeviceProfiler(
|
|
profiler_spec, prof_smem, maybe_prof_buffer
|
|
)
|
|
else:
|
|
prof = None
|
|
|
|
if isinstance(smem_buffers, Union):
|
|
smem_ref_tree: Union[RefTree] = Union(smem_ref_trees)
|
|
else:
|
|
smem_ref_tree: RefTree = smem_ref_trees[0] if smem_ref_trees else []
|
|
|
|
yield LaunchContext(launch_op, prof), smem_ref_tree
|
|
if prof is not None:
|
|
prof.finalize(grid=grid, block=block)
|
|
gpu.terminator()
|
|
|
|
|
|
def _lower_as_gpu_kernel(
|
|
body,
|
|
grid: tuple[int, ...],
|
|
block: tuple[int, ...],
|
|
in_shapes: tuple[Any, ...],
|
|
out_shape,
|
|
smem_scratch_shape: ShapeTree | Union[ShapeTree],
|
|
prof_spec: profiler.ProfilerSpec | None = None,
|
|
):
|
|
ptr_ty = ir.Type.parse("!llvm.ptr")
|
|
token_ty = ir.Type.parse("!gpu.async.token")
|
|
i8 = ir.IntegerType.get_signless(8)
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
i64 = ir.IntegerType.get_signless(64)
|
|
|
|
def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
|
|
return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
|
|
|
|
in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes]
|
|
|
|
unwrap_output_tuple = False
|
|
if isinstance(out_shape, list):
|
|
out_shape = tuple(out_shape)
|
|
elif not isinstance(out_shape, tuple):
|
|
out_shape = (out_shape,)
|
|
unwrap_output_tuple = True
|
|
out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape]
|
|
if prof_spec is not None:
|
|
out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block))
|
|
out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block))
|
|
|
|
module = ir.Module.create()
|
|
with ir.InsertionPoint(module.body):
|
|
_declare_runtime_functions()
|
|
gmem_scratch_bytes = 0
|
|
global_scratch = llvm.GlobalOp(
|
|
ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet.
|
|
"global_scratch",
|
|
ir.Attribute.parse("#llvm.linkage<external>"),
|
|
addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory.
|
|
)
|
|
@func.FuncOp.from_py_func(ptr_ty, ptr_ty, ptr_ty)
|
|
def main(token_ptr, buffers, gmem_scratch_ptr):
|
|
nonlocal gmem_scratch_bytes
|
|
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
|
|
arg_refs = []
|
|
for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
|
|
ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty))
|
|
arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty)))
|
|
in_refs = arg_refs[:len(in_ref_tys)]
|
|
out_refs = arg_refs[len(in_ref_tys):]
|
|
prof_buffer = out_refs.pop() if prof_spec is not None else None
|
|
with _launch(
|
|
token, grid, block, smem_scratch_shape,
|
|
prof_spec, prof_buffer
|
|
) as (launch_ctx, smem_refs):
|
|
body(launch_ctx, *in_refs, *out_refs, smem_refs)
|
|
gmem_scratch_bytes = launch_ctx.next_scratch_offset
|
|
# Allocate and initialize the host buffer right before the launch.
|
|
# Note that we couldn't do that before, because we had to run the body
|
|
# to learn what the scratch contains.
|
|
with ir.InsertionPoint(launch_ctx.launch_op):
|
|
host_scratch_ptr = llvm.alloca(ptr_ty, c(gmem_scratch_bytes, i64), i8)
|
|
for init_callback in launch_ctx.host_scratch_init:
|
|
init_callback(host_scratch_ptr)
|
|
global_scratch.global_type = ir.TypeAttr.get(
|
|
ir.Type.parse("!llvm.array<" + str(gmem_scratch_bytes) + " x i8>")
|
|
)
|
|
func.call(
|
|
[],
|
|
"mosaic_gpu_memcpy_async_h2d",
|
|
[
|
|
gmem_scratch_ptr,
|
|
host_scratch_ptr,
|
|
c(gmem_scratch_bytes, i64),
|
|
token_ptr,
|
|
],
|
|
)
|
|
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
|
sym_tab = ir.SymbolTable(module.operation)
|
|
sym_tab.insert(main.func_op)
|
|
sym_tab.insert(global_scratch)
|
|
module.operation.verify()
|
|
|
|
return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple
|
|
|
|
|
|
def as_gpu_kernel(
|
|
body,
|
|
grid: tuple[int, ...],
|
|
block: tuple[int, ...],
|
|
in_shape,
|
|
out_shape,
|
|
smem_scratch_shape: ShapeTree | Union[ShapeTree],
|
|
prof_spec: profiler.ProfilerSpec | None = None,
|
|
):
|
|
if isinstance(in_shape, list):
|
|
in_shape = tuple(in_shape)
|
|
elif not isinstance(in_shape, tuple):
|
|
in_shape = (in_shape,)
|
|
|
|
module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = (
|
|
_lower_as_gpu_kernel(
|
|
body, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec
|
|
)
|
|
)
|
|
|
|
expected_arg_treedef = jax.tree.structure(in_shape)
|
|
def _check_args(*args):
|
|
arg_treedef = jax.tree.structure(args)
|
|
if arg_treedef != expected_arg_treedef:
|
|
raise ValueError(
|
|
f"Invalid argument structure: expected {expected_arg_treedef}, got"
|
|
f" {arg_treedef}, ({args=})"
|
|
)
|
|
|
|
def bind(*args):
|
|
return mosaic_gpu_p.bind(
|
|
*args,
|
|
out_types=out_shape,
|
|
module=module,
|
|
gmem_scratch_bytes=gmem_scratch_bytes,
|
|
)
|
|
|
|
if prof_spec is not None:
|
|
@jax.jit
|
|
def prof_kernel(*args):
|
|
_check_args(*args)
|
|
*results, prof_buffer = bind(*args)
|
|
def dump_profile(prof_buffer):
|
|
out_file = os.path.join(
|
|
os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"),
|
|
f"{time.time_ns()}-trace.json",
|
|
)
|
|
try:
|
|
with open(out_file, "x") as f:
|
|
prof_spec.dump(prof_buffer, f, grid=grid, block=block)
|
|
except FileExistsError:
|
|
pass # TODO: Retry
|
|
jax.debug.callback(dump_profile, prof_buffer)
|
|
return results[0] if unwrap_output_tuple else results
|
|
return prof_kernel
|
|
else:
|
|
@jax.jit
|
|
def kernel(*args):
|
|
_check_args(*args)
|
|
results = bind(*args)
|
|
return results[0] if unwrap_output_tuple else results
|
|
return kernel
|
|
|
|
|
|
def _declare_runtime_functions():
|
|
"""Declares the runtime functions that can be used by the generated code."""
|
|
ptr_ty = ir.Type.parse("!llvm.ptr")
|
|
i64 = ir.IntegerType.get_signless(64)
|
|
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
|
|
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
|
|
func.FuncOp(
|
|
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
|
|
)
|
|
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
|
|
func.FuncOp(
|
|
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
|
|
)
|