mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Add new Bazel remote cache configs
An example run where the cache configs are used: https://github.com/jax-ml/jax/actions/runs/12940123731 PiperOrigin-RevId: 719627011
This commit is contained in:
parent
d28c3fa409
commit
89a9c6c244
6
.bazelrc
6
.bazelrc
@ -170,6 +170,12 @@ build:rocm --action_env=TF_HIPCC_CLANG="1"
|
||||
# #############################################################################
|
||||
# Cache options below.
|
||||
# #############################################################################
|
||||
# Public read-only cache
|
||||
build:public_cache --remote_cache="https://storage.googleapis.com/jax-bazel-cache/" --remote_upload_local_results=false
|
||||
# Cache pushes are limited to JAX's CI system.
|
||||
build:public_cache_push --config=public_cache --remote_upload_local_results=true --google_default_credentials
|
||||
|
||||
# Note: the following cache configs are deprecated and will be removed soon.
|
||||
# Public read-only cache for Mac builds. JAX uses a GCS bucket to store cache
|
||||
# from JAX's Mac CI build. By applying --config=macos_cache, any local Mac build
|
||||
# should be able to read from this cache and potentially see a speedup. The
|
||||
|
2
.github/workflows/bazel_cuda_non_rbe.yml
vendored
2
.github/workflows/bazel_cuda_non_rbe.yml
vendored
@ -48,6 +48,8 @@ jobs:
|
||||
env:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
|
||||
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
|
||||
# Enable writing to the Bazel remote cache bucket.
|
||||
JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: "1"
|
||||
|
||||
name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
|
||||
|
||||
|
3
.github/workflows/build_artifacts.yml
vendored
3
.github/workflows/build_artifacts.yml
vendored
@ -122,6 +122,9 @@ jobs:
|
||||
- name: Enable RBE if building on Linux x86 or Windows x86
|
||||
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86')
|
||||
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV
|
||||
- name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64
|
||||
if: contains(inputs.runner, 'linux-arm64')
|
||||
run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV
|
||||
# Halt for testing
|
||||
- name: Wait For Connection
|
||||
uses: google-ml-infra/actions/ci_connection@main
|
||||
|
@ -56,11 +56,22 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
|
||||
# flags in the .bazelrc depending upon the platform we are building for.
|
||||
bazelrc_config="${os}_${arch}"
|
||||
|
||||
# TODO(b/379903748): Add remote cache options for Linux and Windows.
|
||||
# On platforms with no RBE support, we can use the Bazel remote cache. Set
|
||||
# it to be empty by default to avoid unbound variable errors.
|
||||
bazel_remote_cache=""
|
||||
|
||||
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
|
||||
bazelrc_config="rbe_${bazelrc_config}"
|
||||
else
|
||||
bazelrc_config="ci_${bazelrc_config}"
|
||||
|
||||
# Set remote cache flags. Pushes to the cache bucket is limited to JAX's
|
||||
# CI system.
|
||||
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
|
||||
bazel_remote_cache="--bazel_options=--config=public_cache_push"
|
||||
else
|
||||
bazel_remote_cache="--bazel_options=--config=public_cache"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Use the "_cuda" configs when building the CUDA artifacts.
|
||||
@ -69,7 +80,10 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
|
||||
fi
|
||||
|
||||
# Build the artifact.
|
||||
python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose --detailed_timestamped_log
|
||||
python build/build.py build --wheels="$artifact" \
|
||||
--bazel_options=--config="$bazelrc_config" $bazel_remote_cache \
|
||||
--python_version=$JAXCI_HERMETIC_PYTHON_VERSION \
|
||||
--verbose --detailed_timestamped_log
|
||||
|
||||
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
|
||||
# run `auditwheel show` to verify manylinux compliance.
|
||||
|
@ -44,6 +44,12 @@ export JAXCI_OUTPUT_DIR="$(pwd)/dist"
|
||||
# for CI builds where RBE is supported.
|
||||
export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}
|
||||
|
||||
# On platforms where RBE is not supported, we use Bazel remote cache to speed up
|
||||
# builds. When this flag is enabled, Bazel will also try to push new cache
|
||||
# entries to the bucket. Since writes to the bucket require authentication, this
|
||||
# flag is enabled only for CI builds.
|
||||
export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0}
|
||||
|
||||
# #############################################################################
|
||||
# Test script specific environment variables.
|
||||
# #############################################################################
|
||||
|
@ -45,6 +45,14 @@ export num_cpu_cores=$(nproc)
|
||||
if [[ $num_test_jobs -gt $num_cpu_cores ]]; then
|
||||
num_test_jobs=$num_cpu_cores
|
||||
fi
|
||||
|
||||
# Use the Bazel remote cache to speed up builds. Pushes to the cache bucket is
|
||||
# limited to JAX's CI system.
|
||||
if [[ "$JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE" == 1 ]]; then
|
||||
bazel_remote_cache="--config=public_cache_push"
|
||||
else
|
||||
bazel_remote_cache="--config=public_cache"
|
||||
fi
|
||||
# End of test environment variables setup.
|
||||
|
||||
# Don't abort the script if one command fails to ensure we run both test
|
||||
@ -55,7 +63,7 @@ set +e
|
||||
# It appears --run_under needs an absolute path.
|
||||
# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR`
|
||||
# should match the VM's CPU core count (set in `--local_test_jobs`).
|
||||
bazel test --config=ci_linux_x86_64_cuda \
|
||||
bazel test --config=ci_linux_x86_64_cuda "$bazel_remote_cache" \
|
||||
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
|
||||
--//jax:build_jaxlib=false \
|
||||
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
|
||||
@ -79,7 +87,7 @@ first_bazel_cmd_retval=$?
|
||||
|
||||
echo "Running multi-accelerator tests (without RBE)..."
|
||||
# Runs multiaccelerator tests with all GPUs directly on the VM without RBE..
|
||||
bazel test --config=ci_linux_x86_64_cuda \
|
||||
bazel test --config=ci_linux_x86_64_cuda "$bazel_remote_cache" \
|
||||
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
|
||||
--//jax:build_jaxlib=false \
|
||||
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
|
||||
|
Loading…
x
Reference in New Issue
Block a user