1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

Register NVPTX LLVM backend from Mosaic custom call

So far Mosaic was implicitly relying on XLA to register the NVPTX target which made problems in cases where only a Mosaic kernel gets compiled and XLA didn't initialize the LLVM NVPTX target.

PiperOrigin-RevId: 746433654
This commit is contained in:
Henning Becker 2025-04-11 06:14:36 -07:00 committed by jax authors
parent a1c06fcb3b
commit 896557f07b
2 changed files with 20 additions and 4 deletions
jaxlib/mosaic/gpu

@ -133,11 +133,12 @@ cc_library(
name = "custom_call",
srcs = ["custom_call.cc"],
deps = [
":mosaic_gpu_comm",
":passes",
":target",
"//jaxlib/cuda:cuda_vendor",
"//jaxlib/mosaic/dialect/gpu:mosaic_gpu",
"//jaxlib/mosaic/gpu:mosaic_gpu_comm",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",

@ -32,6 +32,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/base/call_once.h"
#include "absl/base/optimization.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
@ -105,6 +106,16 @@ namespace ffi = xla::ffi;
using MosaicInitFunc = void(void****);
using MosaicHostFunc = void(void**);
void EnsureLLVMNVPTXTargetIsRegistered() {
static absl::once_flag register_nvptx_target_flag;
absl::call_once(register_nvptx_target_flag, []() {
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
});
}
absl::StatusOr<std::pair<std::string, std::string>> GetSmAndPtxIsaVersion() {
// Assumes driver has been initialized and a context exists. XLA already has
// some utilities to query this, but we try to stay runtime-agnostic, so we
@ -123,13 +134,18 @@ absl::StatusOr<std::pair<std::string, std::string>> GetSmAndPtxIsaVersion() {
device) != CUDA_SUCCESS) {
return absl::InternalError("Failed to get minor compute capability");
}
EnsureLLVMNVPTXTargetIsRegistered();
return mosaic::gpu::GetSmAndPtxIsaVersion(major, minor);
}
mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target,
const std::string& sm, const std::string& ptx_isa, const std::string& nvshmem_path) {
static bool register_once = []() {
static absl::once_flag register_passes_flag;
absl::call_once(register_passes_flag, []() {
EnsureLLVMNVPTXTargetIsRegistered();
llvm::InitializeNativeTarget();
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
@ -157,8 +173,7 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
mosaic::gpu::registerByvalInsertionPass();
mlir::arith::registerArithExpandOpsPass();
return true;
}();
(void)register_once;
});
return mlir::parsePassPipeline(absl::StrCat(
R"(
builtin.module(