[Mosaic GPU] Only call kernel initializer from inside a custom call

XLA:GPU custom call design is far from ideal, as there's apparently no way to figure
out the CUDA context that will be used to run an HLO module before the custom call is
first called. So, we can't preload the kernel onto the GPU, or else we'll get invalid
handle errors due to the load and launch happening in different CUDA contexts...

Also fix up build_wheel.py to match the rename of the runtime lib.

PiperOrigin-RevId: 629401858
This commit is contained in:
Adam Paszke 2024-04-30 07:09:12 -07:00 committed by jax authors
parent 649e0521ff
commit 4051ac2a2f
5 changed files with 62 additions and 11 deletions

View File

@ -118,16 +118,14 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes)
module, opt_level=3, shared_libs=shared_libs, enable_object_dump=False
)
ctx.module_context.add_keepalive(engine)
func_ptr = engine.lookup("main")
# Run the compile-time initialization.
kernel_params_ptr = (ctypes.c_void_p * 1)()
engine.invoke("main_init", kernel_params_ptr)
kernel_params = kernel_params_ptr[0]
launch_func_ptr = ctypes.cast(engine.lookup("main"), ctypes.c_void_p)
init_func_ptr = ctypes.cast(engine.lookup("main_init"), ctypes.c_void_p)
# Make sure we won't get accidental hits due to address reuse.
mosaic_gpu_lib._mosaic_gpu_ext.invalidate_cache(init_func_ptr.value)
trampoline_args = (ctypes.c_void_p * 2)()
trampoline_args[0] = ctypes.cast(func_ptr, ctypes.c_void_p)
trampoline_args[1] = ctypes.cast(kernel_params, ctypes.c_void_p)
trampoline_args[0] = launch_func_ptr
trampoline_args[1] = init_func_ptr
ctx.module_context.add_keepalive(trampoline_args)
ptr_bytes = ctypes.cast(trampoline_args, ctypes.c_void_p).value.to_bytes(
8, byteorder="little"
@ -340,7 +338,8 @@ class LaunchContext:
]
func.call([], "mosaic_gpu_init_tma_desc", args)
def cast_tma_desc(device_ptr):
nvvm.prefetch_tensormap(device_ptr)
# TODO(apaszke): Investigate why prefetching can cause launch failures
# nvvm.prefetch_tensormap(device_ptr)
return builtin.unrealized_conversion_cast(
[tensor_map_ty], [device_ptr]
)

View File

@ -77,6 +77,7 @@ def measure(f, *args):
@jax.jit
def run(*args):
return _record_event(f(*_record_event(args, start_event)), end_event)
jax.block_until_ready(run(*args)) # Warmup.
results = jax.block_until_ready(run(*args))
elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed(
start_event, end_event

View File

@ -231,6 +231,8 @@ pybind_extension(
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cuda:cuda_vendor",
"//jaxlib/mosaic/gpu:mlir_capi",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization",
"@nanobind",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",

View File

@ -1,5 +1,8 @@
#include <cstdint>
#include "nanobind/nanobind.h"
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "jaxlib/mosaic/gpu/integrations/c/passes.h"
@ -9,13 +12,56 @@ namespace jax::cuda {
namespace {
namespace nb = nanobind;
using MosaicInitFunc = void(void***);
using MosaicHostFunc = void(void**);
std::pair<absl::flat_hash_map<uintptr_t, void*>*, absl::Mutex*>
GetContextCache() {
static absl::Mutex mutex;
static auto& context_cache = *new absl::flat_hash_map<uintptr_t, void*>;
return std::make_pair(&context_cache, &mutex);
}
void InvalidateCache(MosaicInitFunc* init) {
auto cache = GetContextCache();
absl::MutexLock lock(cache.second);
// TODO(apaszke): Free all the resources instead of leaking.
cache.first->erase(reinterpret_cast<uintptr_t>(init));
}
// Each compiled kernel has a unique init func, and each kernel is used from
// a single HLO module. So it should be safe to not include the CUDA context
// in the key.
void* InitOnce(MosaicInitFunc* init) {
auto cache_and_mutex = GetContextCache();
auto* cache = cache_and_mutex.first;
auto* mutex = cache_and_mutex.second;
uintptr_t key = reinterpret_cast<uintptr_t>(init);
{
// Fast path uses reader lock (as hash map look-up is relatively slow).
absl::ReaderMutexLock lock(mutex);
auto it = cache->find(key);
if (ABSL_PREDICT_TRUE(it != cache->end())) return it->second;
}
absl::MutexLock lock(mutex);
void*& ctx = (*cache)[key];
// We released the reader lock, another thread might have initialized it.
if (ctx == nullptr) {
void** ptr_to_ctx = &ctx;
init(&ptr_to_ctx);
}
return ctx;
}
void MosaicKernelCall(void* stream, void** buffers, char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
void** static_args = *reinterpret_cast<void***>(opaque);
MosaicHostFunc* func = reinterpret_cast<MosaicHostFunc*>(static_args[0]);
void* ctx = static_args[1];
MosaicInitFunc* init = reinterpret_cast<MosaicInitFunc*>(static_args[1]);
void* ctx = InitOnce(init);
void* args[3] = {&ctx, &stream, &buffers};
func(args);
}
@ -53,6 +99,9 @@ NB_MODULE(_mosaic_gpu_ext, m) {
});
m.def("_record_event_capsule",
[]() { return EncapsulateFunction(EventRecordCall); });
m.def("invalidate_cache", [](uintptr_t init_func_ptr) {
return InvalidateCache(reinterpret_cast<MosaicInitFunc*>(init_func_ptr));
});
}
} // namespace

View File

@ -274,7 +274,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
copy_runfiles(
dst_dir=jaxlib_dir / "mosaic" / "gpu",
src_files=[
"__main__/jaxlib/mosaic/gpu/libmlir_cuda_runtime.so",
"__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so",
],
)