mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #17864 from hawkinsp:buildwheel
PiperOrigin-RevId: 569576169
This commit is contained in:
commit
095c367c01
@ -14,12 +14,16 @@
|
||||
|
||||
"""Utilities for the building JAX related python packages."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import subprocess
|
||||
import glob
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
@ -27,20 +31,22 @@ def is_windows() -> bool:
|
||||
|
||||
|
||||
def copy_file(
|
||||
src_file: str,
|
||||
dst_dir: str,
|
||||
dst_filename=None,
|
||||
from_runfiles=True,
|
||||
runfiles=None,
|
||||
src_files: str | Sequence[str],
|
||||
dst_dir: pathlib.Path,
|
||||
dst_filename = None,
|
||||
runfiles = None,
|
||||
) -> None:
|
||||
if from_runfiles:
|
||||
dst_dir.mkdir(parents=True, exist_ok=True)
|
||||
if isinstance(src_files, str):
|
||||
src_files = [src_files]
|
||||
for src_file in src_files:
|
||||
src_file = runfiles.Rlocation(src_file)
|
||||
src_filename = os.path.basename(src_file)
|
||||
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
|
||||
if is_windows():
|
||||
shutil.copyfile(src_file, dst_file)
|
||||
else:
|
||||
shutil.copy(src_file, dst_file)
|
||||
src_filename = os.path.basename(src_file)
|
||||
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
|
||||
if is_windows():
|
||||
shutil.copyfile(src_file, dst_file)
|
||||
else:
|
||||
shutil.copy(src_file, dst_file)
|
||||
|
||||
|
||||
def platform_tag(cpu: str) -> str:
|
||||
|
@ -18,7 +18,9 @@
|
||||
# Most users should not run this script directly; use build.py instead.
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
from bazel_tools.tools.python.runfiles import runfiles
|
||||
@ -26,20 +28,20 @@ from jax.tools import build_utils
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--sources_path",
|
||||
default=None,
|
||||
help="Path in which the wheel's sources should be prepared. Optional. If "
|
||||
"omitted, a temporary directory will be used.")
|
||||
"--sources_path",
|
||||
default=None,
|
||||
help="Path in which the wheel's sources should be prepared. Optional. If "
|
||||
"omitted, a temporary directory will be used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to which the output wheel should be written. Required.")
|
||||
"--output_path",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to which the output wheel should be written. Required.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Target CPU architecture. Required.")
|
||||
"--cpu", default=None, required=True, help="Target CPU architecture. Required."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cuda_version",
|
||||
default=None,
|
||||
@ -58,51 +60,53 @@ r = runfiles.Create()
|
||||
|
||||
def write_setup_cfg(sources_path, cpu):
|
||||
tag = build_utils.platform_tag(cpu)
|
||||
with open(os.path.join(sources_path, "setup.cfg"), "w") as f:
|
||||
f.write(f"""[metadata]
|
||||
with open(sources_path / "setup.cfg", "w") as f:
|
||||
f.write(
|
||||
f"""[metadata]
|
||||
license_files = LICENSE.txt
|
||||
|
||||
[bdist_wheel]
|
||||
plat-name={tag}
|
||||
python-tag=py3
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def update_setup(file_dir, cuda_version):
|
||||
src_file = os.path.join(file_dir, "setup.py")
|
||||
with open(os.path.join(src_file), "r") as f:
|
||||
src_file = file_dir / "setup.py"
|
||||
with open(src_file, "r") as f:
|
||||
content = f.read()
|
||||
content = content.replace(
|
||||
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"
|
||||
)
|
||||
with open(os.path.join(src_file), "w") as f:
|
||||
with open(src_file, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def prepare_cuda_plugin_wheel(sources_path, *, cpu, cuda_version):
|
||||
def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version):
|
||||
"""Assembles a source tree for the wheel in `sources_path`."""
|
||||
jax_plugins_dir = os.path.join(sources_path, "jax_plugins")
|
||||
os.makedirs(jax_plugins_dir)
|
||||
plugin_dir = os.path.join(jax_plugins_dir, f"xla_cuda_cu{cuda_version}")
|
||||
os.makedirs(plugin_dir)
|
||||
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
|
||||
|
||||
build_utils.copy_file(
|
||||
"__main__/plugins/cuda/pyproject.toml", dst_dir=sources_path, runfiles=r
|
||||
)
|
||||
build_utils.copy_file(
|
||||
"__main__/plugins/cuda/setup.py", dst_dir=sources_path, runfiles=r
|
||||
plugin_dir = sources_path / "jax_plugins" / f"xla_cuda_cu{cuda_version}"
|
||||
copy_runfiles(
|
||||
dst_dir=sources_path,
|
||||
src_files=[
|
||||
"__main__/plugins/cuda/pyproject.toml",
|
||||
"__main__/plugins/cuda/setup.py",
|
||||
],
|
||||
)
|
||||
update_setup(sources_path, cuda_version)
|
||||
write_setup_cfg(sources_path, cpu)
|
||||
build_utils.copy_file(
|
||||
"__main__/plugins/cuda/__init__.py", dst_dir=plugin_dir, runfiles=r
|
||||
copy_runfiles(
|
||||
dst_dir=plugin_dir,
|
||||
src_files=[
|
||||
"__main__/plugins/cuda/__init__.py",
|
||||
],
|
||||
)
|
||||
plugin_so_path = r.Rlocation("xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so")
|
||||
build_utils.copy_file(
|
||||
plugin_so_path,
|
||||
copy_runfiles(
|
||||
"xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so",
|
||||
dst_dir=plugin_dir,
|
||||
dst_filename="xla_cuda_plugin.so",
|
||||
runfiles=r,
|
||||
)
|
||||
|
||||
|
||||
@ -115,7 +119,7 @@ if sources_path is None:
|
||||
try:
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
prepare_cuda_plugin_wheel(
|
||||
sources_path, cpu=args.cpu, cuda_version=args.cuda_version
|
||||
pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.cuda_version
|
||||
)
|
||||
package_name = "jax cuda plugin"
|
||||
if args.editable:
|
||||
|
@ -20,6 +20,7 @@
|
||||
import argparse
|
||||
import functools
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
@ -30,24 +31,25 @@ from jax.tools import build_utils
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--sources_path",
|
||||
default=None,
|
||||
help="Path in which the wheel's sources should be prepared. Optional. If "
|
||||
"omitted, a temporary directory will be used.")
|
||||
"--sources_path",
|
||||
default=None,
|
||||
help="Path in which the wheel's sources should be prepared. Optional. If "
|
||||
"omitted, a temporary directory will be used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to which the output wheel should be written. Required.")
|
||||
"--output_path",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to which the output wheel should be written. Required.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Target CPU architecture. Required.")
|
||||
"--cpu", default=None, required=True, help="Target CPU architecture. Required."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--editable",
|
||||
action="store_true",
|
||||
help="Create an 'editable' jaxlib build instead of a wheel.")
|
||||
"--editable",
|
||||
action="store_true",
|
||||
help="Create an 'editable' jaxlib build instead of a wheel.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include_gpu_plugin_extension",
|
||||
# args.include_gpu_plugin_extension is True when
|
||||
@ -77,10 +79,16 @@ def patch_copy_mlir_import(src_file, dst_dir):
|
||||
with open(src_file) as f:
|
||||
src = f.read()
|
||||
|
||||
with open(os.path.join(dst_dir, src_filename), 'w') as f:
|
||||
replaced = re.sub(r'^from mlir(\..*)? import (.*)', r'from jaxlib.mlir\1 import \2', src, flags=re.MULTILINE)
|
||||
with open(dst_dir / src_filename, "w") as f:
|
||||
replaced = re.sub(
|
||||
r"^from mlir(\..*)? import (.*)",
|
||||
r"from jaxlib.mlir\1 import \2",
|
||||
src,
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
f.write(replaced)
|
||||
|
||||
|
||||
_XLA_EXTENSION_STUBS = [
|
||||
"__init__.pyi",
|
||||
"jax_jit.pyi",
|
||||
@ -91,28 +99,21 @@ _XLA_EXTENSION_STUBS = [
|
||||
"pytree.pyi",
|
||||
"transfer_guard_lib.pyi",
|
||||
]
|
||||
_OPTIONAL_XLA_EXTENSION_STUBS = [
|
||||
]
|
||||
_OPTIONAL_XLA_EXTENSION_STUBS = []
|
||||
|
||||
|
||||
def patch_copy_xla_extension_stubs(dst_dir):
|
||||
# This file is required by PEP-561. It marks jaxlib as package containing
|
||||
# type stubs.
|
||||
with open(os.path.join(dst_dir, "py.typed"), "w"):
|
||||
pass
|
||||
xla_extension_dir = os.path.join(dst_dir, "xla_extension")
|
||||
os.makedirs(xla_extension_dir)
|
||||
for stub_name in _XLA_EXTENSION_STUBS:
|
||||
stub_path = r.Rlocation(
|
||||
"xla/xla/python/xla_extension/" + stub_name)
|
||||
stub_path = r.Rlocation("xla/xla/python/xla_extension/" + stub_name)
|
||||
stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path).
|
||||
if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path):
|
||||
continue
|
||||
with open(stub_path) as f:
|
||||
src = f.read()
|
||||
src = src.replace(
|
||||
"from xla.python import xla_extension",
|
||||
"from .. import xla_extension"
|
||||
"from xla.python import xla_extension", "from .. import xla_extension"
|
||||
)
|
||||
with open(os.path.join(xla_extension_dir, stub_name), "w") as f:
|
||||
f.write(src)
|
||||
@ -130,160 +131,206 @@ def verify_mac_libraries_dont_reference_chkstack():
|
||||
if not _is_mac():
|
||||
return
|
||||
nm = subprocess.run(
|
||||
["nm", "-g",
|
||||
r.Rlocation("xla/xla/python/xla_extension.so")
|
||||
],
|
||||
capture_output=True, text=True,
|
||||
check=False)
|
||||
["nm", "-g", r.Rlocation("xla/xla/python/xla_extension.so")],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
if nm.returncode != 0:
|
||||
raise RuntimeError(f"nm process failed: {nm.stdout} {nm.stderr}")
|
||||
if "____chkstk_darwin" in nm.stdout:
|
||||
raise RuntimeError(
|
||||
"Mac wheel incorrectly depends on symbol ____chkstk_darwin, which "
|
||||
"means that it isn't compatible with older MacOS versions.")
|
||||
"Mac wheel incorrectly depends on symbol ____chkstk_darwin, which "
|
||||
"means that it isn't compatible with older MacOS versions."
|
||||
)
|
||||
|
||||
|
||||
def write_setup_cfg(sources_path, cpu):
|
||||
tag = build_utils.platform_tag(cpu)
|
||||
with open(os.path.join(sources_path, "setup.cfg"), "w") as f:
|
||||
f.write(f"""[metadata]
|
||||
with open(sources_path / "setup.cfg", "w") as f:
|
||||
f.write(
|
||||
f"""[metadata]
|
||||
license_files = LICENSE.txt
|
||||
|
||||
[bdist_wheel]
|
||||
plat-name={tag}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def prepare_wheel(sources_path, *, cpu, include_gpu_plugin_extension):
|
||||
def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extension):
|
||||
"""Assembles a source tree for the wheel in `sources_path`."""
|
||||
jaxlib_dir = os.path.join(sources_path, "jaxlib")
|
||||
os.makedirs(jaxlib_dir)
|
||||
copy_to_jaxlib = functools.partial(build_utils.copy_file, dst_dir=jaxlib_dir, runfiles=r)
|
||||
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
|
||||
|
||||
verify_mac_libraries_dont_reference_chkstack()
|
||||
build_utils.copy_file("__main__/jaxlib/tools/LICENSE.txt", dst_dir=sources_path, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/README.md", dst_dir=sources_path, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path, runfiles=r)
|
||||
copy_runfiles(
|
||||
dst_dir=sources_path,
|
||||
src_files=[
|
||||
"__main__/jaxlib/tools/LICENSE.txt",
|
||||
"__main__/jaxlib/README.md",
|
||||
"__main__/jaxlib/setup.py",
|
||||
],
|
||||
)
|
||||
write_setup_cfg(sources_path, cpu)
|
||||
copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}")
|
||||
|
||||
jaxlib_dir = sources_path / "jaxlib"
|
||||
copy_runfiles(
|
||||
"__main__/jaxlib/init.py", dst_dir=jaxlib_dir, dst_filename="__init__.py"
|
||||
)
|
||||
if include_gpu_plugin_extension:
|
||||
copy_to_jaxlib(f"__main__/jaxlib/cuda_plugin_extension.{pyext}")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/utils.{pyext}")
|
||||
copy_to_jaxlib("__main__/jaxlib/lapack.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/hlo_helpers.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/ducc_fft.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_rnn.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_triton.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_common_utils.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_solver.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_sparse.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/tpu_mosaic.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/version.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/xla_client.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}")
|
||||
cpu_dir = os.path.join(jaxlib_dir, "cpu")
|
||||
os.makedirs(cpu_dir)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cpu/_lapack.{pyext}", dst_dir=cpu_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir, runfiles=r)
|
||||
|
||||
cuda_dir = os.path.join(jaxlib_dir, "cuda")
|
||||
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"):
|
||||
libdevice_dir = os.path.join(cuda_dir, "nvvm", "libdevice")
|
||||
os.makedirs(libdevice_dir)
|
||||
build_utils.copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_rnn.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_triton.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_versions.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
rocm_dir = os.path.join(jaxlib_dir, "rocm")
|
||||
if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"):
|
||||
os.makedirs(rocm_dir)
|
||||
build_utils.copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir, runfiles=r)
|
||||
if exists(f"__main__/jaxlib/cuda/_sparse.{pyext}"):
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
if exists(f"__main__/jaxlib/rocm/_sparse.{pyext}"):
|
||||
build_utils.copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir, runfiles=r)
|
||||
|
||||
mosaic_dir = os.path.join(jaxlib_dir, "mosaic")
|
||||
mosaic_python_dir = os.path.join(mosaic_dir, "python")
|
||||
os.makedirs(mosaic_dir)
|
||||
os.makedirs(mosaic_python_dir)
|
||||
copy_to_jaxlib("__main__/jaxlib/mosaic/python/apply_vector_layout.py", dst_dir=mosaic_python_dir)
|
||||
copy_to_jaxlib("__main__/jaxlib/mosaic/python/infer_memref_layout.py", dst_dir=mosaic_python_dir)
|
||||
copy_to_jaxlib("__main__/jaxlib/mosaic/python/tpu.py", dst_dir=mosaic_python_dir)
|
||||
build_utils.copy_file("__main__/jaxlib/mosaic/python/_tpu_ops_ext.py", dst_dir=mosaic_python_dir, runfiles=r)
|
||||
# TODO (sharadmv,skyewm): can we avoid patching this file?
|
||||
patch_copy_mlir_import("__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir)
|
||||
|
||||
mlir_dir = os.path.join(jaxlib_dir, "mlir")
|
||||
mlir_dialects_dir = os.path.join(jaxlib_dir, "mlir", "dialects")
|
||||
mlir_libs_dir = os.path.join(jaxlib_dir, "mlir", "_mlir_libs")
|
||||
os.makedirs(mlir_dir)
|
||||
os.makedirs(mlir_dialects_dir)
|
||||
os.makedirs(mlir_libs_dir)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ods_common.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_func_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_func_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/sparse_tensor.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/builtin.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/chlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/arith.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/math.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_math_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/memref.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/scf.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/vector.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/mhlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/stablehlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/ml_program.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/__init__.py", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
if build_utils.is_windows():
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
elif _is_mac():
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
else:
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir, src_files=[f"__main__/jaxlib/cuda_plugin_extension.{pyext}"]
|
||||
)
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir,
|
||||
src_files=[
|
||||
f"__main__/jaxlib/cpu_feature_guard.{pyext}",
|
||||
f"__main__/jaxlib/utils.{pyext}",
|
||||
"__main__/jaxlib/lapack.py",
|
||||
"__main__/jaxlib/hlo_helpers.py",
|
||||
"__main__/jaxlib/ducc_fft.py",
|
||||
"__main__/jaxlib/gpu_prng.py",
|
||||
"__main__/jaxlib/gpu_linalg.py",
|
||||
"__main__/jaxlib/gpu_rnn.py",
|
||||
"__main__/jaxlib/gpu_triton.py",
|
||||
"__main__/jaxlib/gpu_common_utils.py",
|
||||
"__main__/jaxlib/gpu_solver.py",
|
||||
"__main__/jaxlib/gpu_sparse.py",
|
||||
"__main__/jaxlib/tpu_mosaic.py",
|
||||
"__main__/jaxlib/version.py",
|
||||
"__main__/jaxlib/xla_client.py",
|
||||
f"__main__/jaxlib/xla_extension.{pyext}",
|
||||
],
|
||||
)
|
||||
# This file is required by PEP-561. It marks jaxlib as package containing
|
||||
# type stubs.
|
||||
with open(jaxlib_dir / "py.typed", "w"):
|
||||
pass
|
||||
patch_copy_xla_extension_stubs(jaxlib_dir)
|
||||
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "cpu",
|
||||
src_files=[
|
||||
f"__main__/jaxlib/cpu/_lapack.{pyext}",
|
||||
f"__main__/jaxlib/cpu/_ducc_fft.{pyext}",
|
||||
],
|
||||
)
|
||||
|
||||
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"):
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice",
|
||||
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
|
||||
)
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "cuda",
|
||||
src_files=[
|
||||
f"__main__/jaxlib/cuda/_solver.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_blas.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_linalg.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_prng.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_rnn.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_sparse.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_triton.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_versions.{pyext}",
|
||||
],
|
||||
)
|
||||
if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"):
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "rocm",
|
||||
src_files=[
|
||||
f"__main__/jaxlib/rocm/_solver.{pyext}",
|
||||
f"__main__/jaxlib/rocm/_blas.{pyext}",
|
||||
f"__main__/jaxlib/rocm/_linalg.{pyext}",
|
||||
f"__main__/jaxlib/rocm/_prng.{pyext}",
|
||||
f"__main__/jaxlib/rocm/_sparse.{pyext}",
|
||||
],
|
||||
)
|
||||
|
||||
mosaic_python_dir = jaxlib_dir / "mosaic" / "python"
|
||||
copy_runfiles(
|
||||
dst_dir=mosaic_python_dir,
|
||||
src_files=[
|
||||
"__main__/jaxlib/mosaic/python/apply_vector_layout.py",
|
||||
"__main__/jaxlib/mosaic/python/infer_memref_layout.py",
|
||||
"__main__/jaxlib/mosaic/python/tpu.py",
|
||||
"__main__/jaxlib/mosaic/python/_tpu_ops_ext.py",
|
||||
],
|
||||
)
|
||||
# TODO (sharadmv,skyewm): can we avoid patching this file?
|
||||
patch_copy_mlir_import(
|
||||
"__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir
|
||||
)
|
||||
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mlir",
|
||||
src_files=[
|
||||
"__main__/jaxlib/mlir/ir.py",
|
||||
"__main__/jaxlib/mlir/passmanager.py",
|
||||
],
|
||||
)
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "mlir" / "dialects",
|
||||
src_files=[
|
||||
"__main__/jaxlib/mlir/dialects/_arith_enum_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_arith_ops_ext.py",
|
||||
"__main__/jaxlib/mlir/dialects/_arith_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py",
|
||||
"__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_func_ops_ext.py",
|
||||
"__main__/jaxlib/mlir/dialects/_func_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_math_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_memref_ops_ext.py",
|
||||
"__main__/jaxlib/mlir/dialects/_memref_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py",
|
||||
"__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_ods_common.py",
|
||||
"__main__/jaxlib/mlir/dialects/_scf_ops_ext.py",
|
||||
"__main__/jaxlib/mlir/dialects/_scf_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_vector_enum_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/_vector_ops_gen.py",
|
||||
"__main__/jaxlib/mlir/dialects/arith.py",
|
||||
"__main__/jaxlib/mlir/dialects/builtin.py",
|
||||
"__main__/jaxlib/mlir/dialects/chlo.py",
|
||||
"__main__/jaxlib/mlir/dialects/func.py",
|
||||
"__main__/jaxlib/mlir/dialects/math.py",
|
||||
"__main__/jaxlib/mlir/dialects/memref.py",
|
||||
"__main__/jaxlib/mlir/dialects/mhlo.py",
|
||||
"__main__/jaxlib/mlir/dialects/ml_program.py",
|
||||
"__main__/jaxlib/mlir/dialects/scf.py",
|
||||
"__main__/jaxlib/mlir/dialects/sparse_tensor.py",
|
||||
"__main__/jaxlib/mlir/dialects/stablehlo.py",
|
||||
"__main__/jaxlib/mlir/dialects/vector.py",
|
||||
],
|
||||
)
|
||||
|
||||
if build_utils.is_windows():
|
||||
capi_so = "__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll"
|
||||
else:
|
||||
so_ext = "dylib" if _is_mac() else "so"
|
||||
capi_so = f"__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.{so_ext}"
|
||||
|
||||
mlir_libs_dir = jaxlib_dir / "mlir" / "_mlir_libs"
|
||||
copy_runfiles(
|
||||
dst_dir=mlir_libs_dir,
|
||||
src_files=[
|
||||
capi_so,
|
||||
"__main__/jaxlib/mlir/_mlir_libs/__init__.py",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
tmpdir = None
|
||||
sources_path = args.sources_path
|
||||
@ -294,7 +341,7 @@ if sources_path is None:
|
||||
try:
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
prepare_wheel(
|
||||
sources_path,
|
||||
pathlib.Path(sources_path),
|
||||
cpu=args.cpu,
|
||||
include_gpu_plugin_extension=args.include_gpu_plugin_extension,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user