build/build.py changes: copy the wheels created by the new build wheel targets into the path specified by --output_path.

PiperOrigin-RevId: 740829299
This commit is contained in:
jax authors 2025-03-26 10:55:14 -07:00
parent b1b281a427
commit 55318d5824
3 changed files with 59 additions and 9 deletions

View File

@ -76,6 +76,8 @@ WHEEL_BUILD_TARGET_DICT_NEW = {
"jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel",
}
_JAX_CUDA_VERSION = "12"
def add_global_arguments(parser: argparse.ArgumentParser):
"""Adds all the global arguments that applies to all the CLI subcommands."""
parser.add_argument(
@ -695,6 +697,35 @@ async def main():
if result.return_code != 0:
raise RuntimeError(f"Command failed with return code {result.return_code}")
if args.use_new_wheel_build_rule:
output_path = args.output_path
jax_bazel_dir = os.path.join("bazel-bin", "dist")
jaxlib_and_plugins_bazel_dir = os.path.join(
"bazel-bin", "jaxlib", "tools", "dist"
)
for wheel in args.wheels.split(","):
if wheel == "jax":
bazel_dir = jax_bazel_dir
else:
bazel_dir = jaxlib_and_plugins_bazel_dir
if "cuda" in wheel:
wheel_dir = wheel.replace("cuda", f"cuda{_JAX_CUDA_VERSION}").replace(
"-", "_"
)
else:
wheel_dir = wheel
if args.editable:
src_dir = os.path.join(bazel_dir, wheel_dir)
dst_dir = os.path.join(output_path, wheel_dir)
utils.copy_dir_recursively(src_dir, dst_dir)
else:
utils.copy_individual_files(bazel_dir, output_path, f"{wheel_dir}*.whl")
if wheel == "jax":
utils.copy_individual_files(
bazel_dir, output_path, f"{wheel_dir}*.tar.gz"
)
# Exit with success if all wheels in the list were built successfully.
sys.exit(0)

View File

@ -14,6 +14,7 @@
# ==============================================================================
# Helper script for tools/utilities used by the JAX build CLI.
import collections
import glob
import hashlib
import logging
import os
@ -256,3 +257,28 @@ def _parse_string_as_bool(s):
return False
else:
raise ValueError(f"Expected either 'true' or 'false'; got {s}")
def copy_dir_recursively(src, dst):
if os.path.exists(dst):
shutil.rmtree(dst)
os.makedirs(dst, exist_ok=True)
for root, dirs, files in os.walk(src):
relative_path = os.path.relpath(root, src)
dst_dir = os.path.join(dst, relative_path)
os.makedirs(dst_dir, exist_ok=True)
for f in files:
src_file = os.path.join(root, f)
dst_file = os.path.join(dst_dir, f)
shutil.copy2(src_file, dst_file)
logging.info("Editable wheel path: %s" % dst)
def copy_individual_files(src, dst, regex):
os.makedirs(dst, exist_ok=True)
for f in glob.glob(os.path.join(src, regex)):
dst_file = os.path.join(dst, os.path.basename(f))
if os.path.exists(dst_file):
os.remove(dst_file)
shutil.copy2(f, dst_file)
logging.info("Distribution path: %s" % dst_file)

View File

@ -96,6 +96,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
--output_path="$JAXCI_OUTPUT_DIR" \
$artifact_tag_flags
# If building release artifacts, we also build a release candidate ("rc")
@ -105,18 +106,10 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
--output_path="$JAXCI_OUTPUT_DIR" \
$artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION"
fi
# Move the built artifacts from the Bazel cache directory to the output
# directory.
if [[ "$artifact" == "jax" ]]; then
mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR"
mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR"
else
mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR"
fi
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
# run `auditwheel show` to verify manylinux compliance.
if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then