mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add functionality to allow promoting RC wheels during release
List of changes: 1. Allow us to build a RC wheel when building release artifacts. This is done by modifying the build CLI to use the new JAX build rule and passing in the build options that control the wheel tag. A new build argument `use_new_wheel_build_rule` is introduced to the build CLI to avoid breaking anyone that uses the CLI and the old build rule. Note that this option will go way in the future when the build CLI migrates fully to the new build rule. 2. Change the upload script to upload both rc and release tagged wheels (changes internal) PiperOrigin-RevId: 733464219
This commit is contained in:
parent
43b6be0e81
commit
721d1a3211
@ -63,6 +63,17 @@ WHEEL_BUILD_TARGET_DICT = {
|
||||
"jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
|
||||
}
|
||||
|
||||
# Dictionary with the new wheel build rule. Note that when JAX migrates to the
|
||||
# new wheel build rule fully, the build CLI will switch to the new wheel build
|
||||
# rule as the default.
|
||||
WHEEL_BUILD_TARGET_DICT_NEW = {
|
||||
"jax": "//:jax_wheel",
|
||||
"jaxlib": "//jaxlib/tools:jaxlib_wheel",
|
||||
"jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel",
|
||||
"jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel",
|
||||
"jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel",
|
||||
"jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel",
|
||||
}
|
||||
|
||||
def add_global_arguments(parser: argparse.ArgumentParser):
|
||||
"""Adds all the global arguments that applies to all the CLI subcommands."""
|
||||
@ -147,6 +158,16 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_new_wheel_build_rule",
|
||||
action="store_true",
|
||||
help=
|
||||
"""
|
||||
Whether to use the new wheel build rule. Temporary flag and will be
|
||||
removed once JAX migrates to the new wheel build rule fully.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--editable",
|
||||
action="store_true",
|
||||
@ -386,7 +407,10 @@ async def main():
|
||||
for option in args.bazel_startup_options:
|
||||
bazel_command_base.append(option)
|
||||
|
||||
bazel_command_base.append("run")
|
||||
if not args.use_new_wheel_build_rule or args.command == "requirements_update":
|
||||
bazel_command_base.append("run")
|
||||
else:
|
||||
bazel_command_base.append("build")
|
||||
|
||||
if args.python_version:
|
||||
# Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version
|
||||
@ -592,13 +616,19 @@ async def main():
|
||||
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
|
||||
|
||||
with open(".jax_configure.bazelrc", "w") as f:
|
||||
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list())
|
||||
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule)
|
||||
if not jax_configure_options:
|
||||
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
|
||||
sys.exit(1)
|
||||
f.write(jax_configure_options)
|
||||
logging.info("Bazel options written to .jax_configure.bazelrc")
|
||||
|
||||
if args.use_new_wheel_build_rule:
|
||||
logging.info("Using new wheel build rule")
|
||||
wheel_build_targets = WHEEL_BUILD_TARGET_DICT_NEW
|
||||
else:
|
||||
wheel_build_targets = WHEEL_BUILD_TARGET_DICT
|
||||
|
||||
if args.configure_only:
|
||||
logging.info("--configure_only is set so not running any Bazel commands.")
|
||||
else:
|
||||
@ -611,7 +641,7 @@ async def main():
|
||||
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
|
||||
wheel = "jax-" + wheel
|
||||
|
||||
if wheel not in WHEEL_BUILD_TARGET_DICT.keys():
|
||||
if wheel not in wheel_build_targets.keys():
|
||||
logging.error(
|
||||
"Incorrect wheel name provided, valid choices are jaxlib,"
|
||||
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
|
||||
@ -629,32 +659,33 @@ async def main():
|
||||
)
|
||||
|
||||
# Append the build target to the Bazel command.
|
||||
build_target = WHEEL_BUILD_TARGET_DICT[wheel]
|
||||
build_target = wheel_build_targets[wheel]
|
||||
wheel_build_command.append(build_target)
|
||||
|
||||
wheel_build_command.append("--")
|
||||
if not args.use_new_wheel_build_rule:
|
||||
wheel_build_command.append("--")
|
||||
|
||||
if args.editable:
|
||||
logger.info("Building an editable build")
|
||||
output_path = os.path.join(output_path, wheel)
|
||||
wheel_build_command.append("--editable")
|
||||
if args.editable:
|
||||
logger.info("Building an editable build")
|
||||
output_path = os.path.join(output_path, wheel)
|
||||
wheel_build_command.append("--editable")
|
||||
|
||||
wheel_build_command.append(f'--output_path="{output_path}"')
|
||||
wheel_build_command.append(f"--cpu={target_cpu}")
|
||||
wheel_build_command.append(f'--output_path="{output_path}"')
|
||||
wheel_build_command.append(f"--cpu={target_cpu}")
|
||||
|
||||
if "cuda" in wheel:
|
||||
wheel_build_command.append("--enable-cuda=True")
|
||||
if args.cuda_version:
|
||||
cuda_major_version = args.cuda_version.split(".")[0]
|
||||
else:
|
||||
cuda_major_version = args.cuda_major_version
|
||||
wheel_build_command.append(f"--platform_version={cuda_major_version}")
|
||||
if "cuda" in wheel:
|
||||
wheel_build_command.append("--enable-cuda=True")
|
||||
if args.cuda_version:
|
||||
cuda_major_version = args.cuda_version.split(".")[0]
|
||||
else:
|
||||
cuda_major_version = args.cuda_major_version
|
||||
wheel_build_command.append(f"--platform_version={cuda_major_version}")
|
||||
|
||||
if "rocm" in wheel:
|
||||
wheel_build_command.append("--enable-rocm=True")
|
||||
wheel_build_command.append(f"--platform_version={args.rocm_version}")
|
||||
if "rocm" in wheel:
|
||||
wheel_build_command.append("--enable-rocm=True")
|
||||
wheel_build_command.append(f"--platform_version={args.rocm_version}")
|
||||
|
||||
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
|
||||
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
|
||||
|
||||
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
|
||||
# Exit with error if any wheel build fails.
|
||||
|
@ -213,11 +213,15 @@ def get_gcc_major_version(gcc_path: str):
|
||||
return major_version
|
||||
|
||||
|
||||
def get_jax_configure_bazel_options(bazel_command: list[str]):
|
||||
def get_jax_configure_bazel_options(bazel_command: list[str], use_new_wheel_build_rule: bool):
|
||||
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
|
||||
# Get the index of the "run" parameter. Build options will come after "run" so
|
||||
# we find the index of "run" and filter everything after it.
|
||||
start = bazel_command.index("run")
|
||||
# we find the index of "run" and filter everything after it. If we are using
|
||||
# the new wheel build rule, we will find the index of "build" instead.
|
||||
if use_new_wheel_build_rule:
|
||||
start = bazel_command.index("build")
|
||||
else:
|
||||
start = bazel_command.index("run")
|
||||
jax_configure_bazel_options = ""
|
||||
try:
|
||||
for i in range(start + 1, len(bazel_command)):
|
||||
|
@ -45,52 +45,82 @@ if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then
|
||||
arch="amd64"
|
||||
fi
|
||||
|
||||
# Determine the artifact tag flags based on the artifact type. A release
|
||||
# wheel is tagged with the release version (e.g. 0.5.1), a nightly wheel is
|
||||
# tagged with the release version and a nightly suffix that contains the
|
||||
# current date (e.g. 0.5.2.dev20250227), and a default wheel is tagged with
|
||||
# the git commit hash of the HEAD of the current branch and the date of the
|
||||
# commit (e.g. 0.5.1.dev20250128+3e75e20c7).
|
||||
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
|
||||
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=release"
|
||||
elif [[ "$JAXCI_ARTIFACT_TYPE" == "nightly" ]]; then
|
||||
current_date=$(date +%Y%m%d)
|
||||
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=${current_date} --bazel_options=--repo_env=ML_WHEEL_TYPE=nightly"
|
||||
elif [[ "$JAXCI_ARTIFACT_TYPE" == "default" ]]; then
|
||||
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=custom --bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --bazel_options=--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)"
|
||||
else
|
||||
echo "Error: Invalid artifact type: $JAXCI_ARTIFACT_TYPE. Allowed values are: release, nightly, default"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
|
||||
# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
|
||||
# flags in the .bazelrc depending upon the platform we are building for.
|
||||
bazelrc_config="${os}_${arch}"
|
||||
|
||||
# Build the jax artifact
|
||||
if [[ "$artifact" == "jax" ]]; then
|
||||
python -m build --outdir $JAXCI_OUTPUT_DIR
|
||||
# On platforms with no RBE support, we can use the Bazel remote cache. Set
|
||||
# it to be empty by default to avoid unbound variable errors.
|
||||
bazel_remote_cache=""
|
||||
|
||||
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
|
||||
bazelrc_config="rbe_${bazelrc_config}"
|
||||
else
|
||||
bazelrc_config="ci_${bazelrc_config}"
|
||||
|
||||
# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
|
||||
# flags in the .bazelrc depending upon the platform we are building for.
|
||||
bazelrc_config="${os}_${arch}"
|
||||
|
||||
# On platforms with no RBE support, we can use the Bazel remote cache. Set
|
||||
# it to be empty by default to avoid unbound variable errors.
|
||||
bazel_remote_cache=""
|
||||
|
||||
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
|
||||
bazelrc_config="rbe_${bazelrc_config}"
|
||||
# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
|
||||
# CI system.
|
||||
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
|
||||
bazel_remote_cache="--bazel_options=--config=public_cache_push"
|
||||
else
|
||||
bazelrc_config="ci_${bazelrc_config}"
|
||||
|
||||
# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
|
||||
# CI system.
|
||||
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
|
||||
bazel_remote_cache="--bazel_options=--config=public_cache_push"
|
||||
else
|
||||
bazel_remote_cache="--bazel_options=--config=public_cache"
|
||||
fi
|
||||
bazel_remote_cache="--bazel_options=--config=public_cache"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Use the "_cuda" configs when building the CUDA artifacts.
|
||||
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
|
||||
bazelrc_config="${bazelrc_config}_cuda"
|
||||
fi
|
||||
# Use the "_cuda" configs when building the CUDA artifacts.
|
||||
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
|
||||
bazelrc_config="${bazelrc_config}_cuda"
|
||||
fi
|
||||
|
||||
# Build the artifact.
|
||||
# Build the artifact.
|
||||
python build/build.py build --wheels="$artifact" \
|
||||
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
|
||||
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
|
||||
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
|
||||
$artifact_tag_flags
|
||||
|
||||
# If building release artifacts, we also build a release candidate ("rc")
|
||||
# tagged wheel.
|
||||
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
|
||||
python build/build.py build --wheels="$artifact" \
|
||||
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
|
||||
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
|
||||
--verbose --detailed_timestamped_log
|
||||
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
|
||||
$artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION"
|
||||
fi
|
||||
|
||||
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
|
||||
# run `auditwheel show` to verify manylinux compliance.
|
||||
if [[ "$os" == "linux" ]]; then
|
||||
./ci/utilities/run_auditwheel.sh
|
||||
fi
|
||||
# Move the built artifacts from the Bazel cache directory to the output
|
||||
# directory.
|
||||
if [[ "$artifact" == "jax" ]]; then
|
||||
mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR"
|
||||
mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR"
|
||||
else
|
||||
mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR"
|
||||
fi
|
||||
|
||||
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
|
||||
# run `auditwheel show` to verify manylinux compliance.
|
||||
if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then
|
||||
./ci/utilities/run_auditwheel.sh
|
||||
fi
|
||||
|
||||
else
|
||||
|
@ -50,6 +50,15 @@ export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}
|
||||
# flag is enabled only for CI builds.
|
||||
export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0}
|
||||
|
||||
# Type of artifacts to build. Valid values are "default", "release", "nightly".
|
||||
# This affects the wheel naming/tag.
|
||||
export JAXCI_ARTIFACT_TYPE=${JAXCI_ARTIFACT_TYPE:-"default"}
|
||||
|
||||
# When building release artifacts, we build a release candidate wheel ("rc"
|
||||
# tagged wheel) in addition to the release wheel. This environment variable
|
||||
# sets the version of the release candidate ("RC") artifact to build.
|
||||
export JAXCI_WHEEL_RC_VERSION=${JAXCI_WHEEL_RC_VERSION:-}
|
||||
|
||||
# #############################################################################
|
||||
# Test script specific environment variables.
|
||||
# #############################################################################
|
||||
|
@ -98,3 +98,6 @@ function retry {
|
||||
|
||||
# Retry "bazel --version" 3 times to avoid flakiness when downloading bazel.
|
||||
retry "bazel --version"
|
||||
|
||||
# Create the output directory if it doesn't exist.
|
||||
mkdir -p "$JAXCI_OUTPUT_DIR"
|
Loading…
x
Reference in New Issue
Block a user