From 611ad630603cffa88aa714bf876340af315dd819 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 6 Sep 2024 16:09:58 +0000 Subject: [PATCH] Add basic PyTorch integration for Mosaic GPU We have already had most of the relevant pieces and we only needed to connect them together. The most sensitive change is perhaps that I needed to expose one more symbol from the XLA GPU plugin, but I don't think it should be a problem. --- jax/experimental/mosaic/gpu/__init__.py | 121 +++++++++++++++++++++--- jaxlib/mosaic/gpu/custom_call.cc | 94 ++++++++++++------ jaxlib/tools/BUILD.bazel | 3 +- jaxlib/tools/gpu_version_script.lds | 11 +++ tests/mosaic/gpu_test.py | 24 +++++ 5 files changed, 212 insertions(+), 41 deletions(-) create mode 100644 jaxlib/tools/gpu_version_script.lds diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 2e2941fca..0e263844b 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -27,6 +27,7 @@ import subprocess import tempfile import time from typing import Any, Generic, TypeVar +import weakref import jax from jax._src import config @@ -800,6 +801,21 @@ def _lower_as_gpu_kernel( return module, out_shape, unwrap_output_tuple +def _declare_runtime_functions(): + """Declares the runtime functions that can be used by the generated code.""" + ptr_ty = ir.Type.parse("!llvm.ptr") + i64 = ir.IntegerType.get_signless(64) + arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] + init_tma_desc_type = ir.FunctionType.get(arg_tys, []) + func.FuncOp( + "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" + ) + memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) + func.FuncOp( + "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" + ) + + def as_gpu_kernel( body, grid: tuple[int, int, int], @@ -867,16 +883,97 @@ def as_gpu_kernel( return kernel -def _declare_runtime_functions(): - """Declares the runtime functions that can be used by the generated code.""" - ptr_ty = ir.Type.parse("!llvm.ptr") - i64 = ir.IntegerType.get_signless(64) - arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] - init_tma_desc_type = ir.FunctionType.get(arg_tys, []) - func.FuncOp( - "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" - ) - memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) - func.FuncOp( - "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" +def as_torch_gpu_kernel( + body, + grid: tuple[int, int, int], + block: tuple[int, int, int], + in_shape, + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, + cluster: tuple[int, int, int] = (1, 1, 1), + module_name: str = "unknown", +): + try: + import torch + except ImportError: + raise RuntimeError("as_torch_gpu_kernel requires PyTorch") + torch.cuda.init() # Make sure CUDA context is set up. + + if isinstance(in_shape, list): + in_shape = tuple(in_shape) + elif not isinstance(in_shape, tuple): + in_shape = (in_shape,) + + flat_out_types, out_treedef = jax.tree.flatten(out_shape) + expected_arg_treedef = jax.tree.structure(in_shape) + + module, out_shape, unwrap_output_tuple = ( + _lower_as_gpu_kernel( + body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, + module_name, prof_spec + ) ) + + # Get our hands on the compilation and unload functions + try: + import jax_plugins.xla_cuda12 as cuda_plugin + except ImportError: + raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds " + "that use backend plugins") + dll = ctypes.CDLL(cuda_plugin._get_library_path()) + compile_func = dll.MosaicGpuCompile + compile_func.argtypes = [ctypes.c_void_p] + compile_func.restype = ctypes.POINTER(ctypes.c_void_p) + unload_func = dll.MosaicGpuUnload + unload_func.argtypes = [compile_func.restype] + unload_func.restype = None + + module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) + compiled = compile_func(ctypes.c_char_p(module_asm)) + if compiled is None: + raise RuntimeError("Failed to compile the module") + ctx, launch_ptr = compiled[0], compiled[1] + ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx)) + launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr) + + def as_torch_dtype(dtype): + # torch contains NumPy-compatible dtypes in its top namespace + return getattr(torch, np.dtype(dtype).name) + + def apply(*args): + flat_args, arg_treedef = jax.tree.flatten(args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}, ({args=})" + ) + + # Construct a device pointer list like in the XLA calling convention + buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))() + i = -1 # Define i in case there are no args + device = 'cuda' + for i, arg in enumerate(flat_args): + buffers[i] = arg.data_ptr() + device = arg.device + flat_outs = [] + for i, t in enumerate(flat_out_types, i + 1): + out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device) + flat_outs.append(out) + buffers[i] = out.data_ptr() + # Allocate another buffer for args of the host-side program. This is sadly + # the default MLIR calling convention. + args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)() + args_ptr[0] = ctx_ptr_ptr + args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_) + args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)), + ctypes.POINTER(ctypes.c_void_p)) + launch(args_ptr) + return jax.tree.unflatten(out_treedef, flat_outs) + + # Unload the compiled code when the Python function is destroyed. + def unload(_): + unload_func(compiled) + apply.destructor = weakref.ref(apply, unload) + + return apply diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 2e5723b18..103f9f78c 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -377,10 +377,40 @@ GetKernelCache() { return std::make_pair(&context_cache, &mutex); } + +absl::StatusOr CompileAndInit(const char* module) { + mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); + InitContext(&context); + mlir::ParserConfig parse_config(&context); + auto module_op = + mlir::parseSourceString(module, parse_config); + if (!module_op) { + return absl::InternalError("Failed to parse module"); + } + auto maybe_engine = Compile(*module_op); + if (!maybe_engine.ok()) { + return maybe_engine.status(); + } + mlir::ExecutionEngine* execution_engine = maybe_engine->get(); + auto main = execution_engine->lookupPacked("_mlir_ciface_main"); + auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); + if (!init || !main) { + return absl::InternalError("Failed to retrieve kernel function"); + } + void* module_ptr = nullptr; + void* kernel_ptr = nullptr; + void** module_ptr_ptr = &module_ptr; + void** kernel_ptr_ptr = &kernel_ptr; + void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; + reinterpret_cast(*init)(init_args); + return CompiledKernel(std::move(*maybe_engine), kernel_ptr, + reinterpret_cast(*main)); +} + // Each compiled kernel has a unique init func, and each kernel is used from // a single HLO module. So it should be safe to not include the CUDA context // in the key. -absl::StatusOr> CompileAndInit( +absl::StatusOr> CachedCompileAndInit( CacheKey key, const char* module) { auto cache_and_mutex = GetKernelCache(); auto* cache = cache_and_mutex.first; @@ -397,33 +427,11 @@ absl::StatusOr> CompileAndInit( absl::MutexLock lock(mutex); // We released the reader lock, another thread might have initialized it. if (cache->find(key) == cache->end()) { - mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); - InitContext(&context); - mlir::ParserConfig parse_config(&context); - auto module_op = - mlir::parseSourceString(module, parse_config); - if (!module_op) { - return absl::InternalError("Failed to parse module"); + auto compiled = CompileAndInit(module); + if (!compiled.ok()) { + return compiled.status(); } - auto maybe_engine = Compile(*module_op); - if (!maybe_engine.ok()) { - return maybe_engine.status(); - } - mlir::ExecutionEngine* execution_engine = maybe_engine->get(); - auto main = execution_engine->lookupPacked("_mlir_ciface_main"); - auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); - if (!init || !main) { - return absl::InternalError("Failed to retrieve kernel function"); - } - void* module_ptr = nullptr; - void* kernel_ptr = nullptr; - void** module_ptr_ptr = &module_ptr; - void** kernel_ptr_ptr = &kernel_ptr; - void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; - reinterpret_cast(*init)(init_args); - cache->insert_or_assign( - key, CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(*main))); + cache->insert_or_assign(key, std::move(*compiled)); } return cache->at(key).GetHostLaunch(); } @@ -441,7 +449,7 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, abort(); } CacheKey key(hash, reinterpret_cast(ctx)); - auto ctx_and_kernel = CompileAndInit(key, opaque + sizeof(KernelHash)); + auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); if (!ctx_and_kernel.ok()) { XlaCustomCallStatusSetFailure(status, ctx_and_kernel.status().message().data(), @@ -456,3 +464,33 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, "CUDA"); } // namespace + +extern "C" { + +__attribute__((visibility("default"))) +void** MosaicGpuCompile(const char* module) { + auto compiled = CompileAndInit(module); + if (!compiled.ok()) { + return nullptr; + } + auto [ctx, launch] = compiled->GetHostLaunch(); + auto tuple_ptr = std::unique_ptr(new void*[3]); + if (!tuple_ptr) { + return nullptr; + } + tuple_ptr.get()[0] = ctx; + tuple_ptr.get()[1] = reinterpret_cast(launch); + tuple_ptr.get()[2] = new CompiledKernel(std::move(*compiled)); + if (!tuple_ptr.get()[2]) { + return nullptr; + } + return tuple_ptr.release(); +} + +__attribute__((visibility("default"))) +void MosaicGpuUnload(void** tuple_ptr) { + delete reinterpret_cast(tuple_ptr[2]); + delete[] tuple_ptr; +} + +} // extern "C" diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 8463cba08..4642af120 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -64,11 +64,12 @@ py_test( cc_binary( name = "pjrt_c_api_gpu_plugin.so", linkopts = [ - "-Wl,--version-script,$(location @xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds)", + "-Wl,--version-script,$(location :gpu_version_script.lds)", "-Wl,--no-undefined", ], linkshared = True, deps = [ + ":gpu_version_script.lds", "@xla//xla/pjrt/c:pjrt_c_api_gpu", "@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds", "@xla//xla/service:gpu_plugin", diff --git a/jaxlib/tools/gpu_version_script.lds b/jaxlib/tools/gpu_version_script.lds new file mode 100644 index 000000000..8e46b2c59 --- /dev/null +++ b/jaxlib/tools/gpu_version_script.lds @@ -0,0 +1,11 @@ +VERS_1.0 { + global: + extern "C" { + GetPjrtApi; + MosaicGpuCompile; + MosaicGpuUnload; + }; + + local: + *; +}; diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index ec9a7cd8b..27dc1c984 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -19,6 +19,7 @@ from functools import partial import itertools import math import operator +import unittest from absl.testing import absltest, parameterized import jax @@ -1387,5 +1388,28 @@ class ProfilerTest(TestCase): jax.block_until_ready(f(xd)) +class TorchTest(TestCase): + + @classmethod + def setUpClass(cls): + try: + import torch + except ImportError: + raise unittest.SkipTest("Test requires PyTorch") + cls.torch = torch + + def test_basic(self): + def kernel(ctx, i_gmem, o_gmem, _): + x = mgpu.FragmentedArray.load_strided(i_gmem) + (x + x).store_untiled(o_gmem) + + ty = jax.ShapeDtypeStruct((128, 128), jnp.float32) + x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda') + f = mosaic_gpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ()) + y = f(x) + np.testing.assert_allclose(y.cpu(), x.cpu() * 2) + del y # Make sure the destructor runs successfully. + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())