mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add job that runs Bazel single accelerator and multi-accelerator CUDA tests (non-RBE)
PiperOrigin-RevId: 718637923
This commit is contained in:
parent
704b2e5fba
commit
9aad6a6827
81
.github/workflows/bazel_cuda_non_rbe.yml
vendored
Normal file
81
.github/workflows/bazel_cuda_non_rbe.yml
vendored
Normal file
@ -0,0 +1,81 @@
|
||||
# CI - Bazel CUDA tests (Non-RBE)
|
||||
#
|
||||
# This workflow runs the CUDA tests with Bazel. It can only be triggered by other workflows via
|
||||
# `workflow_call`. It is used by the `CI - Wheel Tests` workflows to run the Bazel CUDA tests.
|
||||
#
|
||||
# It consists of the following job:
|
||||
# run-tests:
|
||||
# - Downloads the jaxlib and CUDA artifacts from a GCS bucket.
|
||||
# - Executes the `run_bazel_test_cuda_non_rbe.sh` script, which performs the following actions:
|
||||
# - Installs the downloaded wheel artifacts.
|
||||
# - Runs the CUDA tests with Bazel.
|
||||
name: CI - Bazel CUDA tests (Non-RBE)
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
runner:
|
||||
description: "Which runner should the workflow run on?"
|
||||
type: string
|
||||
required: true
|
||||
default: "linux-x86-n2-16"
|
||||
python:
|
||||
description: "Which python version to test?"
|
||||
type: string
|
||||
required: true
|
||||
default: "3.12"
|
||||
enable-x64:
|
||||
description: "Should x64 mode be enabled?"
|
||||
type: string
|
||||
required: true
|
||||
default: "0"
|
||||
gcs_download_uri:
|
||||
description: "GCS location URI from where the artifacts should be downloaded"
|
||||
required: true
|
||||
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
||||
type: string
|
||||
halt-for-connection:
|
||||
description: 'Should this workflow run wait for a remote connection?'
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
run-tests:
|
||||
runs-on: ${{ inputs.runner }}
|
||||
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
|
||||
|
||||
env:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
|
||||
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
|
||||
|
||||
name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set env vars for use in artifact download URL
|
||||
run: |
|
||||
os=$(uname -s | awk '{print tolower($0)}')
|
||||
arch=$(uname -m)
|
||||
|
||||
# Get the major and minor version of Python.
|
||||
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
|
||||
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')
|
||||
|
||||
echo "OS=${os}" >> $GITHUB_ENV
|
||||
echo "ARCH=${arch}" >> $GITHUB_ENV
|
||||
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
|
||||
- name: Download the wheel artifacts from GCS
|
||||
run: >-
|
||||
mkdir -p $(pwd)/dist &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
|
||||
# Halt for testing
|
||||
- name: Wait For Connection
|
||||
uses: google-ml-infra/actions/ci_connection@main
|
||||
with:
|
||||
halt-dispatch-input: ${{ inputs.halt-for-connection }}
|
||||
- name: Run Bazel CUDA tests (Non-RBE)
|
||||
timeout-minutes: 60
|
||||
run: ./ci/run_bazel_test_cuda_non_rbe.sh
|
2
.github/workflows/pytest_cuda.yml
vendored
2
.github/workflows/pytest_cuda.yml
vendored
@ -7,7 +7,7 @@
|
||||
# run-tests:
|
||||
# - Downloads the jaxlib and CUDA artifacts from a GCS bucket.
|
||||
# - Executes the `run_pytest_cuda.sh` script, which performs the following actions:
|
||||
# - Installs the downloaded jaxlib wheel.
|
||||
# - Installs the downloaded wheel artifacts.
|
||||
# - Runs the CUDA tests with Pytest.
|
||||
name: CI - Pytest CUDA
|
||||
|
||||
|
23
.github/workflows/wheel_tests_continuous.yml
vendored
23
.github/workflows/wheel_tests_continuous.yml
vendored
@ -11,6 +11,10 @@
|
||||
# uploads them to a GCS bucket.
|
||||
# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow to download the jaxlib and CUDA artifacts
|
||||
# that were built in the previous steps and runs the CUDA tests.
|
||||
# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow to download the jaxlib and
|
||||
# CUDA artifacts that were built in the previous steps and runs the CUDA
|
||||
# tests using Bazel.
|
||||
|
||||
name: CI - Wheel Tests (Continuous)
|
||||
|
||||
on:
|
||||
@ -44,7 +48,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
# Python values need to match the matrix stategy in the GPU tests job below
|
||||
# Python values need to match the matrix stategy in the CUDA tests job below
|
||||
runner: ["linux-x86-n2-16"]
|
||||
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
|
||||
python: ["3.10",]
|
||||
@ -99,3 +103,20 @@ jobs:
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
# GCS upload URI is the same for both artifact build jobs
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
||||
|
||||
run-bazel-test-cuda:
|
||||
needs: [build-jaxlib-artifact, build-cuda-artifacts]
|
||||
uses: ./.github/workflows/bazel_cuda_non_rbe.yml
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
# Python values need to match the matrix stategy in the build artifacts job above
|
||||
runner: ["linux-x86-g2-48-l4-4gpu",]
|
||||
python: ["3.10",]
|
||||
enable-x64: [1, 0]
|
||||
with:
|
||||
runner: ${{ matrix.runner }}
|
||||
python: ${{ matrix.python }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
# GCS upload URI is the same for both artifact build jobs
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
@ -47,6 +47,10 @@ if [[ $num_test_jobs -gt $num_cpu_cores ]]; then
|
||||
fi
|
||||
# End of test environment variables setup.
|
||||
|
||||
# Don't abort the script if one command fails to ensure we run both test
|
||||
# commands below.
|
||||
set +e
|
||||
|
||||
# Runs single accelerator tests with one GPU apiece.
|
||||
# It appears --run_under needs an absolute path.
|
||||
# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR`
|
||||
@ -70,6 +74,9 @@ bazel test --config=ci_linux_x86_64_cuda \
|
||||
//tests:gpu_tests //tests:backend_independent_tests \
|
||||
//tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
|
||||
|
||||
# Store the return value of the first bazel command.
|
||||
first_bazel_cmd_retval=$?
|
||||
|
||||
echo "Running multi-accelerator tests (without RBE)..."
|
||||
# Runs multiaccelerator tests with all GPUs directly on the VM without RBE..
|
||||
bazel test --config=ci_linux_x86_64_cuda \
|
||||
@ -85,3 +92,15 @@ bazel test --config=ci_linux_x86_64_cuda \
|
||||
--action_env=NCCL_DEBUG=WARN \
|
||||
--color=yes \
|
||||
//tests:gpu_tests //tests/pallas:gpu_tests
|
||||
|
||||
# Store the return value of the second bazel command.
|
||||
second_bazel_cmd_retval=$?
|
||||
|
||||
# Exit with failure if either command fails.
|
||||
if [[ $first_bazel_cmd_retval -ne 0 ]]; then
|
||||
exit $first_bazel_cmd_retval
|
||||
elif [[ $second_bazel_cmd_retval -ne 0 ]]; then
|
||||
exit $second_bazel_cmd_retval
|
||||
else
|
||||
exit 0
|
||||
fi
|
Loading…
x
Reference in New Issue
Block a user