Add basic PyTorch integration for Mosaic GPU

We have already had most of the relevant pieces and we only needed
to connect them together. The most sensitive change is perhaps that
I needed to expose one more symbol from the XLA GPU plugin, but I don't
think it should be a problem.
This commit is contained in:
Adam Paszke 2024-09-06 16:09:58 +00:00
parent 7326db7791
commit 611ad63060
5 changed files with 212 additions and 41 deletions

View File

@ -27,6 +27,7 @@ import subprocess
import tempfile
import time
from typing import Any, Generic, TypeVar
import weakref
import jax
from jax._src import config
@ -800,6 +801,21 @@ def _lower_as_gpu_kernel(
return module, out_shape, unwrap_output_tuple
def _declare_runtime_functions():
"""Declares the runtime functions that can be used by the generated code."""
ptr_ty = ir.Type.parse("!llvm.ptr")
i64 = ir.IntegerType.get_signless(64)
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
func.FuncOp(
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
)
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
func.FuncOp(
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
)
def as_gpu_kernel(
body,
grid: tuple[int, int, int],
@ -867,16 +883,97 @@ def as_gpu_kernel(
return kernel
def _declare_runtime_functions():
"""Declares the runtime functions that can be used by the generated code."""
ptr_ty = ir.Type.parse("!llvm.ptr")
i64 = ir.IntegerType.get_signless(64)
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
func.FuncOp(
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
)
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
func.FuncOp(
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
def as_torch_gpu_kernel(
body,
grid: tuple[int, int, int],
block: tuple[int, int, int],
in_shape,
out_shape,
smem_scratch_shape: ShapeTree | Union[ShapeTree],
prof_spec: profiler.ProfilerSpec | None = None,
cluster: tuple[int, int, int] = (1, 1, 1),
module_name: str = "unknown",
):
try:
import torch
except ImportError:
raise RuntimeError("as_torch_gpu_kernel requires PyTorch")
torch.cuda.init() # Make sure CUDA context is set up.
if isinstance(in_shape, list):
in_shape = tuple(in_shape)
elif not isinstance(in_shape, tuple):
in_shape = (in_shape,)
flat_out_types, out_treedef = jax.tree.flatten(out_shape)
expected_arg_treedef = jax.tree.structure(in_shape)
module, out_shape, unwrap_output_tuple = (
_lower_as_gpu_kernel(
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
module_name, prof_spec
)
)
# Get our hands on the compilation and unload functions
try:
import jax_plugins.xla_cuda12 as cuda_plugin
except ImportError:
raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds "
"that use backend plugins")
dll = ctypes.CDLL(cuda_plugin._get_library_path())
compile_func = dll.MosaicGpuCompile
compile_func.argtypes = [ctypes.c_void_p]
compile_func.restype = ctypes.POINTER(ctypes.c_void_p)
unload_func = dll.MosaicGpuUnload
unload_func.argtypes = [compile_func.restype]
unload_func.restype = None
module_asm = module.operation.get_asm(binary=True, enable_debug_info=True)
compiled = compile_func(ctypes.c_char_p(module_asm))
if compiled is None:
raise RuntimeError("Failed to compile the module")
ctx, launch_ptr = compiled[0], compiled[1]
ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx))
launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr)
def as_torch_dtype(dtype):
# torch contains NumPy-compatible dtypes in its top namespace
return getattr(torch, np.dtype(dtype).name)
def apply(*args):
flat_args, arg_treedef = jax.tree.flatten(args)
if arg_treedef != expected_arg_treedef:
raise ValueError(
f"Invalid argument structure: expected {expected_arg_treedef}, got"
f" {arg_treedef}, ({args=})"
)
# Construct a device pointer list like in the XLA calling convention
buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))()
i = -1 # Define i in case there are no args
device = 'cuda'
for i, arg in enumerate(flat_args):
buffers[i] = arg.data_ptr()
device = arg.device
flat_outs = []
for i, t in enumerate(flat_out_types, i + 1):
out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device)
flat_outs.append(out)
buffers[i] = out.data_ptr()
# Allocate another buffer for args of the host-side program. This is sadly
# the default MLIR calling convention.
args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)()
args_ptr[0] = ctx_ptr_ptr
args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_)
args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)),
ctypes.POINTER(ctypes.c_void_p))
launch(args_ptr)
return jax.tree.unflatten(out_treedef, flat_outs)
# Unload the compiled code when the Python function is destroyed.
def unload(_):
unload_func(compiled)
apply.destructor = weakref.ref(apply, unload)
return apply

View File

@ -377,10 +377,40 @@ GetKernelCache() {
return std::make_pair(&context_cache, &mutex);
}
absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
InitContext(&context);
mlir::ParserConfig parse_config(&context);
auto module_op =
mlir::parseSourceString<mlir::ModuleOp>(module, parse_config);
if (!module_op) {
return absl::InternalError("Failed to parse module");
}
auto maybe_engine = Compile(*module_op);
if (!maybe_engine.ok()) {
return maybe_engine.status();
}
mlir::ExecutionEngine* execution_engine = maybe_engine->get();
auto main = execution_engine->lookupPacked("_mlir_ciface_main");
auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
if (!init || !main) {
return absl::InternalError("Failed to retrieve kernel function");
}
void* module_ptr = nullptr;
void* kernel_ptr = nullptr;
void** module_ptr_ptr = &module_ptr;
void** kernel_ptr_ptr = &kernel_ptr;
void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
reinterpret_cast<MosaicInitFunc*>(*init)(init_args);
return CompiledKernel(std::move(*maybe_engine), kernel_ptr,
reinterpret_cast<MosaicHostFunc*>(*main));
}
// Each compiled kernel has a unique init func, and each kernel is used from
// a single HLO module. So it should be safe to not include the CUDA context
// in the key.
absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CompileAndInit(
absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CachedCompileAndInit(
CacheKey key, const char* module) {
auto cache_and_mutex = GetKernelCache();
auto* cache = cache_and_mutex.first;
@ -397,33 +427,11 @@ absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CompileAndInit(
absl::MutexLock lock(mutex);
// We released the reader lock, another thread might have initialized it.
if (cache->find(key) == cache->end()) {
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
InitContext(&context);
mlir::ParserConfig parse_config(&context);
auto module_op =
mlir::parseSourceString<mlir::ModuleOp>(module, parse_config);
if (!module_op) {
return absl::InternalError("Failed to parse module");
auto compiled = CompileAndInit(module);
if (!compiled.ok()) {
return compiled.status();
}
auto maybe_engine = Compile(*module_op);
if (!maybe_engine.ok()) {
return maybe_engine.status();
}
mlir::ExecutionEngine* execution_engine = maybe_engine->get();
auto main = execution_engine->lookupPacked("_mlir_ciface_main");
auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
if (!init || !main) {
return absl::InternalError("Failed to retrieve kernel function");
}
void* module_ptr = nullptr;
void* kernel_ptr = nullptr;
void** module_ptr_ptr = &module_ptr;
void** kernel_ptr_ptr = &kernel_ptr;
void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
reinterpret_cast<MosaicInitFunc*>(*init)(init_args);
cache->insert_or_assign(
key, CompiledKernel(std::move(*maybe_engine), kernel_ptr,
reinterpret_cast<MosaicHostFunc*>(*main)));
cache->insert_or_assign(key, std::move(*compiled));
}
return cache->at(key).GetHostLaunch();
}
@ -441,7 +449,7 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque,
abort();
}
CacheKey key(hash, reinterpret_cast<uintptr_t>(ctx));
auto ctx_and_kernel = CompileAndInit(key, opaque + sizeof(KernelHash));
auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash));
if (!ctx_and_kernel.ok()) {
XlaCustomCallStatusSetFailure(status,
ctx_and_kernel.status().message().data(),
@ -456,3 +464,33 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall,
"CUDA");
} // namespace
extern "C" {
__attribute__((visibility("default")))
void** MosaicGpuCompile(const char* module) {
auto compiled = CompileAndInit(module);
if (!compiled.ok()) {
return nullptr;
}
auto [ctx, launch] = compiled->GetHostLaunch();
auto tuple_ptr = std::unique_ptr<void*>(new void*[3]);
if (!tuple_ptr) {
return nullptr;
}
tuple_ptr.get()[0] = ctx;
tuple_ptr.get()[1] = reinterpret_cast<void*>(launch);
tuple_ptr.get()[2] = new CompiledKernel(std::move(*compiled));
if (!tuple_ptr.get()[2]) {
return nullptr;
}
return tuple_ptr.release();
}
__attribute__((visibility("default")))
void MosaicGpuUnload(void** tuple_ptr) {
delete reinterpret_cast<CompiledKernel*>(tuple_ptr[2]);
delete[] tuple_ptr;
}
} // extern "C"

View File

@ -64,11 +64,12 @@ py_test(
cc_binary(
name = "pjrt_c_api_gpu_plugin.so",
linkopts = [
"-Wl,--version-script,$(location @xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds)",
"-Wl,--version-script,$(location :gpu_version_script.lds)",
"-Wl,--no-undefined",
],
linkshared = True,
deps = [
":gpu_version_script.lds",
"@xla//xla/pjrt/c:pjrt_c_api_gpu",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds",
"@xla//xla/service:gpu_plugin",

View File

@ -0,0 +1,11 @@
VERS_1.0 {
global:
extern "C" {
GetPjrtApi;
MosaicGpuCompile;
MosaicGpuUnload;
};
local:
*;
};

View File

@ -19,6 +19,7 @@ from functools import partial
import itertools
import math
import operator
import unittest
from absl.testing import absltest, parameterized
import jax
@ -1387,5 +1388,28 @@ class ProfilerTest(TestCase):
jax.block_until_ready(f(xd))
class TorchTest(TestCase):
@classmethod
def setUpClass(cls):
try:
import torch
except ImportError:
raise unittest.SkipTest("Test requires PyTorch")
cls.torch = torch
def test_basic(self):
def kernel(ctx, i_gmem, o_gmem, _):
x = mgpu.FragmentedArray.load_strided(i_gmem)
(x + x).store_untiled(o_gmem)
ty = jax.ShapeDtypeStruct((128, 128), jnp.float32)
x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda')
f = mosaic_gpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ())
y = f(x)
np.testing.assert_allclose(y.cpu(), x.cpu() * 2)
del y # Make sure the destructor runs successfully.
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())