diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 358de1762..102868ec5 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -118,16 +118,14 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes) module, opt_level=3, shared_libs=shared_libs, enable_object_dump=False ) ctx.module_context.add_keepalive(engine) - func_ptr = engine.lookup("main") - - # Run the compile-time initialization. - kernel_params_ptr = (ctypes.c_void_p * 1)() - engine.invoke("main_init", kernel_params_ptr) - kernel_params = kernel_params_ptr[0] + launch_func_ptr = ctypes.cast(engine.lookup("main"), ctypes.c_void_p) + init_func_ptr = ctypes.cast(engine.lookup("main_init"), ctypes.c_void_p) + # Make sure we won't get accidental hits due to address reuse. + mosaic_gpu_lib._mosaic_gpu_ext.invalidate_cache(init_func_ptr.value) trampoline_args = (ctypes.c_void_p * 2)() - trampoline_args[0] = ctypes.cast(func_ptr, ctypes.c_void_p) - trampoline_args[1] = ctypes.cast(kernel_params, ctypes.c_void_p) + trampoline_args[0] = launch_func_ptr + trampoline_args[1] = init_func_ptr ctx.module_context.add_keepalive(trampoline_args) ptr_bytes = ctypes.cast(trampoline_args, ctypes.c_void_p).value.to_bytes( 8, byteorder="little" @@ -340,7 +338,8 @@ class LaunchContext: ] func.call([], "mosaic_gpu_init_tma_desc", args) def cast_tma_desc(device_ptr): - nvvm.prefetch_tensormap(device_ptr) + # TODO(apaszke): Investigate why prefetching can cause launch failures + # nvvm.prefetch_tensormap(device_ptr) return builtin.unrealized_conversion_cast( [tensor_map_ty], [device_ptr] ) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index a9af39a6a..597ea78a6 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -77,6 +77,7 @@ def measure(f, *args): @jax.jit def run(*args): return _record_event(f(*_record_event(args, start_event)), end_event) + jax.block_until_ready(run(*args)) # Warmup. results = jax.block_until_ready(run(*args)) elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed( start_event, end_event diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index e0d499439..b762f5c64 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -231,6 +231,8 @@ pybind_extension( "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cuda:cuda_vendor", "//jaxlib/mosaic/gpu:mlir_capi", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", "@nanobind", "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index bbfa13fe1..8afd1b21a 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -1,5 +1,8 @@ #include + #include "nanobind/nanobind.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "jaxlib/mosaic/gpu/integrations/c/passes.h" @@ -9,13 +12,56 @@ namespace jax::cuda { namespace { namespace nb = nanobind; +using MosaicInitFunc = void(void***); using MosaicHostFunc = void(void**); +std::pair*, absl::Mutex*> +GetContextCache() { + static absl::Mutex mutex; + static auto& context_cache = *new absl::flat_hash_map; + return std::make_pair(&context_cache, &mutex); +} + +void InvalidateCache(MosaicInitFunc* init) { + auto cache = GetContextCache(); + absl::MutexLock lock(cache.second); + // TODO(apaszke): Free all the resources instead of leaking. + cache.first->erase(reinterpret_cast(init)); +} + +// Each compiled kernel has a unique init func, and each kernel is used from +// a single HLO module. So it should be safe to not include the CUDA context +// in the key. +void* InitOnce(MosaicInitFunc* init) { + auto cache_and_mutex = GetContextCache(); + auto* cache = cache_and_mutex.first; + auto* mutex = cache_and_mutex.second; + + uintptr_t key = reinterpret_cast(init); + + { + // Fast path uses reader lock (as hash map look-up is relatively slow). + absl::ReaderMutexLock lock(mutex); + auto it = cache->find(key); + if (ABSL_PREDICT_TRUE(it != cache->end())) return it->second; + } + + absl::MutexLock lock(mutex); + void*& ctx = (*cache)[key]; + // We released the reader lock, another thread might have initialized it. + if (ctx == nullptr) { + void** ptr_to_ctx = &ctx; + init(&ptr_to_ctx); + } + return ctx; +} + void MosaicKernelCall(void* stream, void** buffers, char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { void** static_args = *reinterpret_cast(opaque); MosaicHostFunc* func = reinterpret_cast(static_args[0]); - void* ctx = static_args[1]; + MosaicInitFunc* init = reinterpret_cast(static_args[1]); + void* ctx = InitOnce(init); void* args[3] = {&ctx, &stream, &buffers}; func(args); } @@ -53,6 +99,9 @@ NB_MODULE(_mosaic_gpu_ext, m) { }); m.def("_record_event_capsule", []() { return EncapsulateFunction(EventRecordCall); }); + m.def("invalidate_cache", [](uintptr_t init_func_ptr) { + return InvalidateCache(reinterpret_cast(init_func_ptr)); + }); } } // namespace diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index bfe6848a2..05fa972c2 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -274,7 +274,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): copy_runfiles( dst_dir=jaxlib_dir / "mosaic" / "gpu", src_files=[ - "__main__/jaxlib/mosaic/gpu/libmlir_cuda_runtime.so", + "__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", ], )