[Mosaic GPU] Use a custom TMA descriptor initialization method

The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...

With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.

PiperOrigin-RevId: 628430358
This commit is contained in:
Adam Paszke 2024-04-26 09:39:54 -07:00 committed by jax authors
parent 268b39d426
commit 9b0319512a
6 changed files with 269 additions and 21 deletions

View File

@ -1,3 +1,4 @@
from collections.abc import Callable
# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -25,6 +26,7 @@ from typing import Any, Sequence
import jax
from jax._src import config
from jax._src import core as jax_core
from jax._src.interpreters import mlir
from jax._src.lib import xla_client
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
@ -58,6 +60,9 @@ else:
PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas")
NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm")
TMA_DESCRIPTOR_BYTES = 128
TMA_DESCRIPTOR_ALIGNMENT = 64
c = mgpu.c # This is too common to fully qualify.
@ -97,11 +102,13 @@ mosaic_gpu_p.multiple_results = True
@mosaic_gpu_p.def_abstract_eval
def _mosaic_gpu_abstract_eval(*_, module, out_types):
def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes):
del module, gmem_scratch_bytes # Unused.
return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types]
def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types):
def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes):
del out_types # Unused.
runtime_path = (
pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent.parent.parent
/ "mosaic" / "gpu" / "libmlir_cuda_runtime.so"
@ -127,11 +134,16 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types):
) # pytype: disable=attribute-error
op = mlir.custom_call(
"mosaic_gpu",
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
result_types=[
*(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out),
mlir.aval_to_ir_type(
jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8)
),
],
operands=args,
backend_config=ptr_bytes,
)
return op.results
return op.results[:-1] # Skip the scratch space.
mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda")
@ -227,7 +239,12 @@ OnDeviceProfiler = profiler.OnDeviceProfiler
@dataclasses.dataclass()
class LaunchContext:
launch_op: gpu.LaunchOp
gmem_scratch_ptr: ir.Value
profiler: OnDeviceProfiler | None = None
next_scratch_offset: int = 0
host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field(
default_factory=list, init=False
)
tma_descriptors: dict[
tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]],
ir.Value,
@ -241,6 +258,37 @@ class LaunchContext:
else:
yield
def _alloc_scratch(
self,
size: int,
alignment: int | None = None,
host_init: Callable[[ir.Value], None] = lambda _: None,
device_init: Callable[[ir.Value], Any] = lambda x: x,
) -> ir.Value:
"""Allocates a GMEM scratch buffer.
The buffer is initialized on the host and then copied to GMEM before the
kernel launch.
"""
i8 = ir.IntegerType.get_signless(8)
ptr_ty = ir.Type.parse("!llvm.ptr")
if alignment is None:
alignment = size
if self.next_scratch_offset % alignment:
raise NotImplementedError # TODO(apaszke): Pad to match alignment
alloc_base = self.next_scratch_offset
self.next_scratch_offset += size
def host_init_wrapped(host_ptr):
with ir.InsertionPoint(self.launch_op):
host_init(
llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8)
)
self.host_scratch_init.append(host_init_wrapped)
with ir.InsertionPoint.at_block_begin(self.launch_op.body.blocks[0]):
return device_init(llvm.getelementptr(
ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8
))
def _get_tma_desc(
self,
ref,
@ -265,13 +313,42 @@ class LaunchContext:
with ir.InsertionPoint(self.launch_op):
for t in gmem_transform:
ref = t.apply(ref)
ref_unranked = memref.cast(
ir.UnrankedMemRefType.get(ref_ty.element_type, None), ref
)
tma_desc = nvgpu.tma_create_descriptor(
tensor_map_ty,
ref_unranked,
[c(s, index) for s in transformed_slice_shape],
ref_ty = ir.MemRefType(ref.type)
i64 = ir.IntegerType.get_signless(64)
ptr_ty = ir.Type.parse("!llvm.ptr")
def init_tma_desc(host_ptr):
_, offset, *sizes_and_strides = memref.extract_strided_metadata(ref)
aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref)
as_i64 = lambda i: arith.index_cast(i64, i)
alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx))
llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings...
base_ptr = llvm.getelementptr(
ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type,
)
rank = ref_ty.rank
assert rank * 2 == len(sizes_and_strides)
args = [
host_ptr,
base_ptr,
c(utils.bytewidth(ref_ty.element_type), i64),
c(rank, i64),
utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]),
utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]),
c(0 if swizzle is None else swizzle, i64),
utils.pack_array([c(v, i64) for v in transformed_slice_shape]),
]
func.call([], "mosaic_gpu_init_tma_desc", args)
def cast_tma_desc(device_ptr):
nvvm.prefetch_tensormap(device_ptr)
return builtin.unrealized_conversion_cast(
[tensor_map_ty], [device_ptr]
)
tma_desc = self._alloc_scratch(
TMA_DESCRIPTOR_BYTES,
alignment=TMA_DESCRIPTOR_ALIGNMENT,
host_init=init_tma_desc,
device_init=cast_tma_desc,
)
self.tma_descriptors[tma_desc_key] = tma_desc
return tma_desc
@ -378,18 +455,14 @@ class LaunchContext:
nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only)
gpu.barrier() # Groups are supposedly tracked per-thread
def _prefetch_tma_descs(self):
with ir.InsertionPoint(self.launch_op.body.blocks[0]):
with mgpu.once():
for desc in self.tma_descriptors.values():
nvgpu.tma_prefetch_descriptor(desc)
# TODO(apaszke): Inline this
@contextlib.contextmanager
def _launch(
token,
grid,
block,
gmem_scratch_ptr,
smem_buffers,
profiler_spec: profiler.ProfilerSpec | None = None,
maybe_prof_buffer: ir.Value | None = None,
@ -449,7 +522,7 @@ def _launch(
else:
prof = None
smem_ref_tree = jax.tree.unflatten(smem_buffer_tree, smem_refs)
yield LaunchContext(launch_op, prof), smem_ref_tree
yield LaunchContext(launch_op, gmem_scratch_ptr, prof), smem_ref_tree
if prof is not None:
prof.finalize(grid=grid)
gpu.terminator()
@ -466,6 +539,8 @@ def as_gpu_kernel(
):
ptr_ty = ir.Type.parse("!llvm.ptr")
token_ty = ir.Type.parse("!gpu.async.token")
i8 = ir.IntegerType.get_signless(8)
i64 = ir.IntegerType.get_signless(64)
def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
@ -489,20 +564,46 @@ def as_gpu_kernel(
module = ir.Module.create()
with ir.InsertionPoint(module.body):
_declare_runtime_functions()
gmem_scratch_bytes = 0
@func.FuncOp.from_py_func(ptr_ty, ptr_ty)
def main(token_ptr, buffers):
nonlocal gmem_scratch_bytes
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
arg_refs = []
i = -1
for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty))
arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty)))
gmem_scratch_ptr = llvm.LoadOp(
ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i + 1], ptr_ty)
)
in_refs = arg_refs[:len(in_ref_tys)]
out_refs = arg_refs[len(in_ref_tys):]
prof_buffer = out_refs.pop() if prof_spec is not None else None
with _launch(
token, grid, block, smem_scratch_shape, prof_spec, prof_buffer
token, grid, block, gmem_scratch_ptr, smem_scratch_shape,
prof_spec, prof_buffer
) as (launch_ctx, smem_refs):
body(launch_ctx, *in_refs, *out_refs, smem_refs)
gmem_scratch_bytes = launch_ctx.next_scratch_offset
# Allocate and initialize the host buffer right before the launch.
# Note that we couldn't do that before, because we had to run the body
# to learn what the scratch contains.
with ir.InsertionPoint(launch_ctx.launch_op):
host_scratch_ptr = llvm.alloca(ptr_ty, c(gmem_scratch_bytes, i64), i8)
for init_callback in launch_ctx.host_scratch_init:
init_callback(host_scratch_ptr)
func.call(
[],
"mosaic_gpu_memcpy_async_h2d",
[
gmem_scratch_ptr,
host_scratch_ptr,
c(gmem_scratch_bytes, i64),
token_ptr,
],
)
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
module.operation.verify()
@ -523,7 +624,12 @@ def as_gpu_kernel(
pass_manager.run(module.operation)
def bind(*args):
return mosaic_gpu_p.bind(*args, out_types=out_shape, module=module)
return mosaic_gpu_p.bind(
*args,
out_types=out_shape,
module=module,
gmem_scratch_bytes=gmem_scratch_bytes,
)
if prof_spec is not None:
@jax.jit
@ -552,6 +658,21 @@ def as_gpu_kernel(
return kernel
def _declare_runtime_functions():
"""Declares the runtime functions that can be used by the generated code."""
ptr_ty = ir.Type.parse("!llvm.ptr")
i64 = ir.IntegerType.get_signless(64)
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
func.FuncOp(
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
)
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
func.FuncOp(
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
)
def dump_low_level(module):
dump_ptx = mosaic_gpu_dump_ptx.value
dump_ptxas = mosaic_gpu_dump_ptxas.value

View File

@ -69,6 +69,19 @@ def ptr_as_memref(ptr, memref_ty: ir.MemRefType):
return builtin.unrealized_conversion_cast([memref_ty], [desc])
def pack_array(values):
if not values:
raise ValueError("Empty array")
elem_ty = values[0].type
i64 = ir.IntegerType.get_signless(64)
ptr_ty = ir.Type.parse("!llvm.ptr")
arr_ptr = llvm.alloca(ptr_ty, c(len(values), i64), elem_ty)
for i, v in enumerate(values):
elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty)
llvm.store(v, elem_ptr)
return arr_ptr
def get_contiguous_strides(xs):
strides_ret = []
stride = 1

View File

@ -37,6 +37,7 @@ cc_library(
defines = ["JAX_GPU_CUDA=1"],
visibility = ["//visibility:public"],
deps = [
"@xla//xla/tsl/cuda:cupti",
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudnn_header",
],

View File

@ -219,6 +219,12 @@ pybind_extension(
"-fexceptions",
"-fno-strict-aliasing",
],
linkopts = select({
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
"-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib",
],
"//conditions:default": [],
}),
visibility = ["//third_party/py/jax:__subpackages__"],
deps = [
":jaxlib_mlir_capi_shared_library",
@ -227,6 +233,7 @@ pybind_extension(
"//jaxlib/mosaic/gpu:mlir_capi",
"@nanobind",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
],
)

View File

@ -87,9 +87,20 @@ cc_library(
alwayslink = True,
)
cc_library(
name = "runtime",
srcs = ["runtime.cc"],
deps = [
"@local_config_cuda//cuda:cuda_headers",
],
)
cc_binary(
name = "libmlir_cuda_runtime.so",
srcs = ["@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp"],
srcs = [
"runtime.cc",
"@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp",
],
copts = ["-fvisibility=default"],
linkopts = select({
"@xla//xla/python:use_jax_cuda_pip_rpaths": [

View File

@ -0,0 +1,95 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include "third_party/gpus/cuda/include/cuda.h"
extern "C" {
void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
int64_t elem_bytewidth, int64_t rank,
int64_t *sizes, int64_t *strides,
int64_t swizzle_bytes, int64_t *window_shape) {
CUtensorMapDataType data_type;
if (elem_bytewidth == 1) {
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if (elem_bytewidth == 2) {
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
} else if (elem_bytewidth == 4) {
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
} else if (elem_bytewidth == 8) {
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64;
} else {
fprintf(stderr, "Unsupported element size: %ld\n", elem_bytewidth);
abort();
}
cuuint64_t tma_sizes[5] = {1, 1, 1, 1, 1};
for (int i = 0; i < rank; ++i) {
tma_sizes[i] = static_cast<cuuint64_t>(sizes[rank - i - 1]);
}
cuuint64_t tma_strides[5] = {1, 1, 1, 1, 1};
if (strides[rank - 1] != 1) {
fprintf(stderr, "Minormost stride must be 1, but got %ld\n",
strides[rank - 1]);
abort();
}
for (int i = 0; i < rank - 1; ++i) { // We skip the implicit minor stride.
tma_strides[i] =
static_cast<cuuint64_t>(strides[rank - i - 2] * elem_bytewidth);
}
cuuint32_t tma_window_shape[5] = {1, 1, 1, 1, 1};
for (int64_t i = 0; i < rank; ++i) {
tma_window_shape[i] = static_cast<cuuint32_t>(window_shape[rank - i - 1]);
}
cuuint32_t element_strides[5] = {1, 1, 1, 1, 1};
CUtensorMapSwizzle swizzle;
if (swizzle_bytes == 0) {
swizzle = CU_TENSOR_MAP_SWIZZLE_NONE;
} else if (swizzle_bytes == 32) {
swizzle = CU_TENSOR_MAP_SWIZZLE_32B;
} else if (swizzle_bytes == 64) {
swizzle = CU_TENSOR_MAP_SWIZZLE_64B;
} else if (swizzle_bytes == 128) {
swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
} else {
fprintf(stderr, "Unsupported swizzle: %ld\n", swizzle_bytes);
abort();
}
CUresult result = cuTensorMapEncodeTiled(
tma_desc, data_type, rank, base_addr, tma_sizes, tma_strides,
tma_window_shape, element_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
if (result != CUDA_SUCCESS) {
const char *ptr = nullptr;
cuGetErrorString(result, &ptr);
fprintf(stderr, "cuTensorMapEncodeTiled failed: %s\n", ptr);
abort();
}
}
void mosaic_gpu_memcpy_async_h2d(CUdeviceptr dst, void *src, uint64_t bytes,
CUstream stream) {
CUresult result = cuMemcpyHtoDAsync(dst, src, bytes, stream);
if (result != CUDA_SUCCESS) {
const char *ptr = nullptr;
cuGetErrorString(result, &ptr);
fprintf(stderr, "cuMemcpyAsync failed: %s\n", ptr);
abort();
}
}
}