Add job that runs Bazel single accelerator and multi-accelerator CUDA tests (non-RBE)

PiperOrigin-RevId: 718637923
This commit is contained in:
Nitin Srinivasan 2025-01-22 17:51:10 -08:00 committed by jax authors
parent 704b2e5fba
commit 9aad6a6827
4 changed files with 124 additions and 3 deletions

View File

@ -0,0 +1,81 @@
# CI - Bazel CUDA tests (Non-RBE)
#
# This workflow runs the CUDA tests with Bazel. It can only be triggered by other workflows via
# `workflow_call`. It is used by the `CI - Wheel Tests` workflows to run the Bazel CUDA tests.
#
# It consists of the following job:
# run-tests:
# - Downloads the jaxlib and CUDA artifacts from a GCS bucket.
# - Executes the `run_bazel_test_cuda_non_rbe.sh` script, which performs the following actions:
# - Installs the downloaded wheel artifacts.
# - Runs the CUDA tests with Bazel.
name: CI - Bazel CUDA tests (Non-RBE)
on:
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-n2-16"
python:
description: "Which python version to test?"
type: string
required: true
default: "3.12"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
required: true
default: "0"
gcs_download_uri:
description: "GCS location URI from where the artifacts should be downloaded"
required: true
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: boolean
required: false
default: false
jobs:
run-tests:
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set env vars for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)
# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')
echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
- name: Download the wheel artifacts from GCS
run: >-
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CUDA tests (Non-RBE)
timeout-minutes: 60
run: ./ci/run_bazel_test_cuda_non_rbe.sh

View File

@ -7,7 +7,7 @@
# run-tests:
# - Downloads the jaxlib and CUDA artifacts from a GCS bucket.
# - Executes the `run_pytest_cuda.sh` script, which performs the following actions:
# - Installs the downloaded jaxlib wheel.
# - Installs the downloaded wheel artifacts.
# - Runs the CUDA tests with Pytest.
name: CI - Pytest CUDA

View File

@ -11,6 +11,10 @@
# uploads them to a GCS bucket.
# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow to download the jaxlib and CUDA artifacts
# that were built in the previous steps and runs the CUDA tests.
# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow to download the jaxlib and
# CUDA artifacts that were built in the previous steps and runs the CUDA
# tests using Bazel.
name: CI - Wheel Tests (Continuous)
on:
@ -44,7 +48,7 @@ jobs:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the GPU tests job below
# Python values need to match the matrix stategy in the CUDA tests job below
runner: ["linux-x86-n2-16"]
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
python: ["3.10",]
@ -99,3 +103,20 @@ jobs:
enable-x64: ${{ matrix.enable-x64 }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-bazel-test-cuda:
needs: [build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/bazel_cuda_non_rbe.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the build artifacts job above
runner: ["linux-x86-g2-48-l4-4gpu",]
python: ["3.10",]
enable-x64: [1, 0]
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

View File

@ -47,6 +47,10 @@ if [[ $num_test_jobs -gt $num_cpu_cores ]]; then
fi
# End of test environment variables setup.
# Don't abort the script if one command fails to ensure we run both test
# commands below.
set +e
# Runs single accelerator tests with one GPU apiece.
# It appears --run_under needs an absolute path.
# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR`
@ -70,6 +74,9 @@ bazel test --config=ci_linux_x86_64_cuda \
//tests:gpu_tests //tests:backend_independent_tests \
//tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
# Store the return value of the first bazel command.
first_bazel_cmd_retval=$?
echo "Running multi-accelerator tests (without RBE)..."
# Runs multiaccelerator tests with all GPUs directly on the VM without RBE..
bazel test --config=ci_linux_x86_64_cuda \
@ -85,3 +92,15 @@ bazel test --config=ci_linux_x86_64_cuda \
--action_env=NCCL_DEBUG=WARN \
--color=yes \
//tests:gpu_tests //tests/pallas:gpu_tests
# Store the return value of the second bazel command.
second_bazel_cmd_retval=$?
# Exit with failure if either command fails.
if [[ $first_bazel_cmd_retval -ne 0 ]]; then
exit $first_bazel_cmd_retval
elif [[ $second_bazel_cmd_retval -ne 0 ]]; then
exit $second_bazel_cmd_retval
else
exit 0
fi