Merge pull request #23653 from apaszke:torchsaic

PiperOrigin-RevId: 675967844
This commit is contained in:
jax authors 2024-09-18 06:35:15 -07:00
commit 4e6f690724
5 changed files with 212 additions and 41 deletions

View File

@ -27,6 +27,7 @@ import subprocess
import tempfile
import time
from typing import Any, Generic, TypeVar
import weakref
import jax
from jax._src import config
@ -800,6 +801,21 @@ def _lower_as_gpu_kernel(
return module, out_shape, unwrap_output_tuple
def _declare_runtime_functions():
"""Declares the runtime functions that can be used by the generated code."""
ptr_ty = ir.Type.parse("!llvm.ptr")
i64 = ir.IntegerType.get_signless(64)
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
func.FuncOp(
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
)
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
func.FuncOp(
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
)
def as_gpu_kernel(
body,
grid: tuple[int, int, int],
@ -867,16 +883,97 @@ def as_gpu_kernel(
return kernel
def _declare_runtime_functions():
"""Declares the runtime functions that can be used by the generated code."""
ptr_ty = ir.Type.parse("!llvm.ptr")
i64 = ir.IntegerType.get_signless(64)
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
func.FuncOp(
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
)
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
func.FuncOp(
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
def as_torch_gpu_kernel(
body,
grid: tuple[int, int, int],
block: tuple[int, int, int],
in_shape,
out_shape,
smem_scratch_shape: ShapeTree | Union[ShapeTree],
prof_spec: profiler.ProfilerSpec | None = None,
cluster: tuple[int, int, int] = (1, 1, 1),
module_name: str = "unknown",
):
try:
import torch
except ImportError:
raise RuntimeError("as_torch_gpu_kernel requires PyTorch")
torch.cuda.init() # Make sure CUDA context is set up.
if isinstance(in_shape, list):
in_shape = tuple(in_shape)
elif not isinstance(in_shape, tuple):
in_shape = (in_shape,)
flat_out_types, out_treedef = jax.tree.flatten(out_shape)
expected_arg_treedef = jax.tree.structure(in_shape)
module, out_shape, unwrap_output_tuple = (
_lower_as_gpu_kernel(
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
module_name, prof_spec
)
)
# Get our hands on the compilation and unload functions
try:
import jax_plugins.xla_cuda12 as cuda_plugin
except ImportError:
raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds "
"that use backend plugins")
dll = ctypes.CDLL(cuda_plugin._get_library_path())
compile_func = dll.MosaicGpuCompile
compile_func.argtypes = [ctypes.c_void_p]
compile_func.restype = ctypes.POINTER(ctypes.c_void_p)
unload_func = dll.MosaicGpuUnload
unload_func.argtypes = [compile_func.restype]
unload_func.restype = None
module_asm = module.operation.get_asm(binary=True, enable_debug_info=True)
compiled = compile_func(ctypes.c_char_p(module_asm))
if compiled is None:
raise RuntimeError("Failed to compile the module")
ctx, launch_ptr = compiled[0], compiled[1]
ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx))
launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr)
def as_torch_dtype(dtype):
# torch contains NumPy-compatible dtypes in its top namespace
return getattr(torch, np.dtype(dtype).name)
def apply(*args):
flat_args, arg_treedef = jax.tree.flatten(args)
if arg_treedef != expected_arg_treedef:
raise ValueError(
f"Invalid argument structure: expected {expected_arg_treedef}, got"
f" {arg_treedef}, ({args=})"
)
# Construct a device pointer list like in the XLA calling convention
buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))()
i = -1 # Define i in case there are no args
device = 'cuda'
for i, arg in enumerate(flat_args):
buffers[i] = arg.data_ptr()
device = arg.device
flat_outs = []
for i, t in enumerate(flat_out_types, i + 1):
out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device)
flat_outs.append(out)
buffers[i] = out.data_ptr()
# Allocate another buffer for args of the host-side program. This is sadly
# the default MLIR calling convention.
args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)()
args_ptr[0] = ctx_ptr_ptr
args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_)
args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)),
ctypes.POINTER(ctypes.c_void_p))
launch(args_ptr)
return jax.tree.unflatten(out_treedef, flat_outs)
# Unload the compiled code when the Python function is destroyed.
def unload(_):
unload_func(compiled)
apply.destructor = weakref.ref(apply, unload)
return apply

View File

@ -377,10 +377,40 @@ GetKernelCache() {
return std::make_pair(&context_cache, &mutex);
}
absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
InitContext(&context);
mlir::ParserConfig parse_config(&context);
auto module_op =
mlir::parseSourceString<mlir::ModuleOp>(module, parse_config);
if (!module_op) {
return absl::InternalError("Failed to parse module");
}
auto maybe_engine = Compile(*module_op);
if (!maybe_engine.ok()) {
return maybe_engine.status();
}
mlir::ExecutionEngine* execution_engine = maybe_engine->get();
auto main = execution_engine->lookupPacked("_mlir_ciface_main");
auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
if (!init || !main) {
return absl::InternalError("Failed to retrieve kernel function");
}
void* module_ptr = nullptr;
void* kernel_ptr = nullptr;
void** module_ptr_ptr = &module_ptr;
void** kernel_ptr_ptr = &kernel_ptr;
void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
reinterpret_cast<MosaicInitFunc*>(*init)(init_args);
return CompiledKernel(std::move(*maybe_engine), kernel_ptr,
reinterpret_cast<MosaicHostFunc*>(*main));
}
// Each compiled kernel has a unique init func, and each kernel is used from
// a single HLO module. So it should be safe to not include the CUDA context
// in the key.
absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CompileAndInit(
absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CachedCompileAndInit(
CacheKey key, const char* module) {
auto cache_and_mutex = GetKernelCache();
auto* cache = cache_and_mutex.first;
@ -397,33 +427,11 @@ absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CompileAndInit(
absl::MutexLock lock(mutex);
// We released the reader lock, another thread might have initialized it.
if (cache->find(key) == cache->end()) {
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
InitContext(&context);
mlir::ParserConfig parse_config(&context);
auto module_op =
mlir::parseSourceString<mlir::ModuleOp>(module, parse_config);
if (!module_op) {
return absl::InternalError("Failed to parse module");
auto compiled = CompileAndInit(module);
if (!compiled.ok()) {
return compiled.status();
}
auto maybe_engine = Compile(*module_op);
if (!maybe_engine.ok()) {
return maybe_engine.status();
}
mlir::ExecutionEngine* execution_engine = maybe_engine->get();
auto main = execution_engine->lookupPacked("_mlir_ciface_main");
auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
if (!init || !main) {
return absl::InternalError("Failed to retrieve kernel function");
}
void* module_ptr = nullptr;
void* kernel_ptr = nullptr;
void** module_ptr_ptr = &module_ptr;
void** kernel_ptr_ptr = &kernel_ptr;
void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
reinterpret_cast<MosaicInitFunc*>(*init)(init_args);
cache->insert_or_assign(
key, CompiledKernel(std::move(*maybe_engine), kernel_ptr,
reinterpret_cast<MosaicHostFunc*>(*main)));
cache->insert_or_assign(key, std::move(*compiled));
}
return cache->at(key).GetHostLaunch();
}
@ -441,7 +449,7 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque,
abort();
}
CacheKey key(hash, reinterpret_cast<uintptr_t>(ctx));
auto ctx_and_kernel = CompileAndInit(key, opaque + sizeof(KernelHash));
auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash));
if (!ctx_and_kernel.ok()) {
XlaCustomCallStatusSetFailure(status,
ctx_and_kernel.status().message().data(),
@ -456,3 +464,33 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall,
"CUDA");
} // namespace
extern "C" {
__attribute__((visibility("default")))
void** MosaicGpuCompile(const char* module) {
auto compiled = CompileAndInit(module);
if (!compiled.ok()) {
return nullptr;
}
auto [ctx, launch] = compiled->GetHostLaunch();
auto tuple_ptr = std::unique_ptr<void*>(new void*[3]);
if (!tuple_ptr) {
return nullptr;
}
tuple_ptr.get()[0] = ctx;
tuple_ptr.get()[1] = reinterpret_cast<void*>(launch);
tuple_ptr.get()[2] = new CompiledKernel(std::move(*compiled));
if (!tuple_ptr.get()[2]) {
return nullptr;
}
return tuple_ptr.release();
}
__attribute__((visibility("default")))
void MosaicGpuUnload(void** tuple_ptr) {
delete reinterpret_cast<CompiledKernel*>(tuple_ptr[2]);
delete[] tuple_ptr;
}
} // extern "C"

View File

@ -64,11 +64,12 @@ py_test(
cc_binary(
name = "pjrt_c_api_gpu_plugin.so",
linkopts = [
"-Wl,--version-script,$(location @xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds)",
"-Wl,--version-script,$(location :gpu_version_script.lds)",
"-Wl,--no-undefined",
],
linkshared = True,
deps = [
":gpu_version_script.lds",
"@xla//xla/pjrt/c:pjrt_c_api_gpu",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds",
"@xla//xla/service:gpu_plugin",

View File

@ -0,0 +1,11 @@
VERS_1.0 {
global:
extern "C" {
GetPjrtApi;
MosaicGpuCompile;
MosaicGpuUnload;
};
local:
*;
};

View File

@ -19,6 +19,7 @@ from functools import partial
import itertools
import math
import operator
import unittest
from absl.testing import absltest, parameterized
import jax
@ -1389,5 +1390,28 @@ class ProfilerTest(TestCase):
jax.block_until_ready(f(xd))
class TorchTest(TestCase):
@classmethod
def setUpClass(cls):
try:
import torch
except ImportError:
raise unittest.SkipTest("Test requires PyTorch")
cls.torch = torch
def test_basic(self):
def kernel(ctx, i_gmem, o_gmem, _):
x = mgpu.FragmentedArray.load_strided(i_gmem)
(x + x).store_untiled(o_gmem)
ty = jax.ShapeDtypeStruct((128, 128), jnp.float32)
x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda')
f = mosaic_gpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ())
y = f(x)
np.testing.assert_allclose(y.cpu(), x.cpu() * 2)
del y # Make sure the destructor runs successfully.
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())