mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #23653 from apaszke:torchsaic
PiperOrigin-RevId: 675967844
This commit is contained in:
commit
4e6f690724
@ -27,6 +27,7 @@ import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Generic, TypeVar
|
||||
import weakref
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
@ -800,6 +801,21 @@ def _lower_as_gpu_kernel(
|
||||
return module, out_shape, unwrap_output_tuple
|
||||
|
||||
|
||||
def _declare_runtime_functions():
|
||||
"""Declares the runtime functions that can be used by the generated code."""
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
|
||||
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
|
||||
func.FuncOp(
|
||||
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
|
||||
)
|
||||
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
|
||||
func.FuncOp(
|
||||
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
|
||||
)
|
||||
|
||||
|
||||
def as_gpu_kernel(
|
||||
body,
|
||||
grid: tuple[int, int, int],
|
||||
@ -867,16 +883,97 @@ def as_gpu_kernel(
|
||||
return kernel
|
||||
|
||||
|
||||
def _declare_runtime_functions():
|
||||
"""Declares the runtime functions that can be used by the generated code."""
|
||||
ptr_ty = ir.Type.parse("!llvm.ptr")
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
|
||||
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
|
||||
func.FuncOp(
|
||||
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
|
||||
)
|
||||
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
|
||||
func.FuncOp(
|
||||
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
|
||||
def as_torch_gpu_kernel(
|
||||
body,
|
||||
grid: tuple[int, int, int],
|
||||
block: tuple[int, int, int],
|
||||
in_shape,
|
||||
out_shape,
|
||||
smem_scratch_shape: ShapeTree | Union[ShapeTree],
|
||||
prof_spec: profiler.ProfilerSpec | None = None,
|
||||
cluster: tuple[int, int, int] = (1, 1, 1),
|
||||
module_name: str = "unknown",
|
||||
):
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise RuntimeError("as_torch_gpu_kernel requires PyTorch")
|
||||
torch.cuda.init() # Make sure CUDA context is set up.
|
||||
|
||||
if isinstance(in_shape, list):
|
||||
in_shape = tuple(in_shape)
|
||||
elif not isinstance(in_shape, tuple):
|
||||
in_shape = (in_shape,)
|
||||
|
||||
flat_out_types, out_treedef = jax.tree.flatten(out_shape)
|
||||
expected_arg_treedef = jax.tree.structure(in_shape)
|
||||
|
||||
module, out_shape, unwrap_output_tuple = (
|
||||
_lower_as_gpu_kernel(
|
||||
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
|
||||
module_name, prof_spec
|
||||
)
|
||||
)
|
||||
|
||||
# Get our hands on the compilation and unload functions
|
||||
try:
|
||||
import jax_plugins.xla_cuda12 as cuda_plugin
|
||||
except ImportError:
|
||||
raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds "
|
||||
"that use backend plugins")
|
||||
dll = ctypes.CDLL(cuda_plugin._get_library_path())
|
||||
compile_func = dll.MosaicGpuCompile
|
||||
compile_func.argtypes = [ctypes.c_void_p]
|
||||
compile_func.restype = ctypes.POINTER(ctypes.c_void_p)
|
||||
unload_func = dll.MosaicGpuUnload
|
||||
unload_func.argtypes = [compile_func.restype]
|
||||
unload_func.restype = None
|
||||
|
||||
module_asm = module.operation.get_asm(binary=True, enable_debug_info=True)
|
||||
compiled = compile_func(ctypes.c_char_p(module_asm))
|
||||
if compiled is None:
|
||||
raise RuntimeError("Failed to compile the module")
|
||||
ctx, launch_ptr = compiled[0], compiled[1]
|
||||
ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx))
|
||||
launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr)
|
||||
|
||||
def as_torch_dtype(dtype):
|
||||
# torch contains NumPy-compatible dtypes in its top namespace
|
||||
return getattr(torch, np.dtype(dtype).name)
|
||||
|
||||
def apply(*args):
|
||||
flat_args, arg_treedef = jax.tree.flatten(args)
|
||||
if arg_treedef != expected_arg_treedef:
|
||||
raise ValueError(
|
||||
f"Invalid argument structure: expected {expected_arg_treedef}, got"
|
||||
f" {arg_treedef}, ({args=})"
|
||||
)
|
||||
|
||||
# Construct a device pointer list like in the XLA calling convention
|
||||
buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))()
|
||||
i = -1 # Define i in case there are no args
|
||||
device = 'cuda'
|
||||
for i, arg in enumerate(flat_args):
|
||||
buffers[i] = arg.data_ptr()
|
||||
device = arg.device
|
||||
flat_outs = []
|
||||
for i, t in enumerate(flat_out_types, i + 1):
|
||||
out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device)
|
||||
flat_outs.append(out)
|
||||
buffers[i] = out.data_ptr()
|
||||
# Allocate another buffer for args of the host-side program. This is sadly
|
||||
# the default MLIR calling convention.
|
||||
args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)()
|
||||
args_ptr[0] = ctx_ptr_ptr
|
||||
args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_)
|
||||
args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)),
|
||||
ctypes.POINTER(ctypes.c_void_p))
|
||||
launch(args_ptr)
|
||||
return jax.tree.unflatten(out_treedef, flat_outs)
|
||||
|
||||
# Unload the compiled code when the Python function is destroyed.
|
||||
def unload(_):
|
||||
unload_func(compiled)
|
||||
apply.destructor = weakref.ref(apply, unload)
|
||||
|
||||
return apply
|
||||
|
@ -377,10 +377,40 @@ GetKernelCache() {
|
||||
return std::make_pair(&context_cache, &mutex);
|
||||
}
|
||||
|
||||
|
||||
absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
|
||||
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
|
||||
InitContext(&context);
|
||||
mlir::ParserConfig parse_config(&context);
|
||||
auto module_op =
|
||||
mlir::parseSourceString<mlir::ModuleOp>(module, parse_config);
|
||||
if (!module_op) {
|
||||
return absl::InternalError("Failed to parse module");
|
||||
}
|
||||
auto maybe_engine = Compile(*module_op);
|
||||
if (!maybe_engine.ok()) {
|
||||
return maybe_engine.status();
|
||||
}
|
||||
mlir::ExecutionEngine* execution_engine = maybe_engine->get();
|
||||
auto main = execution_engine->lookupPacked("_mlir_ciface_main");
|
||||
auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
|
||||
if (!init || !main) {
|
||||
return absl::InternalError("Failed to retrieve kernel function");
|
||||
}
|
||||
void* module_ptr = nullptr;
|
||||
void* kernel_ptr = nullptr;
|
||||
void** module_ptr_ptr = &module_ptr;
|
||||
void** kernel_ptr_ptr = &kernel_ptr;
|
||||
void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
|
||||
reinterpret_cast<MosaicInitFunc*>(*init)(init_args);
|
||||
return CompiledKernel(std::move(*maybe_engine), kernel_ptr,
|
||||
reinterpret_cast<MosaicHostFunc*>(*main));
|
||||
}
|
||||
|
||||
// Each compiled kernel has a unique init func, and each kernel is used from
|
||||
// a single HLO module. So it should be safe to not include the CUDA context
|
||||
// in the key.
|
||||
absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CompileAndInit(
|
||||
absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CachedCompileAndInit(
|
||||
CacheKey key, const char* module) {
|
||||
auto cache_and_mutex = GetKernelCache();
|
||||
auto* cache = cache_and_mutex.first;
|
||||
@ -397,33 +427,11 @@ absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CompileAndInit(
|
||||
absl::MutexLock lock(mutex);
|
||||
// We released the reader lock, another thread might have initialized it.
|
||||
if (cache->find(key) == cache->end()) {
|
||||
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
|
||||
InitContext(&context);
|
||||
mlir::ParserConfig parse_config(&context);
|
||||
auto module_op =
|
||||
mlir::parseSourceString<mlir::ModuleOp>(module, parse_config);
|
||||
if (!module_op) {
|
||||
return absl::InternalError("Failed to parse module");
|
||||
auto compiled = CompileAndInit(module);
|
||||
if (!compiled.ok()) {
|
||||
return compiled.status();
|
||||
}
|
||||
auto maybe_engine = Compile(*module_op);
|
||||
if (!maybe_engine.ok()) {
|
||||
return maybe_engine.status();
|
||||
}
|
||||
mlir::ExecutionEngine* execution_engine = maybe_engine->get();
|
||||
auto main = execution_engine->lookupPacked("_mlir_ciface_main");
|
||||
auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
|
||||
if (!init || !main) {
|
||||
return absl::InternalError("Failed to retrieve kernel function");
|
||||
}
|
||||
void* module_ptr = nullptr;
|
||||
void* kernel_ptr = nullptr;
|
||||
void** module_ptr_ptr = &module_ptr;
|
||||
void** kernel_ptr_ptr = &kernel_ptr;
|
||||
void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
|
||||
reinterpret_cast<MosaicInitFunc*>(*init)(init_args);
|
||||
cache->insert_or_assign(
|
||||
key, CompiledKernel(std::move(*maybe_engine), kernel_ptr,
|
||||
reinterpret_cast<MosaicHostFunc*>(*main)));
|
||||
cache->insert_or_assign(key, std::move(*compiled));
|
||||
}
|
||||
return cache->at(key).GetHostLaunch();
|
||||
}
|
||||
@ -441,7 +449,7 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque,
|
||||
abort();
|
||||
}
|
||||
CacheKey key(hash, reinterpret_cast<uintptr_t>(ctx));
|
||||
auto ctx_and_kernel = CompileAndInit(key, opaque + sizeof(KernelHash));
|
||||
auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash));
|
||||
if (!ctx_and_kernel.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status,
|
||||
ctx_and_kernel.status().message().data(),
|
||||
@ -456,3 +464,33 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall,
|
||||
"CUDA");
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
|
||||
__attribute__((visibility("default")))
|
||||
void** MosaicGpuCompile(const char* module) {
|
||||
auto compiled = CompileAndInit(module);
|
||||
if (!compiled.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto [ctx, launch] = compiled->GetHostLaunch();
|
||||
auto tuple_ptr = std::unique_ptr<void*>(new void*[3]);
|
||||
if (!tuple_ptr) {
|
||||
return nullptr;
|
||||
}
|
||||
tuple_ptr.get()[0] = ctx;
|
||||
tuple_ptr.get()[1] = reinterpret_cast<void*>(launch);
|
||||
tuple_ptr.get()[2] = new CompiledKernel(std::move(*compiled));
|
||||
if (!tuple_ptr.get()[2]) {
|
||||
return nullptr;
|
||||
}
|
||||
return tuple_ptr.release();
|
||||
}
|
||||
|
||||
__attribute__((visibility("default")))
|
||||
void MosaicGpuUnload(void** tuple_ptr) {
|
||||
delete reinterpret_cast<CompiledKernel*>(tuple_ptr[2]);
|
||||
delete[] tuple_ptr;
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
@ -64,11 +64,12 @@ py_test(
|
||||
cc_binary(
|
||||
name = "pjrt_c_api_gpu_plugin.so",
|
||||
linkopts = [
|
||||
"-Wl,--version-script,$(location @xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds)",
|
||||
"-Wl,--version-script,$(location :gpu_version_script.lds)",
|
||||
"-Wl,--no-undefined",
|
||||
],
|
||||
linkshared = True,
|
||||
deps = [
|
||||
":gpu_version_script.lds",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds",
|
||||
"@xla//xla/service:gpu_plugin",
|
||||
|
11
jaxlib/tools/gpu_version_script.lds
Normal file
11
jaxlib/tools/gpu_version_script.lds
Normal file
@ -0,0 +1,11 @@
|
||||
VERS_1.0 {
|
||||
global:
|
||||
extern "C" {
|
||||
GetPjrtApi;
|
||||
MosaicGpuCompile;
|
||||
MosaicGpuUnload;
|
||||
};
|
||||
|
||||
local:
|
||||
*;
|
||||
};
|
@ -19,6 +19,7 @@ from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
import jax
|
||||
@ -1389,5 +1390,28 @@ class ProfilerTest(TestCase):
|
||||
jax.block_until_ready(f(xd))
|
||||
|
||||
|
||||
class TorchTest(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise unittest.SkipTest("Test requires PyTorch")
|
||||
cls.torch = torch
|
||||
|
||||
def test_basic(self):
|
||||
def kernel(ctx, i_gmem, o_gmem, _):
|
||||
x = mgpu.FragmentedArray.load_strided(i_gmem)
|
||||
(x + x).store_untiled(o_gmem)
|
||||
|
||||
ty = jax.ShapeDtypeStruct((128, 128), jnp.float32)
|
||||
x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda')
|
||||
f = mosaic_gpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ())
|
||||
y = f(x)
|
||||
np.testing.assert_allclose(y.cpu(), x.cpu() * 2)
|
||||
del y # Make sure the destructor runs successfully.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user