Add workflow for testing nightly/release artifacts

This commits adds a Github action workflow that will be used to jobs that test the nightly/release artifacts. These artifacts are built by our internal CI jobs and are uploaded to a transient GCS bucket. After all the wheels have finished uploading, an internal job is run that that will trigger the `wheel_tests_nightly_release.yml` workflow.

PiperOrigin-RevId: 716789482
This commit is contained in:
Nitin Srinivasan 2025-01-17 13:53:15 -08:00 committed by jax authors
parent 12beb00bb3
commit 9fb29766a2
3 changed files with 78 additions and 4 deletions

View File

@ -0,0 +1,67 @@
# CI - Wheel Tests (Continuous)
#
# This workflow builds JAX artifacts and runs CPU/CUDA tests.
#
# It orchestrates the following:
# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and
# uploads it to a GCS bucket.
# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow to download the jaxlib wheel that was built
# in the previous step and runs CPU tests.
# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and
# uploads them to a GCS bucket.
# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow to download the jaxlib and CUDA artifacts
# that were built in the previous steps and runs the CUDA tests.
name: CI - Wheel Tests (Nightly/Release)
on:
workflow_dispatch:
inputs:
gcs_download_uri:
description: "GCS location URI from where the artifacts should be downloaded"
required: true
default: 'gs://jax-nightly-release-transient/nightly/latest'
type: string
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
env:
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
# GCS bucket.
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "0"
jobs:
run-pytest-cpu:
uses: ./.github/workflows/pytest_cpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy of our internal CI jobs
# that build the wheels.
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
python: ["3.10","3.11", "3.12", "3.13"]
enable-x64: [0]
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
gcs_download_uri: ${{inputs.gcs_download_uri}}
run-pytest-cuda:
uses: ./.github/workflows/pytest_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy of our internal CI jobs
# that build the wheels.
runner: ["linux-x86-g2-48-l4-4gpu"]
python: ["3.10","3.11", "3.12", "3.13"]
cuda: ["12.3", "12.1"]
enable-x64: [0]
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
cuda: ${{ matrix.cuda }}
enable-x64: ${{ matrix.enable-x64 }}
gcs_download_uri: ${{inputs.gcs_download_uri}}

View File

@ -67,3 +67,8 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-}
# on the system. By default, it is set to match the version of the hermetic
# Python used by Bazel for building the wheels.
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
# Installs the JAX package in editable mode at the current commit. Enabled by
# default. Nightly/Release builds disable this flag in the Github action
# workflow files.
export JAXCI_INSTALL_JAX_CURRENT_COMMIT=${JAXCI_INSTALL_JAX_CURRENT_COMMIT:-"1"}

View File

@ -17,7 +17,7 @@
# Install wheels stored in `JAXCI_OUTPUT_DIR` on the system using the Python
# binary set in JAXCI_PYTHON. Use the absolute path to the `find` utility to
# avoid using the Windows version of `find` on Msys.
WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) )
WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) )
if [[ -z "$WHEELS" ]]; then
echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR"
@ -34,6 +34,8 @@ else
"$JAXCI_PYTHON" -m pip install "${WHEELS[@]}"
fi
echo "Installing the JAX package in editable mode at the current commit..."
# Install JAX package at the current commit.
"$JAXCI_PYTHON" -m pip install -U -e .
if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then
echo "Installing the JAX package in editable mode at the current commit..."
# Install JAX package at the current commit.
"$JAXCI_PYTHON" -m pip install -U -e .
fi