Sergei Lebedev a14e6968bf [mosaic] Migrated the serialization pass from codegen to pass_boilerplate.h
This prepares teh generalization of the serialization pass to handle both
Mosaic TPU and GPU.

PiperOrigin-RevId: 705628923
2024-12-12 14:19:36 -08:00

218 lines
6.2 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_python//python:defs.bzl", "py_library")
load("//jaxlib:jax.bzl", "pybind_extension")
package(
default_applicable_licenses = [],
default_visibility = ["//jax:mosaic_gpu_users"],
)
py_library(
name = "mosaic_gpu",
data = [":libmosaic_gpu_runtime.so"],
deps = [":_mosaic_gpu_ext"],
)
cc_library(
name = "target",
srcs = ["target.cc"],
hdrs = ["target.h"],
deps = [
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@llvm-project//llvm:MC",
],
)
cc_library(
name = "passes",
srcs = [
"launch_lowering.cc",
"passes.cc",
],
hdrs = [
"launch_lowering.h",
"passes.h",
],
deps = [
"//jaxlib:pass_boilerplate",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:DataLayoutInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
)
CAPI_SOURCES = [
"integrations/c/passes.cc",
]
CAPI_HEADERS = [
"integrations/c/passes.h",
]
cc_library(
name = "mlir_capi",
srcs = CAPI_SOURCES,
hdrs = CAPI_HEADERS,
deps = [
":passes",
"@llvm-project//mlir:CAPIIRHeaders",
],
)
# Header-only target, used when using the C API from a separate shared library.
cc_library(
name = "mlir_capi_headers",
hdrs = CAPI_HEADERS,
deps = [
"@llvm-project//mlir:CAPIIRHeaders",
],
)
# Alwayslink target, used when exporting the C API from a shared library.
cc_library(
name = "mlir_capi_objects",
srcs = CAPI_SOURCES,
hdrs = CAPI_HEADERS,
deps = [
":passes",
"@llvm-project//mlir:CAPIIRObjects",
],
alwayslink = True,
)
cc_library(
name = "runtime",
srcs = ["runtime.cc"],
deps = [
"@local_config_cuda//cuda:cuda_headers",
],
)
cc_library(
name = "custom_call",
srcs = ["custom_call.cc"],
deps = [
":passes",
":target",
"//jaxlib/cuda:cuda_vendor",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:ArithTransforms",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:ComplexToLLVM",
"@llvm-project//mlir:ControlFlowToLLVM",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:ExecutionEngine",
"@llvm-project//mlir:ExecutionEngineUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncToLLVM",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToLLVMIRTranslation",
"@llvm-project//mlir:GPUTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:IndexToLLVM",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MathToLLVM",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefToLLVM",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:NVGPUDialect",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:NVVMTarget",
"@llvm-project//mlir:NVVMToLLVM",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:UBToLLVM",
"@llvm-project//mlir:VectorDialect",
"@xla//xla/service:custom_call_status",
"@xla//xla/service:custom_call_target_registry",
],
alwayslink = True,
)
pybind_extension(
name = "_mosaic_gpu_ext",
srcs = ["mosaic_gpu_ext.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
linkopts = select({
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
"-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib",
],
"//conditions:default": [],
}),
deps = [
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cuda:cuda_vendor",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/strings",
"@nanobind",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/tsl/cuda:cudart",
],
)
cc_binary(
name = "libmosaic_gpu_runtime.so",
srcs = ["runtime.cc"],
copts = ["-fvisibility=default"],
linkopts = select({
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
"-Wl,-rpath,$$ORIGIN/../../../nvidia/cuda_runtime/lib",
],
"//conditions:default": [],
}),
linkshared = 1,
tags = [
"manual",
"notap",
],
deps = [
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/tsl/cuda:cudart",
],
)