mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Mosaic GPU] Only call kernel initializer from inside a custom call
XLA:GPU custom call design is far from ideal, as there's apparently no way to figure out the CUDA context that will be used to run an HLO module before the custom call is first called. So, we can't preload the kernel onto the GPU, or else we'll get invalid handle errors due to the load and launch happening in different CUDA contexts... Also fix up build_wheel.py to match the rename of the runtime lib. PiperOrigin-RevId: 629401858
This commit is contained in:
parent
649e0521ff
commit
4051ac2a2f
@ -118,16 +118,14 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes)
|
||||
module, opt_level=3, shared_libs=shared_libs, enable_object_dump=False
|
||||
)
|
||||
ctx.module_context.add_keepalive(engine)
|
||||
func_ptr = engine.lookup("main")
|
||||
|
||||
# Run the compile-time initialization.
|
||||
kernel_params_ptr = (ctypes.c_void_p * 1)()
|
||||
engine.invoke("main_init", kernel_params_ptr)
|
||||
kernel_params = kernel_params_ptr[0]
|
||||
launch_func_ptr = ctypes.cast(engine.lookup("main"), ctypes.c_void_p)
|
||||
init_func_ptr = ctypes.cast(engine.lookup("main_init"), ctypes.c_void_p)
|
||||
# Make sure we won't get accidental hits due to address reuse.
|
||||
mosaic_gpu_lib._mosaic_gpu_ext.invalidate_cache(init_func_ptr.value)
|
||||
|
||||
trampoline_args = (ctypes.c_void_p * 2)()
|
||||
trampoline_args[0] = ctypes.cast(func_ptr, ctypes.c_void_p)
|
||||
trampoline_args[1] = ctypes.cast(kernel_params, ctypes.c_void_p)
|
||||
trampoline_args[0] = launch_func_ptr
|
||||
trampoline_args[1] = init_func_ptr
|
||||
ctx.module_context.add_keepalive(trampoline_args)
|
||||
ptr_bytes = ctypes.cast(trampoline_args, ctypes.c_void_p).value.to_bytes(
|
||||
8, byteorder="little"
|
||||
@ -340,7 +338,8 @@ class LaunchContext:
|
||||
]
|
||||
func.call([], "mosaic_gpu_init_tma_desc", args)
|
||||
def cast_tma_desc(device_ptr):
|
||||
nvvm.prefetch_tensormap(device_ptr)
|
||||
# TODO(apaszke): Investigate why prefetching can cause launch failures
|
||||
# nvvm.prefetch_tensormap(device_ptr)
|
||||
return builtin.unrealized_conversion_cast(
|
||||
[tensor_map_ty], [device_ptr]
|
||||
)
|
||||
|
@ -77,6 +77,7 @@ def measure(f, *args):
|
||||
@jax.jit
|
||||
def run(*args):
|
||||
return _record_event(f(*_record_event(args, start_event)), end_event)
|
||||
jax.block_until_ready(run(*args)) # Warmup.
|
||||
results = jax.block_until_ready(run(*args))
|
||||
elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed(
|
||||
start_event, end_event
|
||||
|
@ -231,6 +231,8 @@ pybind_extension(
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"//jaxlib/cuda:cuda_vendor",
|
||||
"//jaxlib/mosaic/gpu:mlir_capi",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@nanobind",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
|
@ -1,5 +1,8 @@
|
||||
#include <cstdint>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_nanobind_helpers.h"
|
||||
#include "jaxlib/mosaic/gpu/integrations/c/passes.h"
|
||||
@ -9,13 +12,56 @@ namespace jax::cuda {
|
||||
namespace {
|
||||
|
||||
namespace nb = nanobind;
|
||||
using MosaicInitFunc = void(void***);
|
||||
using MosaicHostFunc = void(void**);
|
||||
|
||||
std::pair<absl::flat_hash_map<uintptr_t, void*>*, absl::Mutex*>
|
||||
GetContextCache() {
|
||||
static absl::Mutex mutex;
|
||||
static auto& context_cache = *new absl::flat_hash_map<uintptr_t, void*>;
|
||||
return std::make_pair(&context_cache, &mutex);
|
||||
}
|
||||
|
||||
void InvalidateCache(MosaicInitFunc* init) {
|
||||
auto cache = GetContextCache();
|
||||
absl::MutexLock lock(cache.second);
|
||||
// TODO(apaszke): Free all the resources instead of leaking.
|
||||
cache.first->erase(reinterpret_cast<uintptr_t>(init));
|
||||
}
|
||||
|
||||
// Each compiled kernel has a unique init func, and each kernel is used from
|
||||
// a single HLO module. So it should be safe to not include the CUDA context
|
||||
// in the key.
|
||||
void* InitOnce(MosaicInitFunc* init) {
|
||||
auto cache_and_mutex = GetContextCache();
|
||||
auto* cache = cache_and_mutex.first;
|
||||
auto* mutex = cache_and_mutex.second;
|
||||
|
||||
uintptr_t key = reinterpret_cast<uintptr_t>(init);
|
||||
|
||||
{
|
||||
// Fast path uses reader lock (as hash map look-up is relatively slow).
|
||||
absl::ReaderMutexLock lock(mutex);
|
||||
auto it = cache->find(key);
|
||||
if (ABSL_PREDICT_TRUE(it != cache->end())) return it->second;
|
||||
}
|
||||
|
||||
absl::MutexLock lock(mutex);
|
||||
void*& ctx = (*cache)[key];
|
||||
// We released the reader lock, another thread might have initialized it.
|
||||
if (ctx == nullptr) {
|
||||
void** ptr_to_ctx = &ctx;
|
||||
init(&ptr_to_ctx);
|
||||
}
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void MosaicKernelCall(void* stream, void** buffers, char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
void** static_args = *reinterpret_cast<void***>(opaque);
|
||||
MosaicHostFunc* func = reinterpret_cast<MosaicHostFunc*>(static_args[0]);
|
||||
void* ctx = static_args[1];
|
||||
MosaicInitFunc* init = reinterpret_cast<MosaicInitFunc*>(static_args[1]);
|
||||
void* ctx = InitOnce(init);
|
||||
void* args[3] = {&ctx, &stream, &buffers};
|
||||
func(args);
|
||||
}
|
||||
@ -53,6 +99,9 @@ NB_MODULE(_mosaic_gpu_ext, m) {
|
||||
});
|
||||
m.def("_record_event_capsule",
|
||||
[]() { return EncapsulateFunction(EventRecordCall); });
|
||||
m.def("invalidate_cache", [](uintptr_t init_func_ptr) {
|
||||
return InvalidateCache(reinterpret_cast<MosaicInitFunc*>(init_func_ptr));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -274,7 +274,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mosaic" / "gpu",
|
||||
src_files=[
|
||||
"__main__/jaxlib/mosaic/gpu/libmlir_cuda_runtime.so",
|
||||
"__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so",
|
||||
],
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user