1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Add new Bazel remote cache configs

An example run where the cache configs are used: https://github.com/jax-ml/jax/actions/runs/12940123731

PiperOrigin-RevId: 719627011
This commit is contained in:
Nitin Srinivasan 2025-01-25 05:31:49 -08:00 committed by jax authors
parent d28c3fa409
commit 89a9c6c244
6 changed files with 43 additions and 4 deletions

@ -170,6 +170,12 @@ build:rocm --action_env=TF_HIPCC_CLANG="1"
# #############################################################################
# Cache options below.
# #############################################################################
# Public read-only cache
build:public_cache --remote_cache="https://storage.googleapis.com/jax-bazel-cache/" --remote_upload_local_results=false
# Cache pushes are limited to JAX's CI system.
build:public_cache_push --config=public_cache --remote_upload_local_results=true --google_default_credentials
# Note: the following cache configs are deprecated and will be removed soon.
# Public read-only cache for Mac builds. JAX uses a GCS bucket to store cache
# from JAX's Mac CI build. By applying --config=macos_cache, any local Mac build
# should be able to read from this cache and potentially see a speedup. The

@ -48,6 +48,8 @@ jobs:
env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
# Enable writing to the Bazel remote cache bucket.
JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: "1"
name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"

@ -122,6 +122,9 @@ jobs:
- name: Enable RBE if building on Linux x86 or Windows x86
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86')
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV
- name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64
if: contains(inputs.runner, 'linux-arm64')
run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main

@ -56,11 +56,22 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
# flags in the .bazelrc depending upon the platform we are building for.
bazelrc_config="${os}_${arch}"
# TODO(b/379903748): Add remote cache options for Linux and Windows.
# On platforms with no RBE support, we can use the Bazel remote cache. Set
# it to be empty by default to avoid unbound variable errors.
bazel_remote_cache=""
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
bazelrc_config="rbe_${bazelrc_config}"
else
bazelrc_config="ci_${bazelrc_config}"
# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
# CI system.
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
bazel_remote_cache="--bazel_options=--config=public_cache_push"
else
bazel_remote_cache="--bazel_options=--config=public_cache"
fi
fi
# Use the "_cuda" configs when building the CUDA artifacts.
@ -69,7 +80,10 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
fi
# Build the artifact.
python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose --detailed_timestamped_log
python build/build.py build --wheels="$artifact" \
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
--verbose --detailed_timestamped_log
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
# run `auditwheel show` to verify manylinux compliance.

@ -44,6 +44,12 @@ export JAXCI_OUTPUT_DIR="$(pwd)/dist"
# for CI builds where RBE is supported.
export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}
# On platforms where RBE is not supported, we use Bazel remote cache to speed up
# builds. When this flag is enabled, Bazel will also try to push new cache
# entries to the bucket. Since writes to the bucket require authentication, this
# flag is enabled only for CI builds.
export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0}
# #############################################################################
# Test script specific environment variables.
# #############################################################################

@ -45,6 +45,14 @@ export num_cpu_cores=$(nproc)
if [[ $num_test_jobs -gt $num_cpu_cores ]]; then
num_test_jobs=$num_cpu_cores
fi
# Use the Bazel remote cache to speed up builds. Pushes to the cache bucket is
# limited to JAX's CI system.
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
bazel_remote_cache="--config=public_cache_push"
else
bazel_remote_cache="--config=public_cache"
fi
# End of test environment variables setup.
# Don't abort the script if one command fails to ensure we run both test
@ -55,7 +63,7 @@ set +e
# It appears --run_under needs an absolute path.
# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR`
# should match the VM's CPU core count (set in `--local_test_jobs`).
bazel test --config=ci_linux_x86_64_cuda \
bazel test --config=ci_linux_x86_64_cuda "$bazel_remote_cache" \
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
--//jax:build_jaxlib=false \
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
@ -79,7 +87,7 @@ first_bazel_cmd_retval=$?
echo "Running multi-accelerator tests (without RBE)..."
# Runs multiaccelerator tests with all GPUs directly on the VM without RBE..
bazel test --config=ci_linux_x86_64_cuda \
bazel test --config=ci_linux_x86_64_cuda "$bazel_remote_cache" \
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
--//jax:build_jaxlib=false \
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \