mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic GPU] Stop using the MLIR CUDA runtime
This ports the remaining few functions we depended on to the Mosaic GPU runtime. This has the additional benefit of avoiding the expensive driver calls to determine maximum SMEM bounds that the MLIR runtime does at every kernel launch. PiperOrigin-RevId: 629069842
This commit is contained in:
parent
d92d9394ae
commit
32cb7c3f94
@ -111,7 +111,7 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes)
|
||||
del out_types # Unused.
|
||||
runtime_path = (
|
||||
pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent.parent.parent
|
||||
/ "mosaic" / "gpu" / "libmlir_cuda_runtime.so"
|
||||
/ "mosaic" / "gpu" / "libmosaic_gpu_runtime.so"
|
||||
)
|
||||
shared_libs = [str(runtime_path)] if runtime_path.exists() else []
|
||||
engine = ExecutionEngine(
|
||||
|
@ -21,7 +21,7 @@ package(
|
||||
|
||||
py_library(
|
||||
name = "mosaic_gpu",
|
||||
data = [":libmlir_cuda_runtime.so"],
|
||||
data = [":libmosaic_gpu_runtime.so"],
|
||||
deps = [
|
||||
"//jaxlib/mlir:execution_engine",
|
||||
"//jaxlib/mlir:gpu_dialect",
|
||||
@ -96,11 +96,8 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "libmlir_cuda_runtime.so",
|
||||
srcs = [
|
||||
"runtime.cc",
|
||||
"@llvm-project//mlir:lib/ExecutionEngine/CudaRuntimeWrappers.cpp",
|
||||
],
|
||||
name = "libmosaic_gpu_runtime.so",
|
||||
srcs = ["runtime.cc"],
|
||||
copts = ["-fvisibility=default"],
|
||||
linkopts = select({
|
||||
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
|
||||
@ -114,7 +111,6 @@ cc_binary(
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//mlir:mlir_c_runner_utils_hdrs",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
|
@ -100,21 +100,22 @@ void emitRuntimeDecls(mlir::ModuleOp module) {
|
||||
auto i64 = mlir::IntegerType::get(module.getContext(), 64);
|
||||
auto decl_builder = mlir::OpBuilder::atBlockBegin(module.getBody());
|
||||
decl_builder.create<mlir::func::FuncOp>(
|
||||
module.getLoc(), decl_builder.getStringAttr("mgpuLaunchKernel"),
|
||||
module.getLoc(), decl_builder.getStringAttr("mosaic_gpu_launch_kernel"),
|
||||
mlir::FunctionType::get(module.getContext(),
|
||||
{ptr_ty, i64, i64, i64, i64, i64, i64, i32,
|
||||
ptr_ty, ptr_ty, ptr_ty, i64},
|
||||
ptr_ty, ptr_ty},
|
||||
{}),
|
||||
decl_builder.getStringAttr("private"), /*arg_attr=*/nullptr,
|
||||
/*res_attrs=*/nullptr);
|
||||
decl_builder.create<mlir::func::FuncOp>(
|
||||
module.getLoc(), decl_builder.getStringAttr("mgpuModuleLoad"),
|
||||
mlir::FunctionType::get(module.getContext(), {ptr_ty, i64}, {ptr_ty}),
|
||||
module.getLoc(), decl_builder.getStringAttr("mosaic_gpu_module_load"),
|
||||
mlir::FunctionType::get(module.getContext(), {ptr_ty}, {ptr_ty}),
|
||||
decl_builder.getStringAttr("private"), /*arg_attr=*/nullptr,
|
||||
/*res_attrs=*/nullptr);
|
||||
decl_builder.create<mlir::func::FuncOp>(
|
||||
module.getLoc(), decl_builder.getStringAttr("mgpuModuleGetFunction"),
|
||||
mlir::FunctionType::get(module.getContext(), {ptr_ty, ptr_ty}, {ptr_ty}),
|
||||
module.getLoc(), decl_builder.getStringAttr("mosaic_gpu_get_function"),
|
||||
mlir::FunctionType::get(module.getContext(), {ptr_ty, ptr_ty, i32},
|
||||
{ptr_ty}),
|
||||
decl_builder.getStringAttr("private"), /*arg_attr=*/nullptr,
|
||||
/*res_attrs=*/nullptr);
|
||||
}
|
||||
@ -122,8 +123,9 @@ void emitRuntimeDecls(mlir::ModuleOp module) {
|
||||
void buildInitFunction(mlir::OpBuilder &module_builder,
|
||||
mlir::func::FuncOp init_func,
|
||||
llvm::StringRef kernel_name,
|
||||
mlir::gpu::ObjectAttr object) {
|
||||
auto i64 = mlir::IntegerType::get(init_func.getContext(), 64);
|
||||
mlir::gpu::ObjectAttr object,
|
||||
mlir::Value dynamic_smem_size) {
|
||||
auto i32 = mlir::IntegerType::get(init_func.getContext(), 32);
|
||||
auto ptr_ty = mlir::LLVM::LLVMPointerType::get(init_func.getContext());
|
||||
mlir::Location loc = init_func.getLoc();
|
||||
auto builder =
|
||||
@ -139,13 +141,10 @@ void buildInitFunction(mlir::OpBuilder &module_builder,
|
||||
/*value=*/object.getObject());
|
||||
mlir::Value binary_addr = builder.create<mlir::LLVM::AddressOfOp>(
|
||||
init_func.getLoc(), binary_global_decl);
|
||||
mlir::Value binary_size = builder.create<mlir::LLVM::ConstantOp>(
|
||||
loc, i64, builder.getI64IntegerAttr(object.getObject().size()));
|
||||
mlir::Value module_handle =
|
||||
builder
|
||||
.create<mlir::func::CallOp>(
|
||||
loc, "mgpuModuleLoad", ptr_ty,
|
||||
mlir::ValueRange{binary_addr, binary_size})
|
||||
.create<mlir::func::CallOp>(loc, "mosaic_gpu_module_load", ptr_ty,
|
||||
binary_addr)
|
||||
.getResult(0);
|
||||
|
||||
// TODO(apaszke): This will create duplicate globals if the kernel
|
||||
@ -164,11 +163,22 @@ void buildInitFunction(mlir::OpBuilder &module_builder,
|
||||
llvm::Twine(kernel_name).concat(llvm::Twine('\0'))));
|
||||
mlir::Value kernel_name_ptr =
|
||||
builder.create<mlir::LLVM::AddressOfOp>(loc, kernel_name_global);
|
||||
mlir::Value used_smem = builder.create<mlir::LLVM::ConstantOp>(
|
||||
loc, i32, builder.getI32IntegerAttr(0));
|
||||
if (dynamic_smem_size) {
|
||||
if (auto const_smem =
|
||||
dynamic_smem_size.getDefiningOp<mlir::LLVM::ConstantOp>()) {
|
||||
used_smem = builder.create<mlir::LLVM::ConstantOp>(
|
||||
loc, i32,
|
||||
builder.getI32IntegerAttr(
|
||||
mlir::cast<mlir::IntegerAttr>(const_smem.getValue()).getSInt()));
|
||||
}
|
||||
}
|
||||
mlir::Value kernel_handle =
|
||||
builder
|
||||
.create<mlir::func::CallOp>(
|
||||
loc, "mgpuModuleGetFunction", ptr_ty,
|
||||
mlir::ValueRange{module_handle, kernel_name_ptr})
|
||||
loc, "mosaic_gpu_get_function", ptr_ty,
|
||||
mlir::ValueRange{module_handle, kernel_name_ptr, used_smem})
|
||||
.getResult(0);
|
||||
builder.create<mlir::func::ReturnOp>(loc, kernel_handle);
|
||||
}
|
||||
@ -176,13 +186,12 @@ void buildInitFunction(mlir::OpBuilder &module_builder,
|
||||
mlir::LogicalResult launchPreloadedKernel(mlir::func::FuncOp func,
|
||||
mlir::gpu::LaunchFuncOp launch,
|
||||
mlir::Value kernel_handle) {
|
||||
auto ptr_ty = mlir::LLVM::LLVMPointerType::get(func.getContext());
|
||||
// Lower gpu.launch_func to a call to mgpuLaunchKernel.
|
||||
mlir::OpBuilder builder(launch);
|
||||
mlir::Value dynamic_smem = launch.getDynamicSharedMemorySize();
|
||||
if (!dynamic_smem) {
|
||||
dynamic_smem = builder.create<mlir::LLVM::ConstantOp>(
|
||||
launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0));
|
||||
launch.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(0));
|
||||
}
|
||||
mlir::Value arg_ptr_array = packKernelArgs(builder, launch);
|
||||
if (launch.hasClusterSize()) {
|
||||
@ -190,17 +199,11 @@ mlir::LogicalResult launchPreloadedKernel(mlir::func::FuncOp func,
|
||||
}
|
||||
mlir::gpu::KernelDim3 grid = launch.getGridSizeOperandValues();
|
||||
mlir::gpu::KernelDim3 block = launch.getBlockSizeOperandValues();
|
||||
mlir::Value llvm_nullptr =
|
||||
builder.create<mlir::LLVM::ZeroOp>(launch.getLoc(), ptr_ty);
|
||||
mlir::Value stream = launch.getAsyncObject();
|
||||
mlir::Value param_count = builder.create<mlir::LLVM::ConstantOp>(
|
||||
launch.getLoc(), builder.getI64Type(),
|
||||
builder.getI64IntegerAttr(launch.getNumKernelOperands()));
|
||||
builder.create<mlir::func::CallOp>(
|
||||
launch.getLoc(), "mgpuLaunchKernel", mlir::TypeRange{},
|
||||
launch.getLoc(), "mosaic_gpu_launch_kernel", mlir::TypeRange{},
|
||||
mlir::ValueRange{kernel_handle, grid.x, grid.y, grid.z, block.x, block.y,
|
||||
block.z, dynamic_smem, stream, arg_ptr_array,
|
||||
llvm_nullptr, param_count});
|
||||
block.z, dynamic_smem, stream, arg_ptr_array});
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@ -282,7 +285,8 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass<mlir::ModuleOp> {
|
||||
}
|
||||
|
||||
buildInitFunction(module_builder, init_func,
|
||||
launch.getKernelName().getValue(), object);
|
||||
launch.getKernelName().getValue(), object,
|
||||
launch.getDynamicSharedMemorySize());
|
||||
|
||||
// Add a new function argument for the kernel handle.
|
||||
func.insertArgument(0, ptr_ty,
|
||||
|
@ -92,4 +92,54 @@ void mosaic_gpu_memcpy_async_h2d(CUdeviceptr dst, void *src, uint64_t bytes,
|
||||
}
|
||||
}
|
||||
|
||||
void* mosaic_gpu_module_load(void *data) {
|
||||
CUmodule module = nullptr;
|
||||
if (auto result = cuModuleLoadData(&module, data); result != CUDA_SUCCESS) {
|
||||
const char *ptr = nullptr;
|
||||
cuGetErrorString(result, &ptr);
|
||||
fprintf(stderr, "cuModuleLoadData failed: %s\n", ptr);
|
||||
abort();
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
void *mosaic_gpu_get_function(CUmodule module, const char *name,
|
||||
int32_t smem_bytes) {
|
||||
CUfunction function = nullptr;
|
||||
CUresult result = cuModuleGetFunction(&function, module, name);
|
||||
if (result != CUDA_SUCCESS) {
|
||||
const char *ptr = nullptr;
|
||||
cuGetErrorString(result, &ptr);
|
||||
fprintf(stderr, "cuModuleGetFunction failed: %s\n", ptr);
|
||||
abort();
|
||||
}
|
||||
if (smem_bytes) {
|
||||
result = cuFuncSetAttribute(
|
||||
function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_bytes);
|
||||
if (result != CUDA_SUCCESS) {
|
||||
const char *ptr = nullptr;
|
||||
cuGetErrorString(result, &ptr);
|
||||
fprintf(stderr, "cuFuncSetAttribute failed: %s\n", ptr);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
return function;
|
||||
}
|
||||
|
||||
void mosaic_gpu_launch_kernel(CUfunction function, int64_t grid_x,
|
||||
int64_t grid_y, int64_t grid_z, int64_t block_x,
|
||||
int64_t block_y, int64_t block_z,
|
||||
int32_t smem_bytes, CUstream stream,
|
||||
void **params) {
|
||||
CUresult result =
|
||||
cuLaunchKernel(function, grid_x, grid_y, grid_z, block_x, block_y,
|
||||
block_z, smem_bytes, stream, params, nullptr);
|
||||
if (result != CUDA_SUCCESS) {
|
||||
const char *ptr = nullptr;
|
||||
cuGetErrorString(result, &ptr);
|
||||
fprintf(stderr, "cuLaunchKernel failed: %s\n", ptr);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user