Add TPU test jobs to the new CI continuous and nightly/release test workflows

Also, modify the TPU presubmit workflow to reuse the `build_artifacts.yml` and `pytest_tpu.yml`

PiperOrigin-RevId: 735832964
This commit is contained in:
Nitin Srinivasan 2025-03-11 11:41:34 -07:00 committed by jax authors
parent c2c68c018f
commit 7ac6355262
7 changed files with 281 additions and 77 deletions

View File

@ -3,6 +3,7 @@
# This job currently runs as a non-blocking presubmit. It is experimental and is currently being
# tested to get to a stable state before we enable it as a blocking presubmit.
name: CI - Cloud TPU (presubmit)
on:
workflow_dispatch:
inputs:
@ -33,64 +34,32 @@ concurrency:
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
cloud-tpu-test:
build-jax-artifacts:
if: github.event.repository.fork == false
# Begin Presubmit Naming Check - name modification requires internal check to be updated
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
tpu: [
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
python-version: ["3.10"]
name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})"
# End Presubmit Naming Check github-tpu-presubmits
env:
JAXCI_PYTHON: python${{ matrix.python-version }}
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}
fail-fast: false # don't cancel all jobs on failure
matrix:
artifact: ["jax", "jaxlib"]
with:
runner: "linux-x86-n2-16"
artifact: ${{ matrix.artifact }}
python: "3.10"
clone_main_xla: 1
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
runs-on: ${{ matrix.tpu.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
timeout-minutes: 60
defaults:
run:
shell: bash -ex {0}
steps:
# https://opensource.google/documentation/reference/github/services#actions
# mandates using a specific commit for non-Google actions. We use
# https://github.com/sethvargo/ratchet to pin specific versions.
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# Checkout XLA at head, if we're building jaxlib at head.
- name: Checkout XLA at head
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: openxla/xla
path: xla
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
- name: Mark GitHub workspace as safe
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install JAX test requirements
run: |
$JAXCI_PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
- name: Build jaxlib at head with latest XLA
run: |
# Build and install jaxlib at head
$JAXCI_PYTHON build/build.py build --wheels=jaxlib \
--python_version=${{ matrix.python-version }} \
--bazel_options=--config=rbe_linux_x86_64 \
--local_xla_path="$(pwd)/xla" \
--verbose
# Install libtpu
$JAXCI_PYTHON -m uv pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Install jaxlib wheel and run tests
run: ./ci/run_pytest_tpu.sh
run-pytest-tpu:
if: github.event.repository.fork == false
needs: [build-jax-artifacts]
uses: ./.github/workflows/pytest_tpu.yml
# Begin Presubmit Naming Check - name modification requires internal check to be updated
name: "TPU test (jaxlib=head, v5e-8)"
with:
runner: "linux-x86-ct5lp-224-8tpu"
cores: "8"
tpu-type: "v5e-8"
python: "3.10"
libtpu-version-type: "nightly"
gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }}
# End Presubmit Naming Check github-tpu-presubmits

149
.github/workflows/pytest_tpu.yml vendored Normal file
View File

@ -0,0 +1,149 @@
# CI - Pytest TPU
#
# This workflow runs the TPU tests with Pytest. It can only be triggered by other workflows via
# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest TPU tests.
#
# It consists of the following job:
# run-tests:
# - Downloads the jaxlib wheel from a GCS bucket.
# - Sets up the libtpu wheels.
# - Executes the `run_pytest_cpu.sh` script, which performs the following actions:
# - Installs the downloaded jaxlib wheel.
# - Runs the TPU tests with Pytest.
name: CI - Pytest TPU
on:
workflow_call:
inputs:
# Note that the values for runners, cores, and tpu-type are linked to each other.
# For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the
# following mapping:
# {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
# {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-ct5lp-224-8tpu"
cores:
description: "How many TPU cores should the test use?"
type: string
required: true
default: "8"
tpu-type:
description: "Which TPU type is used for testing?"
type: string
required: true
default: "v5e-8"
python:
description: "Which Python version should be used for testing?"
type: string
required: true
default: "3.12"
run-full-tpu-test-suite:
description: "Should the full TPU test suite be run?"
type: string
required: false
default: "0"
libtpu-version-type:
description: "Which libtpu version should be used for testing?"
type: string
required: false
# Choices are:
# - "nightly": Use the nightly libtpu wheel.
# - "pypi_latest": Use the latest libtpu wheel from PyPI.
# - "oldest_supported_libtpu": Use the oldest supported libtpu wheel.
default: "nightly"
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
required: true
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: boolean
required: false
default: false
jobs:
run-tests:
defaults:
run:
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
# Begin Presubmit Naming Check - name modification requires internal check to be updated
name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})"
# End Presubmit Naming Check github-tpu-presubmits
env:
LIBTPU_OLDEST_VERSION_DATE: 20241205
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}"
JAXCI_TPU_CORES: "${{ inputs.cores }}"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set env vars for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)
# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')
echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
# Python wheels follow a naming convention: standard wheels use the pattern
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
# `*-cp<py_version>-cp<py_version>t-*`.
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
- name: Download JAX wheels from GCS
id: download-wheel-artifacts
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the step below so that we can print a more
# informative error message.
continue-on-error: true
run: |
mkdir -p $(pwd)/dist
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
- name: Skip the test run if the wheel artifacts were not downloaded successfully
if: steps.download-wheel-artifacts.outcome == 'failure'
run: |
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
echo "Skipping the test run."
exit 1
- name: Install Python dependencies
run: |
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt
- name: Set up libtpu wheels
run: |
if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then
echo "Using nightly libtpu"
$JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then
echo "Using latest libtpu from PyPI"
# Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh`
# script will install the latest libtpu wheel from PyPI.
echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV
elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then
echo "Using oldest supported libtpu"
$JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
else
echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
exit 1
fi
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Pytest TPU tests
timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }}
run: ./ci/run_pytest_tpu.sh

View File

@ -142,4 +142,30 @@ jobs:
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-pytest-tpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_tpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python: ["3.10",]
tpu-specs: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
name: "TPU tests (jax=head, jaxlib=head)"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: "nightly"
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

View File

@ -58,4 +58,29 @@ jobs:
python: ${{ matrix.python }}
cuda: ${{ matrix.cuda }}
enable-x64: ${{ matrix.enable-x64 }}
gcs_download_uri: ${{inputs.gcs_download_uri}}
run-pytest-tpu:
uses: ./.github/workflows/pytest_tpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python: ["3.10","3.11", "3.12", "3.13"]
tpu-specs: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"]
exclude:
- libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }}
- libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }}
name: "TPU tests (jax=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, jaxlib=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
gcs_download_uri: ${{inputs.gcs_download_uri}}

View File

@ -74,4 +74,14 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-}
# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels
# on the system. By default, it is set to match the version of the hermetic
# Python used by Bazel for building the wheels.
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
# When set to 1, the full TPU test suite is run. Otherwise, a subset of tests
# is run.
export JAXCI_RUN_FULL_TPU_TEST_SUITE=${JAXCI_RUN_FULL_TPU_TEST_SUITE:-0}
# We use this environment variable to control which additional wheels to install
# from PyPI. For instance, it can be set to "tpu_pypi" to install the latest
# libtpu wheel from PyPI. See ci/utilities/install_wheels_locally.sh for the
# list of valid values and their behavior.
export JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=${JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI:-""}

View File

@ -52,23 +52,37 @@ export JAX_SKIP_SLOW_TESTS=true
echo "Running TPU tests..."
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" \
tests/pallas/ops_test.py \
tests/pallas/export_back_compat_pallas_test.py \
tests/pallas/export_pallas_test.py \
tests/pallas/tpu_ops_test.py \
tests/pallas/tpu_pallas_test.py \
tests/pallas/tpu_pallas_random_test.py \
tests/pallas/tpu_pallas_async_test.py \
tests/pallas/tpu_pallas_state_test.py
if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" tests examples
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s \
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run multi-accelerator across all chips
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \
tests/pjit_test.py \
tests/pallas/tpu_pallas_distributed_test.py
# Run multi-accelerator across all chips
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
else
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" \
tests/pallas/ops_test.py \
tests/pallas/export_back_compat_pallas_test.py \
tests/pallas/export_pallas_test.py \
tests/pallas/tpu_ops_test.py \
tests/pallas/tpu_pallas_test.py \
tests/pallas/tpu_pallas_random_test.py \
tests/pallas/tpu_pallas_async_test.py \
tests/pallas/tpu_pallas_state_test.py
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run multi-accelerator across all chips
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \
tests/pjit_test.py \
tests/pallas/tpu_pallas_distributed_test.py
fi

View File

@ -17,8 +17,19 @@
# Install wheels stored in `JAXCI_OUTPUT_DIR` on the system using the Python
# binary set in JAXCI_PYTHON. Use the absolute path to the `find` utility to
# avoid using the Windows version of `find` on Msys.
WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) )
for i in "${!WHEELS[@]}"; do
if [[ "${WHEELS[$i]}" == *jax*py3*none*any.whl ]]; then
if [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "tpu_pypi" ]]; then
# Append [tpu] to the jax wheel name to download the latest libtpu wheel
# from PyPI.
WHEELS[$i]="${WHEELS[$i]}[tpu]"
fi
fi
done
if [[ -z "${WHEELS[@]}" ]]; then
echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR"
exit 1