From 9b0319512a3fafbf2d857e8cca31d2e9dcb2f98e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 26 Apr 2024 09:39:54 -0700 Subject: [PATCH] [Mosaic GPU] Use a custom TMA descriptor initialization method The one bundled with the default MLIR runtime was convenient, but it is also impractical. It allocates memory (which can deadlock due to NCCL), does a synchronous host-to-device copy and then leaks the descriptor after the kernel... With this change, we use our own runtime function to create all the descriptors. What's more, we pack them all into a single buffer so that a single asynchronous copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer, letting us lean on XLA:GPU for memory management. PiperOrigin-RevId: 628430358 --- jax/experimental/mosaic/gpu/__init__.py | 161 +++++++++++++++++++++--- jax/experimental/mosaic/gpu/utils.py | 13 ++ jaxlib/cuda/BUILD | 1 + jaxlib/mlir/_mlir_libs/BUILD.bazel | 7 ++ jaxlib/mosaic/gpu/BUILD | 13 +- jaxlib/mosaic/gpu/runtime.cc | 95 ++++++++++++++ 6 files changed, 269 insertions(+), 21 deletions(-) create mode 100644 jaxlib/mosaic/gpu/runtime.cc diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index d1b3c935e..9e47f2b18 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -1,3 +1,4 @@ +from collections.abc import Callable # Copyright 2024 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,6 +26,7 @@ from typing import Any, Sequence import jax from jax._src import config +from jax._src import core as jax_core from jax._src.interpreters import mlir from jax._src.lib import xla_client from jax._src.lib import mosaic_gpu as mosaic_gpu_lib @@ -58,6 +60,9 @@ else: PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") +TMA_DESCRIPTOR_BYTES = 128 +TMA_DESCRIPTOR_ALIGNMENT = 64 + c = mgpu.c # This is too common to fully qualify. @@ -97,11 +102,13 @@ mosaic_gpu_p.multiple_results = True @mosaic_gpu_p.def_abstract_eval -def _mosaic_gpu_abstract_eval(*_, module, out_types): +def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes): + del module, gmem_scratch_bytes # Unused. return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] -def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types): +def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes): + del out_types # Unused. runtime_path = ( pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent.parent.parent / "mosaic" / "gpu" / "libmlir_cuda_runtime.so" @@ -127,11 +134,16 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types): ) # pytype: disable=attribute-error op = mlir.custom_call( "mosaic_gpu", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + result_types=[ + *(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out), + mlir.aval_to_ir_type( + jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8) + ), + ], operands=args, backend_config=ptr_bytes, ) - return op.results + return op.results[:-1] # Skip the scratch space. mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") @@ -227,7 +239,12 @@ OnDeviceProfiler = profiler.OnDeviceProfiler @dataclasses.dataclass() class LaunchContext: launch_op: gpu.LaunchOp + gmem_scratch_ptr: ir.Value profiler: OnDeviceProfiler | None = None + next_scratch_offset: int = 0 + host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( + default_factory=list, init=False + ) tma_descriptors: dict[ tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], ir.Value, @@ -241,6 +258,37 @@ class LaunchContext: else: yield + def _alloc_scratch( + self, + size: int, + alignment: int | None = None, + host_init: Callable[[ir.Value], None] = lambda _: None, + device_init: Callable[[ir.Value], Any] = lambda x: x, + ) -> ir.Value: + """Allocates a GMEM scratch buffer. + + The buffer is initialized on the host and then copied to GMEM before the + kernel launch. + """ + i8 = ir.IntegerType.get_signless(8) + ptr_ty = ir.Type.parse("!llvm.ptr") + if alignment is None: + alignment = size + if self.next_scratch_offset % alignment: + raise NotImplementedError # TODO(apaszke): Pad to match alignment + alloc_base = self.next_scratch_offset + self.next_scratch_offset += size + def host_init_wrapped(host_ptr): + with ir.InsertionPoint(self.launch_op): + host_init( + llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) + ) + self.host_scratch_init.append(host_init_wrapped) + with ir.InsertionPoint.at_block_begin(self.launch_op.body.blocks[0]): + return device_init(llvm.getelementptr( + ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8 + )) + def _get_tma_desc( self, ref, @@ -265,13 +313,42 @@ class LaunchContext: with ir.InsertionPoint(self.launch_op): for t in gmem_transform: ref = t.apply(ref) - ref_unranked = memref.cast( - ir.UnrankedMemRefType.get(ref_ty.element_type, None), ref - ) - tma_desc = nvgpu.tma_create_descriptor( - tensor_map_ty, - ref_unranked, - [c(s, index) for s in transformed_slice_shape], + ref_ty = ir.MemRefType(ref.type) + + i64 = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr") + def init_tma_desc(host_ptr): + _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) + aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) + as_i64 = lambda i: arith.index_cast(i64, i) + alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) + llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... + base_ptr = llvm.getelementptr( + ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, + ) + rank = ref_ty.rank + assert rank * 2 == len(sizes_and_strides) + args = [ + host_ptr, + base_ptr, + c(utils.bytewidth(ref_ty.element_type), i64), + c(rank, i64), + utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), + utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), + c(0 if swizzle is None else swizzle, i64), + utils.pack_array([c(v, i64) for v in transformed_slice_shape]), + ] + func.call([], "mosaic_gpu_init_tma_desc", args) + def cast_tma_desc(device_ptr): + nvvm.prefetch_tensormap(device_ptr) + return builtin.unrealized_conversion_cast( + [tensor_map_ty], [device_ptr] + ) + tma_desc = self._alloc_scratch( + TMA_DESCRIPTOR_BYTES, + alignment=TMA_DESCRIPTOR_ALIGNMENT, + host_init=init_tma_desc, + device_init=cast_tma_desc, ) self.tma_descriptors[tma_desc_key] = tma_desc return tma_desc @@ -378,18 +455,14 @@ class LaunchContext: nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) gpu.barrier() # Groups are supposedly tracked per-thread - def _prefetch_tma_descs(self): - with ir.InsertionPoint(self.launch_op.body.blocks[0]): - with mgpu.once(): - for desc in self.tma_descriptors.values(): - nvgpu.tma_prefetch_descriptor(desc) - +# TODO(apaszke): Inline this @contextlib.contextmanager def _launch( token, grid, block, + gmem_scratch_ptr, smem_buffers, profiler_spec: profiler.ProfilerSpec | None = None, maybe_prof_buffer: ir.Value | None = None, @@ -449,7 +522,7 @@ def _launch( else: prof = None smem_ref_tree = jax.tree.unflatten(smem_buffer_tree, smem_refs) - yield LaunchContext(launch_op, prof), smem_ref_tree + yield LaunchContext(launch_op, gmem_scratch_ptr, prof), smem_ref_tree if prof is not None: prof.finalize(grid=grid) gpu.terminator() @@ -466,6 +539,8 @@ def as_gpu_kernel( ): ptr_ty = ir.Type.parse("!llvm.ptr") token_ty = ir.Type.parse("!gpu.async.token") + i8 = ir.IntegerType.get_signless(8) + i64 = ir.IntegerType.get_signless(64) def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) @@ -489,20 +564,46 @@ def as_gpu_kernel( module = ir.Module.create() with ir.InsertionPoint(module.body): + _declare_runtime_functions() + gmem_scratch_bytes = 0 @func.FuncOp.from_py_func(ptr_ty, ptr_ty) def main(token_ptr, buffers): + nonlocal gmem_scratch_bytes token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) arg_refs = [] + i = -1 for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) + gmem_scratch_ptr = llvm.LoadOp( + ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i + 1], ptr_ty) + ) in_refs = arg_refs[:len(in_ref_tys)] out_refs = arg_refs[len(in_ref_tys):] prof_buffer = out_refs.pop() if prof_spec is not None else None with _launch( - token, grid, block, smem_scratch_shape, prof_spec, prof_buffer + token, grid, block, gmem_scratch_ptr, smem_scratch_shape, + prof_spec, prof_buffer ) as (launch_ctx, smem_refs): body(launch_ctx, *in_refs, *out_refs, smem_refs) + gmem_scratch_bytes = launch_ctx.next_scratch_offset + # Allocate and initialize the host buffer right before the launch. + # Note that we couldn't do that before, because we had to run the body + # to learn what the scratch contains. + with ir.InsertionPoint(launch_ctx.launch_op): + host_scratch_ptr = llvm.alloca(ptr_ty, c(gmem_scratch_bytes, i64), i8) + for init_callback in launch_ctx.host_scratch_init: + init_callback(host_scratch_ptr) + func.call( + [], + "mosaic_gpu_memcpy_async_h2d", + [ + gmem_scratch_ptr, + host_scratch_ptr, + c(gmem_scratch_bytes, i64), + token_ptr, + ], + ) main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() module.operation.verify() @@ -523,7 +624,12 @@ def as_gpu_kernel( pass_manager.run(module.operation) def bind(*args): - return mosaic_gpu_p.bind(*args, out_types=out_shape, module=module) + return mosaic_gpu_p.bind( + *args, + out_types=out_shape, + module=module, + gmem_scratch_bytes=gmem_scratch_bytes, + ) if prof_spec is not None: @jax.jit @@ -552,6 +658,21 @@ def as_gpu_kernel( return kernel +def _declare_runtime_functions(): + """Declares the runtime functions that can be used by the generated code.""" + ptr_ty = ir.Type.parse("!llvm.ptr") + i64 = ir.IntegerType.get_signless(64) + arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] + init_tma_desc_type = ir.FunctionType.get(arg_tys, []) + func.FuncOp( + "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" + ) + memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) + func.FuncOp( + "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" + ) + + def dump_low_level(module): dump_ptx = mosaic_gpu_dump_ptx.value dump_ptxas = mosaic_gpu_dump_ptxas.value diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 103187e23..9e704533a 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -69,6 +69,19 @@ def ptr_as_memref(ptr, memref_ty: ir.MemRefType): return builtin.unrealized_conversion_cast([memref_ty], [desc]) +def pack_array(values): + if not values: + raise ValueError("Empty array") + elem_ty = values[0].type + i64 = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr") + arr_ptr = llvm.alloca(ptr_ty, c(len(values), i64), elem_ty) + for i, v in enumerate(values): + elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty) + llvm.store(v, elem_ptr) + return arr_ptr + + def get_contiguous_strides(xs): strides_ret = [] stride = 1 diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 99449a5d9..1b976ab32 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -37,6 +37,7 @@ cc_library( defines = ["JAX_GPU_CUDA=1"], visibility = ["//visibility:public"], deps = [ + "@xla//xla/tsl/cuda:cupti", "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", ], diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index d947ef6ee..e0d499439 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -219,6 +219,12 @@ pybind_extension( "-fexceptions", "-fno-strict-aliasing", ], + linkopts = select({ + "@xla//xla/python:use_jax_cuda_pip_rpaths": [ + "-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib", + ], + "//conditions:default": [], + }), visibility = ["//third_party/py/jax:__subpackages__"], deps = [ ":jaxlib_mlir_capi_shared_library", @@ -227,6 +233,7 @@ pybind_extension( "//jaxlib/mosaic/gpu:mlir_capi", "@nanobind", "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", ], ) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index d0615dd8a..ff40b3165 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -87,9 +87,20 @@ cc_library( alwayslink = True, ) +cc_library( + name = "runtime", + srcs = ["runtime.cc"], + deps = [ + "@local_config_cuda//cuda:cuda_headers", + ], +) + cc_binary( name = "libmlir_cuda_runtime.so", - srcs = ["@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp"], + srcs = [ + "runtime.cc", + "@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp", + ], copts = ["-fvisibility=default"], linkopts = select({ "@xla//xla/python:use_jax_cuda_pip_rpaths": [ diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc new file mode 100644 index 000000000..d50ec8e63 --- /dev/null +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -0,0 +1,95 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include "third_party/gpus/cuda/include/cuda.h" + +extern "C" { + +void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, + int64_t elem_bytewidth, int64_t rank, + int64_t *sizes, int64_t *strides, + int64_t swizzle_bytes, int64_t *window_shape) { + CUtensorMapDataType data_type; + if (elem_bytewidth == 1) { + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if (elem_bytewidth == 2) { + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + } else if (elem_bytewidth == 4) { + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + } else if (elem_bytewidth == 8) { + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; + } else { + fprintf(stderr, "Unsupported element size: %ld\n", elem_bytewidth); + abort(); + } + cuuint64_t tma_sizes[5] = {1, 1, 1, 1, 1}; + for (int i = 0; i < rank; ++i) { + tma_sizes[i] = static_cast(sizes[rank - i - 1]); + } + cuuint64_t tma_strides[5] = {1, 1, 1, 1, 1}; + if (strides[rank - 1] != 1) { + fprintf(stderr, "Minormost stride must be 1, but got %ld\n", + strides[rank - 1]); + abort(); + } + for (int i = 0; i < rank - 1; ++i) { // We skip the implicit minor stride. + tma_strides[i] = + static_cast(strides[rank - i - 2] * elem_bytewidth); + } + cuuint32_t tma_window_shape[5] = {1, 1, 1, 1, 1}; + for (int64_t i = 0; i < rank; ++i) { + tma_window_shape[i] = static_cast(window_shape[rank - i - 1]); + } + cuuint32_t element_strides[5] = {1, 1, 1, 1, 1}; + CUtensorMapSwizzle swizzle; + if (swizzle_bytes == 0) { + swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + } else if (swizzle_bytes == 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else if (swizzle_bytes == 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (swizzle_bytes == 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else { + fprintf(stderr, "Unsupported swizzle: %ld\n", swizzle_bytes); + abort(); + } + CUresult result = cuTensorMapEncodeTiled( + tma_desc, data_type, rank, base_addr, tma_sizes, tma_strides, + tma_window_shape, element_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, + CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + if (result != CUDA_SUCCESS) { + const char *ptr = nullptr; + cuGetErrorString(result, &ptr); + fprintf(stderr, "cuTensorMapEncodeTiled failed: %s\n", ptr); + abort(); + } +} + +void mosaic_gpu_memcpy_async_h2d(CUdeviceptr dst, void *src, uint64_t bytes, + CUstream stream) { + CUresult result = cuMemcpyHtoDAsync(dst, src, bytes, stream); + if (result != CUDA_SUCCESS) { + const char *ptr = nullptr; + cuGetErrorString(result, &ptr); + fprintf(stderr, "cuMemcpyAsync failed: %s\n", ptr); + abort(); + } +} + +}