mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
_mlirTransforms merged into _mlirRegisterEverything.
PiperOrigin-RevId: 462233907
This commit is contained in:
parent
ad67d825fe
commit
d8f0099f68
@ -206,14 +206,11 @@ def prepare_wheel(sources_path):
|
||||
mlir_dir = os.path.join(jaxlib_dir, "mlir")
|
||||
mlir_dialects_dir = os.path.join(jaxlib_dir, "mlir", "dialects")
|
||||
mlir_libs_dir = os.path.join(jaxlib_dir, "mlir", "_mlir_libs")
|
||||
mlir_transforms_dir = os.path.join(jaxlib_dir, "mlir", "transforms")
|
||||
os.makedirs(mlir_dir)
|
||||
os.makedirs(mlir_dialects_dir)
|
||||
os.makedirs(mlir_libs_dir)
|
||||
os.makedirs(mlir_transforms_dir)
|
||||
copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir)
|
||||
copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir)
|
||||
copy_file("__main__/jaxlib/mlir/transforms/__init__.py", dst_dir=mlir_transforms_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
@ -231,11 +228,12 @@ def prepare_wheel(sources_path):
|
||||
copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/ml_program.py", dst_dir=mlir_dialects_dir)
|
||||
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/__init__.py", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirRegisterEverything.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirTransforms.{pyext}", dst_dir=mlir_libs_dir)
|
||||
if _is_windows():
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir)
|
||||
elif _is_mac():
|
||||
|
@ -44,7 +44,7 @@ symlink_inputs(
|
||||
py_library(
|
||||
name = "mlir",
|
||||
deps = [
|
||||
"//jaxlib/mlir/_mlir_libs:_mlir",
|
||||
"//jaxlib/mlir/_mlir_libs:_mlir_libs",
|
||||
],
|
||||
)
|
||||
|
||||
@ -128,17 +128,3 @@ symlink_inputs(
|
||||
":mlir",
|
||||
],
|
||||
)
|
||||
|
||||
symlink_inputs(
|
||||
name = "transforms",
|
||||
rule = py_library,
|
||||
symlinked_inputs = {"srcs": {
|
||||
"transforms": [
|
||||
"@llvm-project//mlir/python:TransformsPyFiles",
|
||||
],
|
||||
}},
|
||||
deps = [
|
||||
":mlir",
|
||||
"//jaxlib/mlir/_mlir_libs:_mlirTransforms",
|
||||
],
|
||||
)
|
||||
|
@ -18,6 +18,8 @@ load(
|
||||
"windows_cc_shared_mlir_library",
|
||||
)
|
||||
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_inputs")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
@ -71,18 +73,36 @@ py_extension(
|
||||
)
|
||||
|
||||
py_extension(
|
||||
name = "_mlirTransforms",
|
||||
name = "_mlirRegisterEverything",
|
||||
srcs = [
|
||||
"@llvm-project//mlir:lib/Bindings/Python/Transforms/Transforms.cpp",
|
||||
"@llvm-project//mlir:lib/Bindings/Python/RegisterEverything.cpp",
|
||||
"@llvm-project//mlir:include/mlir/Bindings/Python/PybindAdaptors.h",
|
||||
"@llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h",
|
||||
],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPIConversionHeaders",
|
||||
"@llvm-project//mlir:CAPIRegisterEverythingHeaders",
|
||||
"@llvm-project//mlir:CAPITransformsHeaders",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
symlink_inputs(
|
||||
name = "_mlir_libs",
|
||||
rule = py_library,
|
||||
symlinked_inputs = {"srcs": {
|
||||
".": [
|
||||
"@llvm-project//mlir/python:MlirLibsPyFiles",
|
||||
],
|
||||
}},
|
||||
deps = [
|
||||
":_mlir",
|
||||
":_mlirRegisterEverything",
|
||||
],
|
||||
)
|
||||
|
||||
##---------------------------------------------------------------------------##
|
||||
# MHLO Extensions
|
||||
##---------------------------------------------------------------------------##
|
||||
@ -122,6 +142,8 @@ cc_library(
|
||||
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
|
||||
"@llvm-project//mlir:CAPISparseTensorObjects",
|
||||
"@llvm-project//mlir:CAPITransformsObjects",
|
||||
"@llvm-project//mlir:CAPIRegisterEverythingObjects",
|
||||
"@llvm-project//mlir:CAPIConversionObjects",
|
||||
"@org_tensorflow//tensorflow/compiler/mlir/hlo:CAPIObjects",
|
||||
],
|
||||
)
|
||||
|
@ -61,11 +61,11 @@ setup(
|
||||
'cuda/nvvm/libdevice/libdevice*',
|
||||
'mlir/*.py',
|
||||
'mlir/dialects/*.py',
|
||||
'mlir/transforms/*.py',
|
||||
'mlir/_mlir_libs/*.dll',
|
||||
'mlir/_mlir_libs/*.dylib',
|
||||
'mlir/_mlir_libs/*.so',
|
||||
'mlir/_mlir_libs/*.pyd',
|
||||
'mlir/_mlir_libs/*.py',
|
||||
'rocm/*',
|
||||
],
|
||||
'jaxlib.xla_extension': ['*.pyi'],
|
||||
|
Loading…
x
Reference in New Issue
Block a user