Peter Hawkins 64eae324ee Migrate JAX MLIR Python dialect extensions to nanobind.
Now that https://github.com/llvm/llvm-project/pull/117922 has landed upstream, we can work towards removing our last uses of pybind11.

PiperOrigin-RevId: 705872751
2024-12-13 07:08:28 -08:00

436 lines
12 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"if_windows",
"py_extension",
"pybind_extension",
"windows_cc_shared_mlir_library",
)
load("//jaxlib:symlink_files.bzl", "symlink_inputs")
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"],
)
COPTS = [
"-fexceptions",
"-frtti",
]
LINKOPTS = select({
"@xla//xla/tsl:macos": [
"-Wl,-rpath,@loader_path/",
"-Wl,-rename_section,__TEXT,text_env,__TEXT,__text",
],
"@xla//xla/tsl:windows": [],
"//conditions:default": [
"-Wl,-rpath,$$ORIGIN/",
],
})
py_extension(
name = "_mlir",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
],
)
py_extension(
name = "_mlirDialectsGPU",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/DialectGPU.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIGPUHeaders",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
],
)
py_extension(
name = "_mlirGPUPasses",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIGPUHeaders",
"@pybind11",
],
)
py_extension(
name = "_mlirDialectsNVGPU",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/DialectNVGPU.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:CAPINVGPUHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
],
)
py_extension(
name = "_mlirDialectsLLVM",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:CAPILLVMHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
],
)
py_extension(
name = "_mlirDialectsSparseTensor",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPISparseTensorHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
],
)
py_extension(
name = "_mlirSparseTensorPasses",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPISparseTensorHeaders",
"@pybind11",
],
)
py_extension(
name = "_mosaic_gpu_ext",
srcs = ["mosaic_gpu_ext.cc"],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@nanobind",
],
)
# This is here, instead of in jaxlib/mosaic/python, so it's in the same
# directory as libjaxlib_mlir_capi.so (produced by
# :jaxlib_mlir_capi_shared_library). This ensures that the RPATH works correctly
# across platforms. It's not clear if Windows supports RPATH-like functionality
# across different directories at all.
py_extension(
name = "_tpu_ext",
srcs = ["tpu_ext.cc"],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"//jaxlib/mosaic:tpu_dialect_capi_headers",
"@com_google_absl//absl/log:check",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
"@nanobind",
"@xla//xla/python:nb_numpy",
"@xla//xla/tsl/python/lib/core:numpy",
],
)
# This target contains the extension and it's Python dependencies, which are not
# supported by the `py_extension`/`pybind_extension` macros.
py_library(
name = "_tpu_ext_lib",
deps = [
":_tpu_ext",
"//jaxlib/mlir:ir",
"//jaxlib/mosaic/python:layout_defs",
],
)
pybind_extension(
name = "_triton_ext",
srcs = ["triton_ext.cc"],
copts = COPTS,
linkopts = LINKOPTS,
pytype_srcs = ["_triton_ext.pyi"],
deps = [
":jaxlib_mlir_capi_shared_library",
"//jaxlib/triton:triton_dialect_capi_headers",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
"@nanobind",
],
)
symlink_inputs(
name = "_mlir_libs",
rule = py_library,
symlinked_inputs = {"srcs": {
".": [
"@llvm-project//mlir/python:MlirLibsPyFiles",
],
}},
deps = [
":_mlir",
":register_jax_dialects",
],
)
cc_library(
name = "jaxlib_mlir_capi_shims",
srcs = ["jaxlib_mlir_capi_shims.cc"],
hdrs = ["jaxlib_mlir_capi_shims.h"],
deps = [
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:GPUPipelines",
"@llvm-project//mlir:GPUToLLVMIRTranslation",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:NVVMTarget",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
],
alwayslink = 1,
)
cc_library(
name = "jaxlib_mlir_capi_shims_hdrs",
hdrs = ["jaxlib_mlir_capi_shims.h"],
deps = [
"@llvm-project//mlir:CAPIIRHeaders",
],
)
# JAX-specific registrations.
py_extension(
name = "register_jax_dialects",
srcs = ["register_jax_dialects.cc"],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIArithHeaders",
"@llvm-project//mlir:CAPIGPUHeaders",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:CAPILLVMHeaders",
"@llvm-project//mlir:CAPIMathHeaders",
"@llvm-project//mlir:CAPIMemRefHeaders",
"@llvm-project//mlir:CAPINVGPUHeaders",
"@llvm-project//mlir:CAPINVVMHeaders",
"@llvm-project//mlir:CAPISCFHeaders",
"@llvm-project//mlir:CAPITransformsHeaders",
"@llvm-project//mlir:CAPIVectorHeaders",
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
"@local_config_python//:headers",
"@nanobind",
"@shardy//shardy/integrations/c:sdy_capi_headers",
],
)
##---------------------------------------------------------------------------##
# MHLO Extensions
##---------------------------------------------------------------------------##
py_extension(
name = "_mlirHlo",
srcs = [
"@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@local_config_python//:headers",
"@pybind11",
"@xla//xla/mlir_hlo:CAPIHeaders",
],
)
##---------------------------------------------------------------------------##
# Shardy Extensions
##---------------------------------------------------------------------------##
py_extension(
name = "_sdy",
srcs = [
"@shardy//shardy/integrations/python/ir:sdy_module.cc",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@local_config_python//:headers",
"@pybind11",
"@shardy//shardy/integrations/c:sdy_capi_headers",
],
)
##---------------------------------------------------------------------------##
# Stablehlo Extensions
##---------------------------------------------------------------------------##
py_extension(
name = "_chlo",
srcs = [
"@stablehlo//:chlo_py_api_files",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@local_config_python//:headers",
"@pybind11",
"@stablehlo//:chlo_capi_headers",
],
)
py_extension(
name = "_stablehlo",
srcs = [
"@stablehlo//:stablehlo_py_api_files",
],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@local_config_python//:headers",
"@pybind11",
"@stablehlo//:stablehlo_capi_headers",
],
)
# Shared C++ extension library
cc_library(
name = "jaxlib_mlir_capi_shared_library",
srcs = select({
"@xla//xla/tsl:windows": [":jaxlib_mlir_capi.dll"],
"@xla//xla/tsl:macos": [":libjaxlib_mlir_capi.dylib"],
"//conditions:default": [":libjaxlib_mlir_capi.so"],
}),
deps = select({
"@xla//xla/tsl:windows": [":jaxlib_mlir_capi_dll"],
"//conditions:default": [],
}),
)
cc_library(
name = "jaxlib_mlir_capi_objects",
deps = [
"//jaxlib/mosaic:tpu_dialect_capi_objects",
"@llvm-project//mlir:CAPIArithObjects",
"@llvm-project//mlir:CAPIGPUObjects",
"@llvm-project//mlir:CAPIIRObjects",
"@llvm-project//mlir:CAPILLVMObjects",
"@llvm-project//mlir:CAPIMathObjects",
"@llvm-project//mlir:CAPIMemRefObjects",
"@llvm-project//mlir:CAPINVGPUObjects",
"@llvm-project//mlir:CAPINVVMObjects",
"@llvm-project//mlir:CAPISCFObjects",
"@llvm-project//mlir:CAPISparseTensorObjects",
"@llvm-project//mlir:CAPITransformsObjects",
"@llvm-project//mlir:CAPIVectorObjects",
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
"@shardy//shardy/integrations/c:sdy_capi_objects",
"@stablehlo//:chlo_capi_objects",
"@stablehlo//:stablehlo_capi_objects",
"@xla//xla/mlir_hlo:CAPIObjects",
] + if_windows(
[],
[
"//jaxlib/triton:triton_dialect_capi_objects",
],
),
)
cc_binary(
name = "libjaxlib_mlir_capi.so",
linkopts = [
"-Wl,-soname=libjaxlib_mlir_capi.so",
"-Wl,-rpath='$$ORIGIN'",
],
linkshared = 1,
deps = [":jaxlib_mlir_capi_objects"],
)
cc_binary(
name = "libjaxlib_mlir_capi.dylib",
linkopts = [
"-Wl,-rpath,@loader_path/",
"-Wl,-install_name,@loader_path/libjaxlib_mlir_capi.dylib",
],
linkshared = 1,
deps = [":jaxlib_mlir_capi_objects"],
)
windows_cc_shared_mlir_library(
name = "jaxlib_mlir_capi_dll",
out = "jaxlib_mlir_capi.dll",
exported_symbol_prefixes = [
"mlir",
"chlo",
"sdy",
"stablehlo",
],
deps = [":jaxlib_mlir_capi_objects"],
)