diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 783190b1e..f57a9ae08 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -48,7 +48,6 @@ py_library_providing_imports_info( "//jaxlib/mlir:math_dialect", "//jaxlib/mlir:memref_dialect", "//jaxlib/mlir:mhlo_dialect", - "//jaxlib/mlir:ml_program_dialect", "//jaxlib/mlir:pass_manager", "//jaxlib/mlir:scf_dialect", "//jaxlib/mlir:sparse_tensor_dialect", diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index 00c16b2d5..ae47aacc9 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -20,7 +20,6 @@ import jaxlib.mlir.dialects.math as math import jaxlib.mlir.dialects.memref as memref import jaxlib.mlir.dialects.mhlo as mhlo import jaxlib.mlir.dialects.func as func -import jaxlib.mlir.dialects.ml_program as ml_program import jaxlib.mlir.dialects.scf as scf import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor import jaxlib.mlir.dialects.vector as vector diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 6ba5dc0d8..cb9d6e6b4 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -62,7 +62,6 @@ py_library_providing_imports_info( "//jaxlib/mlir:ir", "//jaxlib/mlir:memref_dialect", "//jaxlib/mlir:mhlo_dialect", - "//jaxlib/mlir:ml_program_dialect", "//jaxlib/mlir:pass_manager", "//jaxlib/mlir:scf_dialect", "//jaxlib/mlir:sparse_tensor_dialect", diff --git a/jaxlib/mlir/BUILD.bazel b/jaxlib/mlir/BUILD.bazel index b7abed018..ae2741586 100644 --- a/jaxlib/mlir/BUILD.bazel +++ b/jaxlib/mlir/BUILD.bazel @@ -126,19 +126,6 @@ symlink_inputs( ], ) -symlink_inputs( - name = "ml_program_dialect", - rule = py_library, - symlinked_inputs = {"srcs": {"dialects": [ - "@llvm-project//mlir/python:MLProgramOpsPyFiles", - ]}}, - deps = [ - ":core", - ":ir", - ":mlir", - ], -) - symlink_inputs( name = "builtin_dialect", rule = py_library, diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 0dd86a11a..4d4e1a96a 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -276,7 +276,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi "__main__/jaxlib/mlir/dialects/_math_ops_gen.py", "__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", "__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py", "__main__/jaxlib/mlir/dialects/_ods_common.py", "__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", "__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", @@ -291,7 +290,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi "__main__/jaxlib/mlir/dialects/math.py", "__main__/jaxlib/mlir/dialects/memref.py", "__main__/jaxlib/mlir/dialects/mhlo.py", - "__main__/jaxlib/mlir/dialects/ml_program.py", "__main__/jaxlib/mlir/dialects/scf.py", "__main__/jaxlib/mlir/dialects/sparse_tensor.py", "__main__/jaxlib/mlir/dialects/stablehlo.py",