mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Mosaic GPU] Use a custom TMA descriptor initialization method
The one bundled with the default MLIR runtime was convenient, but it is also impractical. It allocates memory (which can deadlock due to NCCL), does a synchronous host-to-device copy and then leaks the descriptor after the kernel... With this change, we use our own runtime function to create all the descriptors. What's more, we pack them all into a single buffer so that a single asynchronous copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer, letting us lean on XLA:GPU for memory management. PiperOrigin-RevId: 628430358
This commit is contained in:
parent
268b39d426
commit
9b0319512a
@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -25,6 +26,7 @@ from typing import Any, Sequence
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
|
||||
@ -58,6 +60,9 @@ else:
|
||||
PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas")
|
||||
NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm")
|
||||
|
||||
TMA_DESCRIPTOR_BYTES = 128
|
||||
TMA_DESCRIPTOR_ALIGNMENT = 64
|
||||
|
||||
|
||||
c = mgpu.c # This is too common to fully qualify.
|
||||
|
||||
@ -97,11 +102,13 @@ mosaic_gpu_p.multiple_results = True
|
||||
|
||||
|
||||
@mosaic_gpu_p.def_abstract_eval
|
||||
def _mosaic_gpu_abstract_eval(*_, module, out_types):
|
||||
def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes):
|
||||
del module, gmem_scratch_bytes # Unused.
|
||||
return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types]
|
||||
|
||||
|
||||
def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types):
|
||||
def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes):
|
||||
del out_types # Unused.
|
||||
runtime_path = (
|
||||
pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent.parent.parent
|
||||
/ "mosaic" / "gpu" / "libmlir_cuda_runtime.so"
|
||||
@ -127,11 +134,16 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types):
|
||||
) # pytype: disable=attribute-error
|
||||
op = mlir.custom_call(
|
||||
"mosaic_gpu",
|
||||
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
result_types=[
|
||||
*(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out),
|
||||
mlir.aval_to_ir_type(
|
||||
jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8)
|
||||
),
|
||||
],
|
||||
operands=args,
|
||||
backend_config=ptr_bytes,
|
||||
)
|
||||
return op.results
|
||||
return op.results[:-1] # Skip the scratch space.
|
||||
|
||||
mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda")
|
||||
|
||||
@ -227,7 +239,12 @@ OnDeviceProfiler = profiler.OnDeviceProfiler
|
||||
@dataclasses.dataclass()
|
||||
class LaunchContext:
|
||||
launch_op: gpu.LaunchOp
|
||||
gmem_scratch_ptr: ir.Value
|
||||
profiler: OnDeviceProfiler | None = None
|
||||
next_scratch_offset: int = 0
|
||||
host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field(
|
||||
default_factory=list, init=False
|
||||
)
|
||||
tma_descriptors: dict[
|
||||
tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]],
|
||||
ir.Value,
|
||||
@ -241,6 +258,37 @@ class LaunchContext:
|
||||
else:
|
||||
yield
|
||||
|
||||
def _alloc_scratch(
|
||||
self,
|
||||
size: int,
|
||||
alignment: int | None = None,
|
||||
host_init: Callable[[ir.Value], None] = lambda _: None,
|
||||
device_init: Callable[[ir.Value], Any] = lambda x: x,
|
||||
) -> ir.Value:
|
||||
"""Allocates a GMEM scratch buffer.
|
||||
|
||||
The buffer is initialized on the host and then copied to GMEM before the
|
||||
kernel launch.
|
||||
"""
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
if alignment is None:
|
||||
alignment = size
|
||||
if self.next_scratch_offset % alignment:
|
||||
raise NotImplementedError # TODO(apaszke): Pad to match alignment
|
||||
alloc_base = self.next_scratch_offset
|
||||
self.next_scratch_offset += size
|
||||
def host_init_wrapped(host_ptr):
|
||||
with ir.InsertionPoint(self.launch_op):
|
||||
host_init(
|
||||
llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8)
|
||||
)
|
||||
self.host_scratch_init.append(host_init_wrapped)
|
||||
with ir.InsertionPoint.at_block_begin(self.launch_op.body.blocks[0]):
|
||||
return device_init(llvm.getelementptr(
|
||||
ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8
|
||||
))
|
||||
|
||||
def _get_tma_desc(
|
||||
self,
|
||||
ref,
|
||||
@ -265,13 +313,42 @@ class LaunchContext:
|
||||
with ir.InsertionPoint(self.launch_op):
|
||||
for t in gmem_transform:
|
||||
ref = t.apply(ref)
|
||||
ref_unranked = memref.cast(
|
||||
ir.UnrankedMemRefType.get(ref_ty.element_type, None), ref
|
||||
)
|
||||
tma_desc = nvgpu.tma_create_descriptor(
|
||||
tensor_map_ty,
|
||||
ref_unranked,
|
||||
[c(s, index) for s in transformed_slice_shape],
|
||||
ref_ty = ir.MemRefType(ref.type)
|
||||
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
def init_tma_desc(host_ptr):
|
||||
_, offset, *sizes_and_strides = memref.extract_strided_metadata(ref)
|
||||
aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref)
|
||||
as_i64 = lambda i: arith.index_cast(i64, i)
|
||||
alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx))
|
||||
llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings...
|
||||
base_ptr = llvm.getelementptr(
|
||||
ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type,
|
||||
)
|
||||
rank = ref_ty.rank
|
||||
assert rank * 2 == len(sizes_and_strides)
|
||||
args = [
|
||||
host_ptr,
|
||||
base_ptr,
|
||||
c(utils.bytewidth(ref_ty.element_type), i64),
|
||||
c(rank, i64),
|
||||
utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]),
|
||||
utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]),
|
||||
c(0 if swizzle is None else swizzle, i64),
|
||||
utils.pack_array([c(v, i64) for v in transformed_slice_shape]),
|
||||
]
|
||||
func.call([], "mosaic_gpu_init_tma_desc", args)
|
||||
def cast_tma_desc(device_ptr):
|
||||
nvvm.prefetch_tensormap(device_ptr)
|
||||
return builtin.unrealized_conversion_cast(
|
||||
[tensor_map_ty], [device_ptr]
|
||||
)
|
||||
tma_desc = self._alloc_scratch(
|
||||
TMA_DESCRIPTOR_BYTES,
|
||||
alignment=TMA_DESCRIPTOR_ALIGNMENT,
|
||||
host_init=init_tma_desc,
|
||||
device_init=cast_tma_desc,
|
||||
)
|
||||
self.tma_descriptors[tma_desc_key] = tma_desc
|
||||
return tma_desc
|
||||
@ -378,18 +455,14 @@ class LaunchContext:
|
||||
nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only)
|
||||
gpu.barrier() # Groups are supposedly tracked per-thread
|
||||
|
||||
def _prefetch_tma_descs(self):
|
||||
with ir.InsertionPoint(self.launch_op.body.blocks[0]):
|
||||
with mgpu.once():
|
||||
for desc in self.tma_descriptors.values():
|
||||
nvgpu.tma_prefetch_descriptor(desc)
|
||||
|
||||
|
||||
# TODO(apaszke): Inline this
|
||||
@contextlib.contextmanager
|
||||
def _launch(
|
||||
token,
|
||||
grid,
|
||||
block,
|
||||
gmem_scratch_ptr,
|
||||
smem_buffers,
|
||||
profiler_spec: profiler.ProfilerSpec | None = None,
|
||||
maybe_prof_buffer: ir.Value | None = None,
|
||||
@ -449,7 +522,7 @@ def _launch(
|
||||
else:
|
||||
prof = None
|
||||
smem_ref_tree = jax.tree.unflatten(smem_buffer_tree, smem_refs)
|
||||
yield LaunchContext(launch_op, prof), smem_ref_tree
|
||||
yield LaunchContext(launch_op, gmem_scratch_ptr, prof), smem_ref_tree
|
||||
if prof is not None:
|
||||
prof.finalize(grid=grid)
|
||||
gpu.terminator()
|
||||
@ -466,6 +539,8 @@ def as_gpu_kernel(
|
||||
):
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
token_ty = ir.Type.parse("!gpu.async.token")
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
|
||||
def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
|
||||
return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
|
||||
@ -489,20 +564,46 @@ def as_gpu_kernel(
|
||||
|
||||
module = ir.Module.create()
|
||||
with ir.InsertionPoint(module.body):
|
||||
_declare_runtime_functions()
|
||||
gmem_scratch_bytes = 0
|
||||
@func.FuncOp.from_py_func(ptr_ty, ptr_ty)
|
||||
def main(token_ptr, buffers):
|
||||
nonlocal gmem_scratch_bytes
|
||||
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
|
||||
arg_refs = []
|
||||
i = -1
|
||||
for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
|
||||
ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty))
|
||||
arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty)))
|
||||
gmem_scratch_ptr = llvm.LoadOp(
|
||||
ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i + 1], ptr_ty)
|
||||
)
|
||||
in_refs = arg_refs[:len(in_ref_tys)]
|
||||
out_refs = arg_refs[len(in_ref_tys):]
|
||||
prof_buffer = out_refs.pop() if prof_spec is not None else None
|
||||
with _launch(
|
||||
token, grid, block, smem_scratch_shape, prof_spec, prof_buffer
|
||||
token, grid, block, gmem_scratch_ptr, smem_scratch_shape,
|
||||
prof_spec, prof_buffer
|
||||
) as (launch_ctx, smem_refs):
|
||||
body(launch_ctx, *in_refs, *out_refs, smem_refs)
|
||||
gmem_scratch_bytes = launch_ctx.next_scratch_offset
|
||||
# Allocate and initialize the host buffer right before the launch.
|
||||
# Note that we couldn't do that before, because we had to run the body
|
||||
# to learn what the scratch contains.
|
||||
with ir.InsertionPoint(launch_ctx.launch_op):
|
||||
host_scratch_ptr = llvm.alloca(ptr_ty, c(gmem_scratch_bytes, i64), i8)
|
||||
for init_callback in launch_ctx.host_scratch_init:
|
||||
init_callback(host_scratch_ptr)
|
||||
func.call(
|
||||
[],
|
||||
"mosaic_gpu_memcpy_async_h2d",
|
||||
[
|
||||
gmem_scratch_ptr,
|
||||
host_scratch_ptr,
|
||||
c(gmem_scratch_bytes, i64),
|
||||
token_ptr,
|
||||
],
|
||||
)
|
||||
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||
module.operation.verify()
|
||||
|
||||
@ -523,7 +624,12 @@ def as_gpu_kernel(
|
||||
pass_manager.run(module.operation)
|
||||
|
||||
def bind(*args):
|
||||
return mosaic_gpu_p.bind(*args, out_types=out_shape, module=module)
|
||||
return mosaic_gpu_p.bind(
|
||||
*args,
|
||||
out_types=out_shape,
|
||||
module=module,
|
||||
gmem_scratch_bytes=gmem_scratch_bytes,
|
||||
)
|
||||
|
||||
if prof_spec is not None:
|
||||
@jax.jit
|
||||
@ -552,6 +658,21 @@ def as_gpu_kernel(
|
||||
return kernel
|
||||
|
||||
|
||||
def _declare_runtime_functions():
|
||||
"""Declares the runtime functions that can be used by the generated code."""
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
|
||||
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
|
||||
func.FuncOp(
|
||||
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
|
||||
)
|
||||
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
|
||||
func.FuncOp(
|
||||
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
|
||||
)
|
||||
|
||||
|
||||
def dump_low_level(module):
|
||||
dump_ptx = mosaic_gpu_dump_ptx.value
|
||||
dump_ptxas = mosaic_gpu_dump_ptxas.value
|
||||
|
@ -69,6 +69,19 @@ def ptr_as_memref(ptr, memref_ty: ir.MemRefType):
|
||||
return builtin.unrealized_conversion_cast([memref_ty], [desc])
|
||||
|
||||
|
||||
def pack_array(values):
|
||||
if not values:
|
||||
raise ValueError("Empty array")
|
||||
elem_ty = values[0].type
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
arr_ptr = llvm.alloca(ptr_ty, c(len(values), i64), elem_ty)
|
||||
for i, v in enumerate(values):
|
||||
elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty)
|
||||
llvm.store(v, elem_ptr)
|
||||
return arr_ptr
|
||||
|
||||
|
||||
def get_contiguous_strides(xs):
|
||||
strides_ret = []
|
||||
stride = 1
|
||||
|
@ -37,6 +37,7 @@ cc_library(
|
||||
defines = ["JAX_GPU_CUDA=1"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@xla//xla/tsl/cuda:cupti",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cudnn_header",
|
||||
],
|
||||
|
@ -219,6 +219,12 @@ pybind_extension(
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
linkopts = select({
|
||||
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
|
||||
"-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//third_party/py/jax:__subpackages__"],
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
@ -227,6 +233,7 @@ pybind_extension(
|
||||
"//jaxlib/mosaic/gpu:mlir_capi",
|
||||
"@nanobind",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -87,9 +87,20 @@ cc_library(
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime",
|
||||
srcs = ["runtime.cc"],
|
||||
deps = [
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "libmlir_cuda_runtime.so",
|
||||
srcs = ["@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp"],
|
||||
srcs = [
|
||||
"runtime.cc",
|
||||
"@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp",
|
||||
],
|
||||
copts = ["-fvisibility=default"],
|
||||
linkopts = select({
|
||||
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
|
||||
|
95
jaxlib/mosaic/gpu/runtime.cc
Normal file
95
jaxlib/mosaic/gpu/runtime.cc
Normal file
@ -0,0 +1,95 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
|
||||
int64_t elem_bytewidth, int64_t rank,
|
||||
int64_t *sizes, int64_t *strides,
|
||||
int64_t swizzle_bytes, int64_t *window_shape) {
|
||||
CUtensorMapDataType data_type;
|
||||
if (elem_bytewidth == 1) {
|
||||
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
} else if (elem_bytewidth == 2) {
|
||||
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
||||
} else if (elem_bytewidth == 4) {
|
||||
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
||||
} else if (elem_bytewidth == 8) {
|
||||
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64;
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported element size: %ld\n", elem_bytewidth);
|
||||
abort();
|
||||
}
|
||||
cuuint64_t tma_sizes[5] = {1, 1, 1, 1, 1};
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
tma_sizes[i] = static_cast<cuuint64_t>(sizes[rank - i - 1]);
|
||||
}
|
||||
cuuint64_t tma_strides[5] = {1, 1, 1, 1, 1};
|
||||
if (strides[rank - 1] != 1) {
|
||||
fprintf(stderr, "Minormost stride must be 1, but got %ld\n",
|
||||
strides[rank - 1]);
|
||||
abort();
|
||||
}
|
||||
for (int i = 0; i < rank - 1; ++i) { // We skip the implicit minor stride.
|
||||
tma_strides[i] =
|
||||
static_cast<cuuint64_t>(strides[rank - i - 2] * elem_bytewidth);
|
||||
}
|
||||
cuuint32_t tma_window_shape[5] = {1, 1, 1, 1, 1};
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
tma_window_shape[i] = static_cast<cuuint32_t>(window_shape[rank - i - 1]);
|
||||
}
|
||||
cuuint32_t element_strides[5] = {1, 1, 1, 1, 1};
|
||||
CUtensorMapSwizzle swizzle;
|
||||
if (swizzle_bytes == 0) {
|
||||
swizzle = CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
} else if (swizzle_bytes == 32) {
|
||||
swizzle = CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
} else if (swizzle_bytes == 64) {
|
||||
swizzle = CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
} else if (swizzle_bytes == 128) {
|
||||
swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported swizzle: %ld\n", swizzle_bytes);
|
||||
abort();
|
||||
}
|
||||
CUresult result = cuTensorMapEncodeTiled(
|
||||
tma_desc, data_type, rank, base_addr, tma_sizes, tma_strides,
|
||||
tma_window_shape, element_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
if (result != CUDA_SUCCESS) {
|
||||
const char *ptr = nullptr;
|
||||
cuGetErrorString(result, &ptr);
|
||||
fprintf(stderr, "cuTensorMapEncodeTiled failed: %s\n", ptr);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
void mosaic_gpu_memcpy_async_h2d(CUdeviceptr dst, void *src, uint64_t bytes,
|
||||
CUstream stream) {
|
||||
CUresult result = cuMemcpyHtoDAsync(dst, src, bytes, stream);
|
||||
if (result != CUDA_SUCCESS) {
|
||||
const char *ptr = nullptr;
|
||||
cuGetErrorString(result, &ptr);
|
||||
fprintf(stderr, "cuMemcpyAsync failed: %s\n", ptr);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user