mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
Register NVPTX LLVM backend from Mosaic custom call
So far Mosaic was implicitly relying on XLA to register the NVPTX target which made problems in cases where only a Mosaic kernel gets compiled and XLA didn't initialize the LLVM NVPTX target. PiperOrigin-RevId: 746433654
This commit is contained in:
parent
a1c06fcb3b
commit
896557f07b
jaxlib/mosaic/gpu
@ -133,11 +133,12 @@ cc_library(
|
||||
name = "custom_call",
|
||||
srcs = ["custom_call.cc"],
|
||||
deps = [
|
||||
":mosaic_gpu_comm",
|
||||
":passes",
|
||||
":target",
|
||||
"//jaxlib/cuda:cuda_vendor",
|
||||
"//jaxlib/mosaic/dialect/gpu:mosaic_gpu",
|
||||
"//jaxlib/mosaic/gpu:mosaic_gpu_comm",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/cleanup",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
@ -105,6 +106,16 @@ namespace ffi = xla::ffi;
|
||||
using MosaicInitFunc = void(void****);
|
||||
using MosaicHostFunc = void(void**);
|
||||
|
||||
void EnsureLLVMNVPTXTargetIsRegistered() {
|
||||
static absl::once_flag register_nvptx_target_flag;
|
||||
absl::call_once(register_nvptx_target_flag, []() {
|
||||
LLVMInitializeNVPTXTarget();
|
||||
LLVMInitializeNVPTXTargetInfo();
|
||||
LLVMInitializeNVPTXTargetMC();
|
||||
LLVMInitializeNVPTXAsmPrinter();
|
||||
});
|
||||
}
|
||||
|
||||
absl::StatusOr<std::pair<std::string, std::string>> GetSmAndPtxIsaVersion() {
|
||||
// Assumes driver has been initialized and a context exists. XLA already has
|
||||
// some utilities to query this, but we try to stay runtime-agnostic, so we
|
||||
@ -123,13 +134,18 @@ absl::StatusOr<std::pair<std::string, std::string>> GetSmAndPtxIsaVersion() {
|
||||
device) != CUDA_SUCCESS) {
|
||||
return absl::InternalError("Failed to get minor compute capability");
|
||||
}
|
||||
EnsureLLVMNVPTXTargetIsRegistered();
|
||||
return mosaic::gpu::GetSmAndPtxIsaVersion(major, minor);
|
||||
}
|
||||
|
||||
|
||||
mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
|
||||
mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target,
|
||||
const std::string& sm, const std::string& ptx_isa, const std::string& nvshmem_path) {
|
||||
static bool register_once = []() {
|
||||
static absl::once_flag register_passes_flag;
|
||||
absl::call_once(register_passes_flag, []() {
|
||||
EnsureLLVMNVPTXTargetIsRegistered();
|
||||
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
@ -157,8 +173,7 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
|
||||
mosaic::gpu::registerByvalInsertionPass();
|
||||
mlir::arith::registerArithExpandOpsPass();
|
||||
return true;
|
||||
}();
|
||||
(void)register_once;
|
||||
});
|
||||
return mlir::parsePassPipeline(absl::StrCat(
|
||||
R"(
|
||||
builtin.module(
|
||||
|
Loading…
x
Reference in New Issue
Block a user