mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00

Now that https://github.com/llvm/llvm-project/pull/117922 has landed upstream, we can work towards removing our last uses of pybind11. PiperOrigin-RevId: 705872751
436 lines
12 KiB
Python
436 lines
12 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
load(
|
|
"//jaxlib:jax.bzl",
|
|
"if_windows",
|
|
"py_extension",
|
|
"pybind_extension",
|
|
"windows_cc_shared_mlir_library",
|
|
)
|
|
load("//jaxlib:symlink_files.bzl", "symlink_inputs")
|
|
|
|
package(
|
|
default_visibility = [
|
|
"//visibility:public",
|
|
],
|
|
licenses = ["notice"],
|
|
)
|
|
|
|
COPTS = [
|
|
"-fexceptions",
|
|
"-frtti",
|
|
]
|
|
|
|
LINKOPTS = select({
|
|
"@xla//xla/tsl:macos": [
|
|
"-Wl,-rpath,@loader_path/",
|
|
"-Wl,-rename_section,__TEXT,text_env,__TEXT,__text",
|
|
],
|
|
"@xla//xla/tsl:windows": [],
|
|
"//conditions:default": [
|
|
"-Wl,-rpath,$$ORIGIN/",
|
|
],
|
|
})
|
|
|
|
py_extension(
|
|
name = "_mlir",
|
|
srcs = [
|
|
"@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp",
|
|
],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI",
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
],
|
|
)
|
|
|
|
py_extension(
|
|
name = "_mlirDialectsGPU",
|
|
srcs = [
|
|
"@llvm-project//mlir:lib/Bindings/Python/DialectGPU.cpp",
|
|
],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//mlir:CAPIGPUHeaders",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
"@pybind11",
|
|
],
|
|
)
|
|
|
|
py_extension(
|
|
name = "_mlirGPUPasses",
|
|
srcs = [
|
|
"@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp",
|
|
],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//mlir:CAPIGPUHeaders",
|
|
"@pybind11",
|
|
],
|
|
)
|
|
|
|
py_extension(
|
|
name = "_mlirDialectsNVGPU",
|
|
srcs = [
|
|
"@llvm-project//mlir:lib/Bindings/Python/DialectNVGPU.cpp",
|
|
],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:CAPINVGPUHeaders",
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
"@pybind11",
|
|
],
|
|
)
|
|
|
|
py_extension(
|
|
name = "_mlirDialectsLLVM",
|
|
srcs = [
|
|
"@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp",
|
|
],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:CAPILLVMHeaders",
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
"@pybind11",
|
|
],
|
|
)
|
|
|
|
py_extension(
|
|
name = "_mlirDialectsSparseTensor",
|
|
srcs = [
|
|
"@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp",
|
|
],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
"@pybind11",
|
|
],
|
|
)
|
|
|
|
py_extension(
|
|
name = "_mlirSparseTensorPasses",
|
|
srcs = [
|
|
"@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp",
|
|
],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
|
"@pybind11",
|
|
],
|
|
)
|
|
|
|
py_extension(
|
|
name = "_mosaic_gpu_ext",
|
|
srcs = ["mosaic_gpu_ext.cc"],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
|
|
"@nanobind",
|
|
],
|
|
)
|
|
|
|
# This is here, instead of in jaxlib/mosaic/python, so it's in the same
|
|
# directory as libjaxlib_mlir_capi.so (produced by
|
|
# :jaxlib_mlir_capi_shared_library). This ensures that the RPATH works correctly
|
|
# across platforms. It's not clear if Windows supports RPATH-like functionality
|
|
# across different directories at all.
|
|
py_extension(
|
|
name = "_tpu_ext",
|
|
srcs = ["tpu_ext.cc"],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"//jaxlib/mosaic:tpu_dialect_capi_headers",
|
|
"@com_google_absl//absl/log:check",
|
|
"@llvm-project//llvm:Support",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
|
|
"@nanobind",
|
|
"@xla//xla/python:nb_numpy",
|
|
"@xla//xla/tsl/python/lib/core:numpy",
|
|
],
|
|
)
|
|
|
|
# This target contains the extension and it's Python dependencies, which are not
|
|
# supported by the `py_extension`/`pybind_extension` macros.
|
|
py_library(
|
|
name = "_tpu_ext_lib",
|
|
deps = [
|
|
":_tpu_ext",
|
|
"//jaxlib/mlir:ir",
|
|
"//jaxlib/mosaic/python:layout_defs",
|
|
],
|
|
)
|
|
|
|
pybind_extension(
|
|
name = "_triton_ext",
|
|
srcs = ["triton_ext.cc"],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
pytype_srcs = ["_triton_ext.pyi"],
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"//jaxlib/triton:triton_dialect_capi_headers",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
|
|
"@nanobind",
|
|
],
|
|
)
|
|
|
|
symlink_inputs(
|
|
name = "_mlir_libs",
|
|
rule = py_library,
|
|
symlinked_inputs = {"srcs": {
|
|
".": [
|
|
"@llvm-project//mlir/python:MlirLibsPyFiles",
|
|
],
|
|
}},
|
|
deps = [
|
|
":_mlir",
|
|
":register_jax_dialects",
|
|
],
|
|
)
|
|
|
|
cc_library(
|
|
name = "jaxlib_mlir_capi_shims",
|
|
srcs = ["jaxlib_mlir_capi_shims.cc"],
|
|
hdrs = ["jaxlib_mlir_capi_shims.h"],
|
|
deps = [
|
|
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:GPUPipelines",
|
|
"@llvm-project//mlir:GPUToLLVMIRTranslation",
|
|
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
|
|
"@llvm-project//mlir:MemRefTransforms",
|
|
"@llvm-project//mlir:NVVMTarget",
|
|
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
|
|
],
|
|
alwayslink = 1,
|
|
)
|
|
|
|
cc_library(
|
|
name = "jaxlib_mlir_capi_shims_hdrs",
|
|
hdrs = ["jaxlib_mlir_capi_shims.h"],
|
|
deps = [
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
],
|
|
)
|
|
|
|
# JAX-specific registrations.
|
|
py_extension(
|
|
name = "register_jax_dialects",
|
|
srcs = ["register_jax_dialects.cc"],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//mlir:CAPIArithHeaders",
|
|
"@llvm-project//mlir:CAPIGPUHeaders",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:CAPILLVMHeaders",
|
|
"@llvm-project//mlir:CAPIMathHeaders",
|
|
"@llvm-project//mlir:CAPIMemRefHeaders",
|
|
"@llvm-project//mlir:CAPINVGPUHeaders",
|
|
"@llvm-project//mlir:CAPINVVMHeaders",
|
|
"@llvm-project//mlir:CAPISCFHeaders",
|
|
"@llvm-project//mlir:CAPITransformsHeaders",
|
|
"@llvm-project//mlir:CAPIVectorHeaders",
|
|
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders",
|
|
"@local_config_python//:headers",
|
|
"@nanobind",
|
|
"@shardy//shardy/integrations/c:sdy_capi_headers",
|
|
],
|
|
)
|
|
|
|
##---------------------------------------------------------------------------##
|
|
# MHLO Extensions
|
|
##---------------------------------------------------------------------------##
|
|
|
|
py_extension(
|
|
name = "_mlirHlo",
|
|
srcs = [
|
|
"@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc",
|
|
],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
"@local_config_python//:headers",
|
|
"@pybind11",
|
|
"@xla//xla/mlir_hlo:CAPIHeaders",
|
|
],
|
|
)
|
|
|
|
##---------------------------------------------------------------------------##
|
|
# Shardy Extensions
|
|
##---------------------------------------------------------------------------##
|
|
|
|
py_extension(
|
|
name = "_sdy",
|
|
srcs = [
|
|
"@shardy//shardy/integrations/python/ir:sdy_module.cc",
|
|
],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:IR",
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
"@local_config_python//:headers",
|
|
"@pybind11",
|
|
"@shardy//shardy/integrations/c:sdy_capi_headers",
|
|
],
|
|
)
|
|
|
|
##---------------------------------------------------------------------------##
|
|
# Stablehlo Extensions
|
|
##---------------------------------------------------------------------------##
|
|
|
|
py_extension(
|
|
name = "_chlo",
|
|
srcs = [
|
|
"@stablehlo//:chlo_py_api_files",
|
|
],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
"@local_config_python//:headers",
|
|
"@pybind11",
|
|
"@stablehlo//:chlo_capi_headers",
|
|
],
|
|
)
|
|
|
|
py_extension(
|
|
name = "_stablehlo",
|
|
srcs = [
|
|
"@stablehlo//:stablehlo_py_api_files",
|
|
],
|
|
copts = COPTS,
|
|
linkopts = LINKOPTS,
|
|
deps = [
|
|
":jaxlib_mlir_capi_shared_library",
|
|
"@llvm-project//llvm:Support",
|
|
"@llvm-project//mlir:CAPIIRHeaders",
|
|
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
|
"@local_config_python//:headers",
|
|
"@pybind11",
|
|
"@stablehlo//:stablehlo_capi_headers",
|
|
],
|
|
)
|
|
|
|
# Shared C++ extension library
|
|
|
|
cc_library(
|
|
name = "jaxlib_mlir_capi_shared_library",
|
|
srcs = select({
|
|
"@xla//xla/tsl:windows": [":jaxlib_mlir_capi.dll"],
|
|
"@xla//xla/tsl:macos": [":libjaxlib_mlir_capi.dylib"],
|
|
"//conditions:default": [":libjaxlib_mlir_capi.so"],
|
|
}),
|
|
deps = select({
|
|
"@xla//xla/tsl:windows": [":jaxlib_mlir_capi_dll"],
|
|
"//conditions:default": [],
|
|
}),
|
|
)
|
|
|
|
cc_library(
|
|
name = "jaxlib_mlir_capi_objects",
|
|
deps = [
|
|
"//jaxlib/mosaic:tpu_dialect_capi_objects",
|
|
"@llvm-project//mlir:CAPIArithObjects",
|
|
"@llvm-project//mlir:CAPIGPUObjects",
|
|
"@llvm-project//mlir:CAPIIRObjects",
|
|
"@llvm-project//mlir:CAPILLVMObjects",
|
|
"@llvm-project//mlir:CAPIMathObjects",
|
|
"@llvm-project//mlir:CAPIMemRefObjects",
|
|
"@llvm-project//mlir:CAPINVGPUObjects",
|
|
"@llvm-project//mlir:CAPINVVMObjects",
|
|
"@llvm-project//mlir:CAPISCFObjects",
|
|
"@llvm-project//mlir:CAPISparseTensorObjects",
|
|
"@llvm-project//mlir:CAPITransformsObjects",
|
|
"@llvm-project//mlir:CAPIVectorObjects",
|
|
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
|
|
"@shardy//shardy/integrations/c:sdy_capi_objects",
|
|
"@stablehlo//:chlo_capi_objects",
|
|
"@stablehlo//:stablehlo_capi_objects",
|
|
"@xla//xla/mlir_hlo:CAPIObjects",
|
|
] + if_windows(
|
|
[],
|
|
[
|
|
"//jaxlib/triton:triton_dialect_capi_objects",
|
|
],
|
|
),
|
|
)
|
|
|
|
cc_binary(
|
|
name = "libjaxlib_mlir_capi.so",
|
|
linkopts = [
|
|
"-Wl,-soname=libjaxlib_mlir_capi.so",
|
|
"-Wl,-rpath='$$ORIGIN'",
|
|
],
|
|
linkshared = 1,
|
|
deps = [":jaxlib_mlir_capi_objects"],
|
|
)
|
|
|
|
cc_binary(
|
|
name = "libjaxlib_mlir_capi.dylib",
|
|
linkopts = [
|
|
"-Wl,-rpath,@loader_path/",
|
|
"-Wl,-install_name,@loader_path/libjaxlib_mlir_capi.dylib",
|
|
],
|
|
linkshared = 1,
|
|
deps = [":jaxlib_mlir_capi_objects"],
|
|
)
|
|
|
|
windows_cc_shared_mlir_library(
|
|
name = "jaxlib_mlir_capi_dll",
|
|
out = "jaxlib_mlir_capi.dll",
|
|
exported_symbol_prefixes = [
|
|
"mlir",
|
|
"chlo",
|
|
"sdy",
|
|
"stablehlo",
|
|
],
|
|
deps = [":jaxlib_mlir_capi_objects"],
|
|
)
|