From 4051ac2a2fa06c86cbafc05961a7cd56be5303d5 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 30 Apr 2024 07:09:12 -0700 Subject: [PATCH] [Mosaic GPU] Only call kernel initializer from inside a custom call XLA:GPU custom call design is far from ideal, as there's apparently no way to figure out the CUDA context that will be used to run an HLO module before the custom call is first called. So, we can't preload the kernel onto the GPU, or else we'll get invalid handle errors due to the load and launch happening in different CUDA contexts... Also fix up build_wheel.py to match the rename of the runtime lib. PiperOrigin-RevId: 629401858 --- jax/experimental/mosaic/gpu/__init__.py | 17 ++++---- jax/experimental/mosaic/gpu/profiler.py | 1 + jaxlib/mlir/_mlir_libs/BUILD.bazel | 2 + jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc | 51 +++++++++++++++++++++++- jaxlib/tools/build_wheel.py | 2 +- 5 files changed, 62 insertions(+), 11 deletions(-) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 358de1762..102868ec5 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -118,16 +118,14 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes) module, opt_level=3, shared_libs=shared_libs, enable_object_dump=False ) ctx.module_context.add_keepalive(engine) - func_ptr = engine.lookup("main") - - # Run the compile-time initialization. - kernel_params_ptr = (ctypes.c_void_p * 1)() - engine.invoke("main_init", kernel_params_ptr) - kernel_params = kernel_params_ptr[0] + launch_func_ptr = ctypes.cast(engine.lookup("main"), ctypes.c_void_p) + init_func_ptr = ctypes.cast(engine.lookup("main_init"), ctypes.c_void_p) + # Make sure we won't get accidental hits due to address reuse. + mosaic_gpu_lib._mosaic_gpu_ext.invalidate_cache(init_func_ptr.value) trampoline_args = (ctypes.c_void_p * 2)() - trampoline_args[0] = ctypes.cast(func_ptr, ctypes.c_void_p) - trampoline_args[1] = ctypes.cast(kernel_params, ctypes.c_void_p) + trampoline_args[0] = launch_func_ptr + trampoline_args[1] = init_func_ptr ctx.module_context.add_keepalive(trampoline_args) ptr_bytes = ctypes.cast(trampoline_args, ctypes.c_void_p).value.to_bytes( 8, byteorder="little" @@ -340,7 +338,8 @@ class LaunchContext: ] func.call([], "mosaic_gpu_init_tma_desc", args) def cast_tma_desc(device_ptr): - nvvm.prefetch_tensormap(device_ptr) + # TODO(apaszke): Investigate why prefetching can cause launch failures + # nvvm.prefetch_tensormap(device_ptr) return builtin.unrealized_conversion_cast( [tensor_map_ty], [device_ptr] ) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index a9af39a6a..597ea78a6 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -77,6 +77,7 @@ def measure(f, *args): @jax.jit def run(*args): return _record_event(f(*_record_event(args, start_event)), end_event) + jax.block_until_ready(run(*args)) # Warmup. results = jax.block_until_ready(run(*args)) elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed( start_event, end_event diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index e0d499439..b762f5c64 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -231,6 +231,8 @@ pybind_extension( "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cuda:cuda_vendor", "//jaxlib/mosaic/gpu:mlir_capi", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", "@nanobind", "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index bbfa13fe1..8afd1b21a 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -1,5 +1,8 @@ #include + #include "nanobind/nanobind.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "jaxlib/mosaic/gpu/integrations/c/passes.h" @@ -9,13 +12,56 @@ namespace jax::cuda { namespace { namespace nb = nanobind; +using MosaicInitFunc = void(void***); using MosaicHostFunc = void(void**); +std::pair*, absl::Mutex*> +GetContextCache() { + static absl::Mutex mutex; + static auto& context_cache = *new absl::flat_hash_map; + return std::make_pair(&context_cache, &mutex); +} + +void InvalidateCache(MosaicInitFunc* init) { + auto cache = GetContextCache(); + absl::MutexLock lock(cache.second); + // TODO(apaszke): Free all the resources instead of leaking. + cache.first->erase(reinterpret_cast(init)); +} + +// Each compiled kernel has a unique init func, and each kernel is used from +// a single HLO module. So it should be safe to not include the CUDA context +// in the key. +void* InitOnce(MosaicInitFunc* init) { + auto cache_and_mutex = GetContextCache(); + auto* cache = cache_and_mutex.first; + auto* mutex = cache_and_mutex.second; + + uintptr_t key = reinterpret_cast(init); + + { + // Fast path uses reader lock (as hash map look-up is relatively slow). + absl::ReaderMutexLock lock(mutex); + auto it = cache->find(key); + if (ABSL_PREDICT_TRUE(it != cache->end())) return it->second; + } + + absl::MutexLock lock(mutex); + void*& ctx = (*cache)[key]; + // We released the reader lock, another thread might have initialized it. + if (ctx == nullptr) { + void** ptr_to_ctx = &ctx; + init(&ptr_to_ctx); + } + return ctx; +} + void MosaicKernelCall(void* stream, void** buffers, char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { void** static_args = *reinterpret_cast(opaque); MosaicHostFunc* func = reinterpret_cast(static_args[0]); - void* ctx = static_args[1]; + MosaicInitFunc* init = reinterpret_cast(static_args[1]); + void* ctx = InitOnce(init); void* args[3] = {&ctx, &stream, &buffers}; func(args); } @@ -53,6 +99,9 @@ NB_MODULE(_mosaic_gpu_ext, m) { }); m.def("_record_event_capsule", []() { return EncapsulateFunction(EventRecordCall); }); + m.def("invalidate_cache", [](uintptr_t init_func_ptr) { + return InvalidateCache(reinterpret_cast(init_func_ptr)); + }); } } // namespace diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index bfe6848a2..05fa972c2 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -274,7 +274,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): copy_runfiles( dst_dir=jaxlib_dir / "mosaic" / "gpu", src_files=[ - "__main__/jaxlib/mosaic/gpu/libmlir_cuda_runtime.so", + "__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", ], )