Add functionality to allow promoting RC wheels during release

List of changes:
1. Allow us to build a RC wheel when building release artifacts. This is done by modifying the build CLI to use the new JAX build rule and passing in the build options that control the wheel tag. A new build argument `use_new_wheel_build_rule` is introduced to the build CLI to avoid breaking anyone that uses the CLI and the old build rule. Note that this option will go way in the future when the build CLI migrates fully to the new build rule.
2. Change the upload script to upload both rc and release tagged wheels (changes internal)

PiperOrigin-RevId: 733464219
This commit is contained in:
Nitin Srinivasan 2025-03-04 14:20:27 -08:00 committed by jax authors
parent 43b6be0e81
commit 721d1a3211
5 changed files with 135 additions and 58 deletions

View File

@ -63,6 +63,17 @@ WHEEL_BUILD_TARGET_DICT = {
"jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
}
# Dictionary with the new wheel build rule. Note that when JAX migrates to the
# new wheel build rule fully, the build CLI will switch to the new wheel build
# rule as the default.
WHEEL_BUILD_TARGET_DICT_NEW = {
"jax": "//:jax_wheel",
"jaxlib": "//jaxlib/tools:jaxlib_wheel",
"jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel",
"jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel",
"jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel",
"jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel",
}
def add_global_arguments(parser: argparse.ArgumentParser):
"""Adds all the global arguments that applies to all the CLI subcommands."""
@ -147,6 +158,16 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
""",
)
parser.add_argument(
"--use_new_wheel_build_rule",
action="store_true",
help=
"""
Whether to use the new wheel build rule. Temporary flag and will be
removed once JAX migrates to the new wheel build rule fully.
""",
)
parser.add_argument(
"--editable",
action="store_true",
@ -386,7 +407,10 @@ async def main():
for option in args.bazel_startup_options:
bazel_command_base.append(option)
bazel_command_base.append("run")
if not args.use_new_wheel_build_rule or args.command == "requirements_update":
bazel_command_base.append("run")
else:
bazel_command_base.append("build")
if args.python_version:
# Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version
@ -592,13 +616,19 @@ async def main():
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
with open(".jax_configure.bazelrc", "w") as f:
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list())
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule)
if not jax_configure_options:
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
sys.exit(1)
f.write(jax_configure_options)
logging.info("Bazel options written to .jax_configure.bazelrc")
if args.use_new_wheel_build_rule:
logging.info("Using new wheel build rule")
wheel_build_targets = WHEEL_BUILD_TARGET_DICT_NEW
else:
wheel_build_targets = WHEEL_BUILD_TARGET_DICT
if args.configure_only:
logging.info("--configure_only is set so not running any Bazel commands.")
else:
@ -611,7 +641,7 @@ async def main():
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
wheel = "jax-" + wheel
if wheel not in WHEEL_BUILD_TARGET_DICT.keys():
if wheel not in wheel_build_targets.keys():
logging.error(
"Incorrect wheel name provided, valid choices are jaxlib,"
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
@ -629,32 +659,33 @@ async def main():
)
# Append the build target to the Bazel command.
build_target = WHEEL_BUILD_TARGET_DICT[wheel]
build_target = wheel_build_targets[wheel]
wheel_build_command.append(build_target)
wheel_build_command.append("--")
if not args.use_new_wheel_build_rule:
wheel_build_command.append("--")
if args.editable:
logger.info("Building an editable build")
output_path = os.path.join(output_path, wheel)
wheel_build_command.append("--editable")
if args.editable:
logger.info("Building an editable build")
output_path = os.path.join(output_path, wheel)
wheel_build_command.append("--editable")
wheel_build_command.append(f'--output_path="{output_path}"')
wheel_build_command.append(f"--cpu={target_cpu}")
wheel_build_command.append(f'--output_path="{output_path}"')
wheel_build_command.append(f"--cpu={target_cpu}")
if "cuda" in wheel:
wheel_build_command.append("--enable-cuda=True")
if args.cuda_version:
cuda_major_version = args.cuda_version.split(".")[0]
else:
cuda_major_version = args.cuda_major_version
wheel_build_command.append(f"--platform_version={cuda_major_version}")
if "cuda" in wheel:
wheel_build_command.append("--enable-cuda=True")
if args.cuda_version:
cuda_major_version = args.cuda_version.split(".")[0]
else:
cuda_major_version = args.cuda_major_version
wheel_build_command.append(f"--platform_version={cuda_major_version}")
if "rocm" in wheel:
wheel_build_command.append("--enable-rocm=True")
wheel_build_command.append(f"--platform_version={args.rocm_version}")
if "rocm" in wheel:
wheel_build_command.append("--enable-rocm=True")
wheel_build_command.append(f"--platform_version={args.rocm_version}")
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
# Exit with error if any wheel build fails.

View File

@ -213,11 +213,15 @@ def get_gcc_major_version(gcc_path: str):
return major_version
def get_jax_configure_bazel_options(bazel_command: list[str]):
def get_jax_configure_bazel_options(bazel_command: list[str], use_new_wheel_build_rule: bool):
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
# Get the index of the "run" parameter. Build options will come after "run" so
# we find the index of "run" and filter everything after it.
start = bazel_command.index("run")
# we find the index of "run" and filter everything after it. If we are using
# the new wheel build rule, we will find the index of "build" instead.
if use_new_wheel_build_rule:
start = bazel_command.index("build")
else:
start = bazel_command.index("run")
jax_configure_bazel_options = ""
try:
for i in range(start + 1, len(bazel_command)):

View File

@ -45,52 +45,82 @@ if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then
arch="amd64"
fi
# Determine the artifact tag flags based on the artifact type. A release
# wheel is tagged with the release version (e.g. 0.5.1), a nightly wheel is
# tagged with the release version and a nightly suffix that contains the
# current date (e.g. 0.5.2.dev20250227), and a default wheel is tagged with
# the git commit hash of the HEAD of the current branch and the date of the
# commit (e.g. 0.5.1.dev20250128+3e75e20c7).
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=release"
elif [[ "$JAXCI_ARTIFACT_TYPE" == "nightly" ]]; then
current_date=$(date +%Y%m%d)
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=${current_date} --bazel_options=--repo_env=ML_WHEEL_TYPE=nightly"
elif [[ "$JAXCI_ARTIFACT_TYPE" == "default" ]]; then
artifact_tag_flags="--bazel_options=--repo_env=ML_WHEEL_TYPE=custom --bazel_options=--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD) --bazel_options=--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)"
else
echo "Error: Invalid artifact type: $JAXCI_ARTIFACT_TYPE. Allowed values are: release, nightly, default"
exit 1
fi
if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
# flags in the .bazelrc depending upon the platform we are building for.
bazelrc_config="${os}_${arch}"
# Build the jax artifact
if [[ "$artifact" == "jax" ]]; then
python -m build --outdir $JAXCI_OUTPUT_DIR
# On platforms with no RBE support, we can use the Bazel remote cache. Set
# it to be empty by default to avoid unbound variable errors.
bazel_remote_cache=""
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
bazelrc_config="rbe_${bazelrc_config}"
else
bazelrc_config="ci_${bazelrc_config}"
# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
# flags in the .bazelrc depending upon the platform we are building for.
bazelrc_config="${os}_${arch}"
# On platforms with no RBE support, we can use the Bazel remote cache. Set
# it to be empty by default to avoid unbound variable errors.
bazel_remote_cache=""
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
bazelrc_config="rbe_${bazelrc_config}"
# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
# CI system.
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
bazel_remote_cache="--bazel_options=--config=public_cache_push"
else
bazelrc_config="ci_${bazelrc_config}"
# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
# CI system.
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
bazel_remote_cache="--bazel_options=--config=public_cache_push"
else
bazel_remote_cache="--bazel_options=--config=public_cache"
fi
bazel_remote_cache="--bazel_options=--config=public_cache"
fi
fi
# Use the "_cuda" configs when building the CUDA artifacts.
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
bazelrc_config="${bazelrc_config}_cuda"
fi
# Use the "_cuda" configs when building the CUDA artifacts.
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
bazelrc_config="${bazelrc_config}_cuda"
fi
# Build the artifact.
# Build the artifact.
python build/build.py build --wheels="$artifact" \
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
$artifact_tag_flags
# If building release artifacts, we also build a release candidate ("rc")
# tagged wheel.
if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then
python build/build.py build --wheels="$artifact" \
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
--verbose --detailed_timestamped_log
--verbose --detailed_timestamped_log --use_new_wheel_build_rule \
$artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION"
fi
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
# run `auditwheel show` to verify manylinux compliance.
if [[ "$os" == "linux" ]]; then
./ci/utilities/run_auditwheel.sh
fi
# Move the built artifacts from the Bazel cache directory to the output
# directory.
if [[ "$artifact" == "jax" ]]; then
mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR"
mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR"
else
mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR"
fi
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
# run `auditwheel show` to verify manylinux compliance.
if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then
./ci/utilities/run_auditwheel.sh
fi
else

View File

@ -50,6 +50,15 @@ export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}
# flag is enabled only for CI builds.
export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0}
# Type of artifacts to build. Valid values are "default", "release", "nightly".
# This affects the wheel naming/tag.
export JAXCI_ARTIFACT_TYPE=${JAXCI_ARTIFACT_TYPE:-"default"}
# When building release artifacts, we build a release candidate wheel ("rc"
# tagged wheel) in addition to the release wheel. This environment variable
# sets the version of the release candidate ("RC") artifact to build.
export JAXCI_WHEEL_RC_VERSION=${JAXCI_WHEEL_RC_VERSION:-}
# #############################################################################
# Test script specific environment variables.
# #############################################################################

View File

@ -98,3 +98,6 @@ function retry {
# Retry "bazel --version" 3 times to avoid flakiness when downloading bazel.
retry "bazel --version"
# Create the output directory if it doesn't exist.
mkdir -p "$JAXCI_OUTPUT_DIR"