mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
build/build.py
changes: copy the wheels created by the new build wheel targets into the path specified by --output_path
.
PiperOrigin-RevId: 740829299
This commit is contained in:
parent
b1b281a427
commit
55318d5824
@ -76,6 +76,8 @@ WHEEL_BUILD_TARGET_DICT_NEW = {
|
||||
"jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel",
|
||||
}
|
||||
|
||||
_JAX_CUDA_VERSION = "12"
|
||||
|
||||
def add_global_arguments(parser: argparse.ArgumentParser):
|
||||
"""Adds all the global arguments that applies to all the CLI subcommands."""
|
||||
parser.add_argument(
|
||||
@ -695,6 +697,35 @@ async def main():
|
||||
if result.return_code != 0:
|
||||
raise RuntimeError(f"Command failed with return code {result.return_code}")
|
||||
|
||||
if args.use_new_wheel_build_rule:
|
||||
output_path = args.output_path
|
||||
jax_bazel_dir = os.path.join("bazel-bin", "dist")
|
||||
jaxlib_and_plugins_bazel_dir = os.path.join(
|
||||
"bazel-bin", "jaxlib", "tools", "dist"
|
||||
)
|
||||
for wheel in args.wheels.split(","):
|
||||
if wheel == "jax":
|
||||
bazel_dir = jax_bazel_dir
|
||||
else:
|
||||
bazel_dir = jaxlib_and_plugins_bazel_dir
|
||||
if "cuda" in wheel:
|
||||
wheel_dir = wheel.replace("cuda", f"cuda{_JAX_CUDA_VERSION}").replace(
|
||||
"-", "_"
|
||||
)
|
||||
else:
|
||||
wheel_dir = wheel
|
||||
|
||||
if args.editable:
|
||||
src_dir = os.path.join(bazel_dir, wheel_dir)
|
||||
dst_dir = os.path.join(output_path, wheel_dir)
|
||||
utils.copy_dir_recursively(src_dir, dst_dir)
|
||||
else:
|
||||
utils.copy_individual_files(bazel_dir, output_path, f"{wheel_dir}*.whl")
|
||||
if wheel == "jax":
|
||||
utils.copy_individual_files(
|
||||
bazel_dir, output_path, f"{wheel_dir}*.tar.gz"
|
||||
)
|
||||
|
||||
# Exit with success if all wheels in the list were built successfully.
|
||||
sys.exit(0)
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
# ==============================================================================
|
||||
# Helper script for tools/utilities used by the JAX build CLI.
|
||||
import collections
|
||||
import glob
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
@ -256,3 +257,28 @@ def _parse_string_as_bool(s):
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Expected either 'true' or 'false'; got {s}")
|
||||
|
||||
|
||||
def copy_dir_recursively(src, dst):
|
||||
if os.path.exists(dst):
|
||||
shutil.rmtree(dst)
|
||||
os.makedirs(dst, exist_ok=True)
|
||||
for root, dirs, files in os.walk(src):
|
||||
relative_path = os.path.relpath(root, src)
|
||||
dst_dir = os.path.join(dst, relative_path)
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
for f in files:
|
||||
src_file = os.path.join(root, f)
|
||||
dst_file = os.path.join(dst_dir, f)
|
||||
shutil.copy2(src_file, dst_file)
|
||||
logging.info("Editable wheel path: %s" % dst)
|
||||
|
||||
|
||||
def copy_individual_files(src, dst, regex):
|
||||
os.makedirs(dst, exist_ok=True)
|
||||
for f in glob.glob(os.path.join(src, regex)):
|
||||
dst_file = os.path.join(dst, os.path.basename(f))
|
||||
if os.path.exists(dst_file):
|
||||
os.remove(dst_file)
|
||||
shutil.copy2(f, dst_file)
|
||||
logging.info("Distribution path: %s" % dst_file)
|
||||
|
@ -96,6 +96,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
|
||||
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
|
||||
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
|
||||
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
|
||||
--output_path="$JAXCI_OUTPUT_DIR" \
|
||||
$artifact_tag_flags
|
||||
|
||||
# If building release artifacts, we also build a release candidate ("rc")
|
||||
@ -105,18 +106,10 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
|
||||
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
|
||||
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
|
||||
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
|
||||
--output_path="$JAXCI_OUTPUT_DIR" \
|
||||
$artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION"
|
||||
fi
|
||||
|
||||
# Move the built artifacts from the Bazel cache directory to the output
|
||||
# directory.
|
||||
if [[ "$artifact" == "jax" ]]; then
|
||||
mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR"
|
||||
mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR"
|
||||
else
|
||||
mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR"
|
||||
fi
|
||||
|
||||
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
|
||||
# run `auditwheel show` to verify manylinux compliance.
|
||||
if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then
|
||||
|
Loading…
x
Reference in New Issue
Block a user