Fix jaxlib build failure after upstream MLIR Python binding changes.

https://github.com/llvm/llvm-project/pull/68853 changed the structure of
the upstream MLIR Python bindings, breaking the jaxlib build. Update our
build scripts to match.
This commit is contained in:
Peter Hawkins 2023-10-23 14:27:52 +00:00
parent 373c4212a4
commit caee898fd0
2 changed files with 6 additions and 10 deletions

View File

@ -40,13 +40,15 @@ def copy_file(
if isinstance(src_files, str):
src_files = [src_files]
for src_file in src_files:
src_file = runfiles.Rlocation(src_file)
src_filename = os.path.basename(src_file)
src_file_rloc = runfiles.Rlocation(src_file)
if src_file_rloc is None:
raise ValueError(f"Unable to find wheel source file {src_file}")
src_filename = os.path.basename(src_file_rloc)
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
if is_windows():
shutil.copyfile(src_file, dst_file)
shutil.copyfile(src_file_rloc, dst_file)
else:
shutil.copy(src_file, dst_file)
shutil.copy(src_file_rloc, dst_file)
def platform_tag(cpu: str) -> str:

View File

@ -273,21 +273,15 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
dst_dir=jaxlib_dir / "mlir" / "dialects",
src_files=[
"__main__/jaxlib/mlir/dialects/_arith_enum_gen.py",
"__main__/jaxlib/mlir/dialects/_arith_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_arith_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_func_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_func_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_math_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_memref_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_memref_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_ods_common.py",
"__main__/jaxlib/mlir/dialects/_scf_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_scf_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py",
"__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py",