Merge pull request #17864 from hawkinsp:buildwheel

PiperOrigin-RevId: 569576169
This commit is contained in:
jax authors 2023-09-29 13:34:42 -07:00
commit 095c367c01
3 changed files with 265 additions and 208 deletions

View File

@ -14,12 +14,16 @@
"""Utilities for the building JAX related python packages."""
from __future__ import annotations
import os
import pathlib
import platform
import shutil
import sys
import subprocess
import glob
from typing import Sequence
def is_windows() -> bool:
@ -27,20 +31,22 @@ def is_windows() -> bool:
def copy_file(
src_file: str,
dst_dir: str,
dst_filename=None,
from_runfiles=True,
runfiles=None,
src_files: str | Sequence[str],
dst_dir: pathlib.Path,
dst_filename = None,
runfiles = None,
) -> None:
if from_runfiles:
dst_dir.mkdir(parents=True, exist_ok=True)
if isinstance(src_files, str):
src_files = [src_files]
for src_file in src_files:
src_file = runfiles.Rlocation(src_file)
src_filename = os.path.basename(src_file)
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
if is_windows():
shutil.copyfile(src_file, dst_file)
else:
shutil.copy(src_file, dst_file)
src_filename = os.path.basename(src_file)
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
if is_windows():
shutil.copyfile(src_file, dst_file)
else:
shutil.copy(src_file, dst_file)
def platform_tag(cpu: str) -> str:

View File

@ -18,7 +18,9 @@
# Most users should not run this script directly; use build.py instead.
import argparse
import functools
import os
import pathlib
import tempfile
from bazel_tools.tools.python.runfiles import runfiles
@ -26,20 +28,20 @@ from jax.tools import build_utils
parser = argparse.ArgumentParser()
parser.add_argument(
"--sources_path",
default=None,
help="Path in which the wheel's sources should be prepared. Optional. If "
"omitted, a temporary directory will be used.")
"--sources_path",
default=None,
help="Path in which the wheel's sources should be prepared. Optional. If "
"omitted, a temporary directory will be used.",
)
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.")
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.",
)
parser.add_argument(
"--cpu",
default=None,
required=True,
help="Target CPU architecture. Required.")
"--cpu", default=None, required=True, help="Target CPU architecture. Required."
)
parser.add_argument(
"--cuda_version",
default=None,
@ -58,51 +60,53 @@ r = runfiles.Create()
def write_setup_cfg(sources_path, cpu):
tag = build_utils.platform_tag(cpu)
with open(os.path.join(sources_path, "setup.cfg"), "w") as f:
f.write(f"""[metadata]
with open(sources_path / "setup.cfg", "w") as f:
f.write(
f"""[metadata]
license_files = LICENSE.txt
[bdist_wheel]
plat-name={tag}
python-tag=py3
""")
"""
)
def update_setup(file_dir, cuda_version):
src_file = os.path.join(file_dir, "setup.py")
with open(os.path.join(src_file), "r") as f:
src_file = file_dir / "setup.py"
with open(src_file, "r") as f:
content = f.read()
content = content.replace(
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"
)
with open(os.path.join(src_file), "w") as f:
with open(src_file, "w") as f:
f.write(content)
def prepare_cuda_plugin_wheel(sources_path, *, cpu, cuda_version):
def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version):
"""Assembles a source tree for the wheel in `sources_path`."""
jax_plugins_dir = os.path.join(sources_path, "jax_plugins")
os.makedirs(jax_plugins_dir)
plugin_dir = os.path.join(jax_plugins_dir, f"xla_cuda_cu{cuda_version}")
os.makedirs(plugin_dir)
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
build_utils.copy_file(
"__main__/plugins/cuda/pyproject.toml", dst_dir=sources_path, runfiles=r
)
build_utils.copy_file(
"__main__/plugins/cuda/setup.py", dst_dir=sources_path, runfiles=r
plugin_dir = sources_path / "jax_plugins" / f"xla_cuda_cu{cuda_version}"
copy_runfiles(
dst_dir=sources_path,
src_files=[
"__main__/plugins/cuda/pyproject.toml",
"__main__/plugins/cuda/setup.py",
],
)
update_setup(sources_path, cuda_version)
write_setup_cfg(sources_path, cpu)
build_utils.copy_file(
"__main__/plugins/cuda/__init__.py", dst_dir=plugin_dir, runfiles=r
copy_runfiles(
dst_dir=plugin_dir,
src_files=[
"__main__/plugins/cuda/__init__.py",
],
)
plugin_so_path = r.Rlocation("xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so")
build_utils.copy_file(
plugin_so_path,
copy_runfiles(
"xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so",
dst_dir=plugin_dir,
dst_filename="xla_cuda_plugin.so",
runfiles=r,
)
@ -115,7 +119,7 @@ if sources_path is None:
try:
os.makedirs(args.output_path, exist_ok=True)
prepare_cuda_plugin_wheel(
sources_path, cpu=args.cpu, cuda_version=args.cuda_version
pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.cuda_version
)
package_name = "jax cuda plugin"
if args.editable:

View File

@ -20,6 +20,7 @@
import argparse
import functools
import os
import pathlib
import platform
import re
import subprocess
@ -30,24 +31,25 @@ from jax.tools import build_utils
parser = argparse.ArgumentParser()
parser.add_argument(
"--sources_path",
default=None,
help="Path in which the wheel's sources should be prepared. Optional. If "
"omitted, a temporary directory will be used.")
"--sources_path",
default=None,
help="Path in which the wheel's sources should be prepared. Optional. If "
"omitted, a temporary directory will be used.",
)
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.")
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.",
)
parser.add_argument(
"--cpu",
default=None,
required=True,
help="Target CPU architecture. Required.")
"--cpu", default=None, required=True, help="Target CPU architecture. Required."
)
parser.add_argument(
"--editable",
action="store_true",
help="Create an 'editable' jaxlib build instead of a wheel.")
"--editable",
action="store_true",
help="Create an 'editable' jaxlib build instead of a wheel.",
)
parser.add_argument(
"--include_gpu_plugin_extension",
# args.include_gpu_plugin_extension is True when
@ -77,10 +79,16 @@ def patch_copy_mlir_import(src_file, dst_dir):
with open(src_file) as f:
src = f.read()
with open(os.path.join(dst_dir, src_filename), 'w') as f:
replaced = re.sub(r'^from mlir(\..*)? import (.*)', r'from jaxlib.mlir\1 import \2', src, flags=re.MULTILINE)
with open(dst_dir / src_filename, "w") as f:
replaced = re.sub(
r"^from mlir(\..*)? import (.*)",
r"from jaxlib.mlir\1 import \2",
src,
flags=re.MULTILINE,
)
f.write(replaced)
_XLA_EXTENSION_STUBS = [
"__init__.pyi",
"jax_jit.pyi",
@ -91,28 +99,21 @@ _XLA_EXTENSION_STUBS = [
"pytree.pyi",
"transfer_guard_lib.pyi",
]
_OPTIONAL_XLA_EXTENSION_STUBS = [
]
_OPTIONAL_XLA_EXTENSION_STUBS = []
def patch_copy_xla_extension_stubs(dst_dir):
# This file is required by PEP-561. It marks jaxlib as package containing
# type stubs.
with open(os.path.join(dst_dir, "py.typed"), "w"):
pass
xla_extension_dir = os.path.join(dst_dir, "xla_extension")
os.makedirs(xla_extension_dir)
for stub_name in _XLA_EXTENSION_STUBS:
stub_path = r.Rlocation(
"xla/xla/python/xla_extension/" + stub_name)
stub_path = r.Rlocation("xla/xla/python/xla_extension/" + stub_name)
stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path).
if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path):
continue
with open(stub_path) as f:
src = f.read()
src = src.replace(
"from xla.python import xla_extension",
"from .. import xla_extension"
"from xla.python import xla_extension", "from .. import xla_extension"
)
with open(os.path.join(xla_extension_dir, stub_name), "w") as f:
f.write(src)
@ -130,160 +131,206 @@ def verify_mac_libraries_dont_reference_chkstack():
if not _is_mac():
return
nm = subprocess.run(
["nm", "-g",
r.Rlocation("xla/xla/python/xla_extension.so")
],
capture_output=True, text=True,
check=False)
["nm", "-g", r.Rlocation("xla/xla/python/xla_extension.so")],
capture_output=True,
text=True,
check=False,
)
if nm.returncode != 0:
raise RuntimeError(f"nm process failed: {nm.stdout} {nm.stderr}")
if "____chkstk_darwin" in nm.stdout:
raise RuntimeError(
"Mac wheel incorrectly depends on symbol ____chkstk_darwin, which "
"means that it isn't compatible with older MacOS versions.")
"Mac wheel incorrectly depends on symbol ____chkstk_darwin, which "
"means that it isn't compatible with older MacOS versions."
)
def write_setup_cfg(sources_path, cpu):
tag = build_utils.platform_tag(cpu)
with open(os.path.join(sources_path, "setup.cfg"), "w") as f:
f.write(f"""[metadata]
with open(sources_path / "setup.cfg", "w") as f:
f.write(
f"""[metadata]
license_files = LICENSE.txt
[bdist_wheel]
plat-name={tag}
""")
"""
)
def prepare_wheel(sources_path, *, cpu, include_gpu_plugin_extension):
def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extension):
"""Assembles a source tree for the wheel in `sources_path`."""
jaxlib_dir = os.path.join(sources_path, "jaxlib")
os.makedirs(jaxlib_dir)
copy_to_jaxlib = functools.partial(build_utils.copy_file, dst_dir=jaxlib_dir, runfiles=r)
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
verify_mac_libraries_dont_reference_chkstack()
build_utils.copy_file("__main__/jaxlib/tools/LICENSE.txt", dst_dir=sources_path, runfiles=r)
build_utils.copy_file("__main__/jaxlib/README.md", dst_dir=sources_path, runfiles=r)
build_utils.copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path, runfiles=r)
copy_runfiles(
dst_dir=sources_path,
src_files=[
"__main__/jaxlib/tools/LICENSE.txt",
"__main__/jaxlib/README.md",
"__main__/jaxlib/setup.py",
],
)
write_setup_cfg(sources_path, cpu)
copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py")
copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}")
jaxlib_dir = sources_path / "jaxlib"
copy_runfiles(
"__main__/jaxlib/init.py", dst_dir=jaxlib_dir, dst_filename="__init__.py"
)
if include_gpu_plugin_extension:
copy_to_jaxlib(f"__main__/jaxlib/cuda_plugin_extension.{pyext}")
copy_to_jaxlib(f"__main__/jaxlib/utils.{pyext}")
copy_to_jaxlib("__main__/jaxlib/lapack.py")
copy_to_jaxlib("__main__/jaxlib/hlo_helpers.py")
copy_to_jaxlib("__main__/jaxlib/ducc_fft.py")
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
copy_to_jaxlib("__main__/jaxlib/gpu_rnn.py")
copy_to_jaxlib("__main__/jaxlib/gpu_triton.py")
copy_to_jaxlib("__main__/jaxlib/gpu_common_utils.py")
copy_to_jaxlib("__main__/jaxlib/gpu_solver.py")
copy_to_jaxlib("__main__/jaxlib/gpu_sparse.py")
copy_to_jaxlib("__main__/jaxlib/tpu_mosaic.py")
copy_to_jaxlib("__main__/jaxlib/version.py")
copy_to_jaxlib("__main__/jaxlib/xla_client.py")
copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}")
cpu_dir = os.path.join(jaxlib_dir, "cpu")
os.makedirs(cpu_dir)
build_utils.copy_file(f"__main__/jaxlib/cpu/_lapack.{pyext}", dst_dir=cpu_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir, runfiles=r)
cuda_dir = os.path.join(jaxlib_dir, "cuda")
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"):
libdevice_dir = os.path.join(cuda_dir, "nvvm", "libdevice")
os.makedirs(libdevice_dir)
build_utils.copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_rnn.{pyext}", dst_dir=cuda_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_triton.{pyext}", dst_dir=cuda_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_versions.{pyext}", dst_dir=cuda_dir, runfiles=r)
rocm_dir = os.path.join(jaxlib_dir, "rocm")
if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"):
os.makedirs(rocm_dir)
build_utils.copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir, runfiles=r)
if exists(f"__main__/jaxlib/cuda/_sparse.{pyext}"):
build_utils.copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir, runfiles=r)
if exists(f"__main__/jaxlib/rocm/_sparse.{pyext}"):
build_utils.copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir, runfiles=r)
mosaic_dir = os.path.join(jaxlib_dir, "mosaic")
mosaic_python_dir = os.path.join(mosaic_dir, "python")
os.makedirs(mosaic_dir)
os.makedirs(mosaic_python_dir)
copy_to_jaxlib("__main__/jaxlib/mosaic/python/apply_vector_layout.py", dst_dir=mosaic_python_dir)
copy_to_jaxlib("__main__/jaxlib/mosaic/python/infer_memref_layout.py", dst_dir=mosaic_python_dir)
copy_to_jaxlib("__main__/jaxlib/mosaic/python/tpu.py", dst_dir=mosaic_python_dir)
build_utils.copy_file("__main__/jaxlib/mosaic/python/_tpu_ops_ext.py", dst_dir=mosaic_python_dir, runfiles=r)
# TODO (sharadmv,skyewm): can we avoid patching this file?
patch_copy_mlir_import("__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir)
mlir_dir = os.path.join(jaxlib_dir, "mlir")
mlir_dialects_dir = os.path.join(jaxlib_dir, "mlir", "dialects")
mlir_libs_dir = os.path.join(jaxlib_dir, "mlir", "_mlir_libs")
os.makedirs(mlir_dir)
os.makedirs(mlir_dialects_dir)
os.makedirs(mlir_libs_dir)
build_utils.copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ods_common.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_func_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_func_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/sparse_tensor.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/builtin.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/chlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/arith.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/math.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_math_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/memref.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/scf.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/vector.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/mhlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/stablehlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/ml_program.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/__init__.py", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
if build_utils.is_windows():
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir, runfiles=r)
elif _is_mac():
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib", dst_dir=mlir_libs_dir, runfiles=r)
else:
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir, runfiles=r)
copy_runfiles(
dst_dir=jaxlib_dir, src_files=[f"__main__/jaxlib/cuda_plugin_extension.{pyext}"]
)
copy_runfiles(
dst_dir=jaxlib_dir,
src_files=[
f"__main__/jaxlib/cpu_feature_guard.{pyext}",
f"__main__/jaxlib/utils.{pyext}",
"__main__/jaxlib/lapack.py",
"__main__/jaxlib/hlo_helpers.py",
"__main__/jaxlib/ducc_fft.py",
"__main__/jaxlib/gpu_prng.py",
"__main__/jaxlib/gpu_linalg.py",
"__main__/jaxlib/gpu_rnn.py",
"__main__/jaxlib/gpu_triton.py",
"__main__/jaxlib/gpu_common_utils.py",
"__main__/jaxlib/gpu_solver.py",
"__main__/jaxlib/gpu_sparse.py",
"__main__/jaxlib/tpu_mosaic.py",
"__main__/jaxlib/version.py",
"__main__/jaxlib/xla_client.py",
f"__main__/jaxlib/xla_extension.{pyext}",
],
)
# This file is required by PEP-561. It marks jaxlib as package containing
# type stubs.
with open(jaxlib_dir / "py.typed", "w"):
pass
patch_copy_xla_extension_stubs(jaxlib_dir)
copy_runfiles(
dst_dir=jaxlib_dir / "cpu",
src_files=[
f"__main__/jaxlib/cpu/_lapack.{pyext}",
f"__main__/jaxlib/cpu/_ducc_fft.{pyext}",
],
)
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"):
copy_runfiles(
dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice",
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
)
copy_runfiles(
dst_dir=jaxlib_dir / "cuda",
src_files=[
f"__main__/jaxlib/cuda/_solver.{pyext}",
f"__main__/jaxlib/cuda/_blas.{pyext}",
f"__main__/jaxlib/cuda/_linalg.{pyext}",
f"__main__/jaxlib/cuda/_prng.{pyext}",
f"__main__/jaxlib/cuda/_rnn.{pyext}",
f"__main__/jaxlib/cuda/_sparse.{pyext}",
f"__main__/jaxlib/cuda/_triton.{pyext}",
f"__main__/jaxlib/cuda/_versions.{pyext}",
],
)
if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"):
copy_runfiles(
dst_dir=jaxlib_dir / "rocm",
src_files=[
f"__main__/jaxlib/rocm/_solver.{pyext}",
f"__main__/jaxlib/rocm/_blas.{pyext}",
f"__main__/jaxlib/rocm/_linalg.{pyext}",
f"__main__/jaxlib/rocm/_prng.{pyext}",
f"__main__/jaxlib/rocm/_sparse.{pyext}",
],
)
mosaic_python_dir = jaxlib_dir / "mosaic" / "python"
copy_runfiles(
dst_dir=mosaic_python_dir,
src_files=[
"__main__/jaxlib/mosaic/python/apply_vector_layout.py",
"__main__/jaxlib/mosaic/python/infer_memref_layout.py",
"__main__/jaxlib/mosaic/python/tpu.py",
"__main__/jaxlib/mosaic/python/_tpu_ops_ext.py",
],
)
# TODO (sharadmv,skyewm): can we avoid patching this file?
patch_copy_mlir_import(
"__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir
)
copy_runfiles(
dst_dir=jaxlib_dir / "mlir",
src_files=[
"__main__/jaxlib/mlir/ir.py",
"__main__/jaxlib/mlir/passmanager.py",
],
)
copy_runfiles(
dst_dir=jaxlib_dir / "mlir" / "dialects",
src_files=[
"__main__/jaxlib/mlir/dialects/_arith_enum_gen.py",
"__main__/jaxlib/mlir/dialects/_arith_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_arith_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_func_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_func_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_math_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_memref_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_memref_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_ods_common.py",
"__main__/jaxlib/mlir/dialects/_scf_ops_ext.py",
"__main__/jaxlib/mlir/dialects/_scf_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py",
"__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py",
"__main__/jaxlib/mlir/dialects/_vector_enum_gen.py",
"__main__/jaxlib/mlir/dialects/_vector_ops_gen.py",
"__main__/jaxlib/mlir/dialects/arith.py",
"__main__/jaxlib/mlir/dialects/builtin.py",
"__main__/jaxlib/mlir/dialects/chlo.py",
"__main__/jaxlib/mlir/dialects/func.py",
"__main__/jaxlib/mlir/dialects/math.py",
"__main__/jaxlib/mlir/dialects/memref.py",
"__main__/jaxlib/mlir/dialects/mhlo.py",
"__main__/jaxlib/mlir/dialects/ml_program.py",
"__main__/jaxlib/mlir/dialects/scf.py",
"__main__/jaxlib/mlir/dialects/sparse_tensor.py",
"__main__/jaxlib/mlir/dialects/stablehlo.py",
"__main__/jaxlib/mlir/dialects/vector.py",
],
)
if build_utils.is_windows():
capi_so = "__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll"
else:
so_ext = "dylib" if _is_mac() else "so"
capi_so = f"__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.{so_ext}"
mlir_libs_dir = jaxlib_dir / "mlir" / "_mlir_libs"
copy_runfiles(
dst_dir=mlir_libs_dir,
src_files=[
capi_so,
"__main__/jaxlib/mlir/_mlir_libs/__init__.py",
f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}",
],
)
tmpdir = None
sources_path = args.sources_path
@ -294,7 +341,7 @@ if sources_path is None:
try:
os.makedirs(args.output_path, exist_ok=True)
prepare_wheel(
sources_path,
pathlib.Path(sources_path),
cpu=args.cpu,
include_gpu_plugin_extension=args.include_gpu_plugin_extension,
)