_mlirTransforms merged into _mlirRegisterEverything.

PiperOrigin-RevId: 462233907
This commit is contained in:
Parker Schuh 2022-07-20 14:42:56 -07:00 committed by jax authors
parent ad67d825fe
commit d8f0099f68
4 changed files with 28 additions and 22 deletions

View File

@ -206,14 +206,11 @@ def prepare_wheel(sources_path):
mlir_dir = os.path.join(jaxlib_dir, "mlir")
mlir_dialects_dir = os.path.join(jaxlib_dir, "mlir", "dialects")
mlir_libs_dir = os.path.join(jaxlib_dir, "mlir", "_mlir_libs")
mlir_transforms_dir = os.path.join(jaxlib_dir, "mlir", "transforms")
os.makedirs(mlir_dir)
os.makedirs(mlir_dialects_dir)
os.makedirs(mlir_libs_dir)
os.makedirs(mlir_transforms_dir)
copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir)
copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir)
copy_file("__main__/jaxlib/mlir/transforms/__init__.py", dst_dir=mlir_transforms_dir)
copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir)
@ -231,11 +228,12 @@ def prepare_wheel(sources_path):
copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/ml_program.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/_mlir_libs/__init__.py", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirRegisterEverything.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirTransforms.{pyext}", dst_dir=mlir_libs_dir)
if _is_windows():
copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir)
elif _is_mac():

View File

@ -44,7 +44,7 @@ symlink_inputs(
py_library(
name = "mlir",
deps = [
"//jaxlib/mlir/_mlir_libs:_mlir",
"//jaxlib/mlir/_mlir_libs:_mlir_libs",
],
)
@ -128,17 +128,3 @@ symlink_inputs(
":mlir",
],
)
symlink_inputs(
name = "transforms",
rule = py_library,
symlinked_inputs = {"srcs": {
"transforms": [
"@llvm-project//mlir/python:TransformsPyFiles",
],
}},
deps = [
":mlir",
"//jaxlib/mlir/_mlir_libs:_mlirTransforms",
],
)

View File

@ -18,6 +18,8 @@ load(
"windows_cc_shared_mlir_library",
)
load("//jaxlib:symlink_files.bzl", "symlink_inputs")
package(
default_visibility = [
"//visibility:public",
@ -71,18 +73,36 @@ py_extension(
)
py_extension(
name = "_mlirTransforms",
name = "_mlirRegisterEverything",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/Transforms/Transforms.cpp",
"@llvm-project//mlir:lib/Bindings/Python/RegisterEverything.cpp",
"@llvm-project//mlir:include/mlir/Bindings/Python/PybindAdaptors.h",
"@llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h",
],
copts = COPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIConversionHeaders",
"@llvm-project//mlir:CAPIRegisterEverythingHeaders",
"@llvm-project//mlir:CAPITransformsHeaders",
"@pybind11",
],
)
symlink_inputs(
name = "_mlir_libs",
rule = py_library,
symlinked_inputs = {"srcs": {
".": [
"@llvm-project//mlir/python:MlirLibsPyFiles",
],
}},
deps = [
":_mlir",
":_mlirRegisterEverything",
],
)
##---------------------------------------------------------------------------##
# MHLO Extensions
##---------------------------------------------------------------------------##
@ -122,6 +142,8 @@ cc_library(
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
"@llvm-project//mlir:CAPISparseTensorObjects",
"@llvm-project//mlir:CAPITransformsObjects",
"@llvm-project//mlir:CAPIRegisterEverythingObjects",
"@llvm-project//mlir:CAPIConversionObjects",
"@org_tensorflow//tensorflow/compiler/mlir/hlo:CAPIObjects",
],
)

View File

@ -61,11 +61,11 @@ setup(
'cuda/nvvm/libdevice/libdevice*',
'mlir/*.py',
'mlir/dialects/*.py',
'mlir/transforms/*.py',
'mlir/_mlir_libs/*.dll',
'mlir/_mlir_libs/*.dylib',
'mlir/_mlir_libs/*.so',
'mlir/_mlir_libs/*.pyd',
'mlir/_mlir_libs/*.py',
'rocm/*',
],
'jaxlib.xla_extension': ['*.pyi'],