mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Add basic PyTorch integration for Mosaic GPU
We have already had most of the relevant pieces and we only needed to connect them together. The most sensitive change is perhaps that I needed to expose one more symbol from the XLA GPU plugin, but I don't think it should be a problem.
This commit is contained in:
parent
7326db7791
commit
611ad63060
@ -27,6 +27,7 @@ import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Generic, TypeVar
|
||||
import weakref
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
@ -800,6 +801,21 @@ def _lower_as_gpu_kernel(
|
||||
return module, out_shape, unwrap_output_tuple
|
||||
|
||||
|
||||
def _declare_runtime_functions():
|
||||
"""Declares the runtime functions that can be used by the generated code."""
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
|
||||
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
|
||||
func.FuncOp(
|
||||
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
|
||||
)
|
||||
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
|
||||
func.FuncOp(
|
||||
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
|
||||
)
|
||||
|
||||
|
||||
def as_gpu_kernel(
|
||||
body,
|
||||
grid: tuple[int, int, int],
|
||||
@ -867,16 +883,97 @@ def as_gpu_kernel(
|
||||
return kernel
|
||||
|
||||
|
||||
def _declare_runtime_functions():
|
||||
"""Declares the runtime functions that can be used by the generated code."""
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
|
||||
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
|
||||
func.FuncOp(
|
||||
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
|
||||
)
|
||||
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
|
||||
func.FuncOp(
|
||||
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
|
||||
def as_torch_gpu_kernel(
|
||||
body,
|
||||
grid: tuple[int, int, int],
|
||||
block: tuple[int, int, int],
|
||||
in_shape,
|
||||
out_shape,
|
||||
smem_scratch_shape: ShapeTree | Union[ShapeTree],
|
||||
prof_spec: profiler.ProfilerSpec | None = None,
|
||||
cluster: tuple[int, int, int] = (1, 1, 1),
|
||||
module_name: str = "unknown",
|
||||
):
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise RuntimeError("as_torch_gpu_kernel requires PyTorch")
|
||||
torch.cuda.init() # Make sure CUDA context is set up.
|
||||
|
||||
if isinstance(in_shape, list):
|
||||
in_shape = tuple(in_shape)
|
||||
elif not isinstance(in_shape, tuple):
|
||||
in_shape = (in_shape,)
|
||||
|
||||
flat_out_types, out_treedef = jax.tree.flatten(out_shape)
|
||||
expected_arg_treedef = jax.tree.structure(in_shape)
|
||||
|
||||
module, out_shape, unwrap_output_tuple = (
|
||||
_lower_as_gpu_kernel(
|
||||
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
|
||||
module_name, prof_spec
|
||||
)
|
||||
)
|
||||
|
||||
# Get our hands on the compilation and unload functions
|
||||
try:
|
||||
import jax_plugins.xla_cuda12 as cuda_plugin
|
||||
except ImportError:
|
||||
raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds "
|
||||
"that use backend plugins")
|
||||
dll = ctypes.CDLL(cuda_plugin._get_library_path())
|
||||
compile_func = dll.MosaicGpuCompile
|
||||
compile_func.argtypes = [ctypes.c_void_p]
|
||||
compile_func.restype = ctypes.POINTER(ctypes.c_void_p)
|
||||
unload_func = dll.MosaicGpuUnload
|
||||
unload_func.argtypes = [compile_func.restype]
|
||||
unload_func.restype = None
|
||||
|
||||
module_asm = module.operation.get_asm(binary=True, enable_debug_info=True)
|
||||
compiled = compile_func(ctypes.c_char_p(module_asm))
|
||||
if compiled is None:
|
||||
raise RuntimeError("Failed to compile the module")
|
||||
ctx, launch_ptr = compiled[0], compiled[1]
|
||||
ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx))
|
||||
launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr)
|
||||
|
||||
def as_torch_dtype(dtype):
|
||||
# torch contains NumPy-compatible dtypes in its top namespace
|
||||
return getattr(torch, np.dtype(dtype).name)
|
||||
|
||||
def apply(*args):
|
||||
flat_args, arg_treedef = jax.tree.flatten(args)
|
||||
if arg_treedef != expected_arg_treedef:
|
||||
raise ValueError(
|
||||
f"Invalid argument structure: expected {expected_arg_treedef}, got"
|
||||
f" {arg_treedef}, ({args=})"
|
||||
)
|
||||
|
||||
# Construct a device pointer list like in the XLA calling convention
|
||||
buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))()
|
||||
i = -1 # Define i in case there are no args
|
||||
device = 'cuda'
|
||||
for i, arg in enumerate(flat_args):
|
||||
buffers[i] = arg.data_ptr()
|
||||
device = arg.device
|
||||
flat_outs = []
|
||||
for i, t in enumerate(flat_out_types, i + 1):
|
||||
out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device)
|
||||
flat_outs.append(out)
|
||||
buffers[i] = out.data_ptr()
|
||||
# Allocate another buffer for args of the host-side program. This is sadly
|
||||
# the default MLIR calling convention.
|
||||
args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)()
|
||||
args_ptr[0] = ctx_ptr_ptr
|
||||
args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_)
|
||||
args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)),
|
||||
ctypes.POINTER(ctypes.c_void_p))
|
||||
launch(args_ptr)
|
||||
return jax.tree.unflatten(out_treedef, flat_outs)
|
||||
|
||||
# Unload the compiled code when the Python function is destroyed.
|
||||
def unload(_):
|
||||
unload_func(compiled)
|
||||
apply.destructor = weakref.ref(apply, unload)
|
||||
|
||||
return apply
|
||||
|
@ -377,10 +377,40 @@ GetKernelCache() {
|
||||
return std::make_pair(&context_cache, &mutex);
|
||||
}
|
||||
|
||||
|
||||
absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
|
||||
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
|
||||
InitContext(&context);
|
||||
mlir::ParserConfig parse_config(&context);
|
||||
auto module_op =
|
||||
mlir::parseSourceString<mlir::ModuleOp>(module, parse_config);
|
||||
if (!module_op) {
|
||||
return absl::InternalError("Failed to parse module");
|
||||
}
|
||||
auto maybe_engine = Compile(*module_op);
|
||||
if (!maybe_engine.ok()) {
|
||||
return maybe_engine.status();
|
||||
}
|
||||
mlir::ExecutionEngine* execution_engine = maybe_engine->get();
|
||||
auto main = execution_engine->lookupPacked("_mlir_ciface_main");
|
||||
auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
|
||||
if (!init || !main) {
|
||||
return absl::InternalError("Failed to retrieve kernel function");
|
||||
}
|
||||
void* module_ptr = nullptr;
|
||||
void* kernel_ptr = nullptr;
|
||||
void** module_ptr_ptr = &module_ptr;
|
||||
void** kernel_ptr_ptr = &kernel_ptr;
|
||||
void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
|
||||
reinterpret_cast<MosaicInitFunc*>(*init)(init_args);
|
||||
return CompiledKernel(std::move(*maybe_engine), kernel_ptr,
|
||||
reinterpret_cast<MosaicHostFunc*>(*main));
|
||||
}
|
||||
|
||||
// Each compiled kernel has a unique init func, and each kernel is used from
|
||||
// a single HLO module. So it should be safe to not include the CUDA context
|
||||
// in the key.
|
||||
absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CompileAndInit(
|
||||
absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CachedCompileAndInit(
|
||||
CacheKey key, const char* module) {
|
||||
auto cache_and_mutex = GetKernelCache();
|
||||
auto* cache = cache_and_mutex.first;
|
||||
@ -397,33 +427,11 @@ absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CompileAndInit(
|
||||
absl::MutexLock lock(mutex);
|
||||
// We released the reader lock, another thread might have initialized it.
|
||||
if (cache->find(key) == cache->end()) {
|
||||
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
|
||||
InitContext(&context);
|
||||
mlir::ParserConfig parse_config(&context);
|
||||
auto module_op =
|
||||
mlir::parseSourceString<mlir::ModuleOp>(module, parse_config);
|
||||
if (!module_op) {
|
||||
return absl::InternalError("Failed to parse module");
|
||||
auto compiled = CompileAndInit(module);
|
||||
if (!compiled.ok()) {
|
||||
return compiled.status();
|
||||
}
|
||||
auto maybe_engine = Compile(*module_op);
|
||||
if (!maybe_engine.ok()) {
|
||||
return maybe_engine.status();
|
||||
}
|
||||
mlir::ExecutionEngine* execution_engine = maybe_engine->get();
|
||||
auto main = execution_engine->lookupPacked("_mlir_ciface_main");
|
||||
auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
|
||||
if (!init || !main) {
|
||||
return absl::InternalError("Failed to retrieve kernel function");
|
||||
}
|
||||
void* module_ptr = nullptr;
|
||||
void* kernel_ptr = nullptr;
|
||||
void** module_ptr_ptr = &module_ptr;
|
||||
void** kernel_ptr_ptr = &kernel_ptr;
|
||||
void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
|
||||
reinterpret_cast<MosaicInitFunc*>(*init)(init_args);
|
||||
cache->insert_or_assign(
|
||||
key, CompiledKernel(std::move(*maybe_engine), kernel_ptr,
|
||||
reinterpret_cast<MosaicHostFunc*>(*main)));
|
||||
cache->insert_or_assign(key, std::move(*compiled));
|
||||
}
|
||||
return cache->at(key).GetHostLaunch();
|
||||
}
|
||||
@ -441,7 +449,7 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque,
|
||||
abort();
|
||||
}
|
||||
CacheKey key(hash, reinterpret_cast<uintptr_t>(ctx));
|
||||
auto ctx_and_kernel = CompileAndInit(key, opaque + sizeof(KernelHash));
|
||||
auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash));
|
||||
if (!ctx_and_kernel.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status,
|
||||
ctx_and_kernel.status().message().data(),
|
||||
@ -456,3 +464,33 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall,
|
||||
"CUDA");
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
|
||||
__attribute__((visibility("default")))
|
||||
void** MosaicGpuCompile(const char* module) {
|
||||
auto compiled = CompileAndInit(module);
|
||||
if (!compiled.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto [ctx, launch] = compiled->GetHostLaunch();
|
||||
auto tuple_ptr = std::unique_ptr<void*>(new void*[3]);
|
||||
if (!tuple_ptr) {
|
||||
return nullptr;
|
||||
}
|
||||
tuple_ptr.get()[0] = ctx;
|
||||
tuple_ptr.get()[1] = reinterpret_cast<void*>(launch);
|
||||
tuple_ptr.get()[2] = new CompiledKernel(std::move(*compiled));
|
||||
if (!tuple_ptr.get()[2]) {
|
||||
return nullptr;
|
||||
}
|
||||
return tuple_ptr.release();
|
||||
}
|
||||
|
||||
__attribute__((visibility("default")))
|
||||
void MosaicGpuUnload(void** tuple_ptr) {
|
||||
delete reinterpret_cast<CompiledKernel*>(tuple_ptr[2]);
|
||||
delete[] tuple_ptr;
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
@ -64,11 +64,12 @@ py_test(
|
||||
cc_binary(
|
||||
name = "pjrt_c_api_gpu_plugin.so",
|
||||
linkopts = [
|
||||
"-Wl,--version-script,$(location @xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds)",
|
||||
"-Wl,--version-script,$(location :gpu_version_script.lds)",
|
||||
"-Wl,--no-undefined",
|
||||
],
|
||||
linkshared = True,
|
||||
deps = [
|
||||
":gpu_version_script.lds",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds",
|
||||
"@xla//xla/service:gpu_plugin",
|
||||
|
11
jaxlib/tools/gpu_version_script.lds
Normal file
11
jaxlib/tools/gpu_version_script.lds
Normal file
@ -0,0 +1,11 @@
|
||||
VERS_1.0 {
|
||||
global:
|
||||
extern "C" {
|
||||
GetPjrtApi;
|
||||
MosaicGpuCompile;
|
||||
MosaicGpuUnload;
|
||||
};
|
||||
|
||||
local:
|
||||
*;
|
||||
};
|
@ -19,6 +19,7 @@ from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
import jax
|
||||
@ -1387,5 +1388,28 @@ class ProfilerTest(TestCase):
|
||||
jax.block_until_ready(f(xd))
|
||||
|
||||
|
||||
class TorchTest(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise unittest.SkipTest("Test requires PyTorch")
|
||||
cls.torch = torch
|
||||
|
||||
def test_basic(self):
|
||||
def kernel(ctx, i_gmem, o_gmem, _):
|
||||
x = mgpu.FragmentedArray.load_strided(i_gmem)
|
||||
(x + x).store_untiled(o_gmem)
|
||||
|
||||
ty = jax.ShapeDtypeStruct((128, 128), jnp.float32)
|
||||
x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda')
|
||||
f = mosaic_gpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ())
|
||||
y = f(x)
|
||||
np.testing.assert_allclose(y.cpu(), x.cpu() * 2)
|
||||
del y # Make sure the destructor runs successfully.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user