rocm_jax/.github/workflows/pytest_tpu.yml
Nitin Srinivasan a6ab6bbc20 Ignore Pallas TPU tests when testing with the oldest supported libtpu
I missed adding this in from https://github.com/jax-ml/jax/blob/main/.github/workflows/cloud-tpu-ci-nightly.yml when I added the TPU jobs to the new CI workflows

PiperOrigin-RevId: 736094492
2025-03-12 05:20:42 -07:00

152 lines
6.9 KiB
YAML

# CI - Pytest TPU
#
# This workflow runs the TPU tests with Pytest. It can only be triggered by other workflows via
# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest TPU tests.
#
# It consists of the following job:
# run-tests:
# - Downloads the jaxlib wheel from a GCS bucket.
# - Sets up the libtpu wheels.
# - Executes the `run_pytest_cpu.sh` script, which performs the following actions:
# - Installs the downloaded jaxlib wheel.
# - Runs the TPU tests with Pytest.
name: CI - Pytest TPU
on:
workflow_call:
inputs:
# Note that the values for runners, cores, and tpu-type are linked to each other.
# For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the
# following mapping:
# {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
# {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-ct5lp-224-8tpu"
cores:
description: "How many TPU cores should the test use?"
type: string
required: true
default: "8"
tpu-type:
description: "Which TPU type is used for testing?"
type: string
required: true
default: "v5e-8"
python:
description: "Which Python version should be used for testing?"
type: string
required: true
default: "3.12"
run-full-tpu-test-suite:
description: "Should the full TPU test suite be run?"
type: string
required: false
default: "0"
libtpu-version-type:
description: "Which libtpu version should be used for testing?"
type: string
required: false
# Choices are:
# - "nightly": Use the nightly libtpu wheel.
# - "pypi_latest": Use the latest libtpu wheel from PyPI.
# - "oldest_supported_libtpu": Use the oldest supported libtpu wheel.
default: "nightly"
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
required: true
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: boolean
required: false
default: false
jobs:
run-tests:
defaults:
run:
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
# Begin Presubmit Naming Check - name modification requires internal check to be updated
name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})"
# End Presubmit Naming Check github-tpu-presubmits
env:
LIBTPU_OLDEST_VERSION_DATE: 20241205
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}"
JAXCI_TPU_CORES: "${{ inputs.cores }}"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set env vars for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)
# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')
echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
# Python wheels follow a naming convention: standard wheels use the pattern
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
# `*-cp<py_version>-cp<py_version>t-*`.
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
- name: Download JAX wheels from GCS
id: download-wheel-artifacts
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the step below so that we can print a more
# informative error message.
continue-on-error: true
run: |
mkdir -p $(pwd)/dist
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
- name: Skip the test run if the wheel artifacts were not downloaded successfully
if: steps.download-wheel-artifacts.outcome == 'failure'
run: |
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
echo "Skipping the test run."
exit 1
- name: Install Python dependencies
run: |
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt
- name: Set up libtpu wheels
run: |
if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then
echo "Using nightly libtpu"
$JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then
echo "Using latest libtpu from PyPI"
# Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh`
# script will install the latest libtpu wheel from PyPI.
echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV
elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then
echo "Using oldest supported libtpu"
$JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV
else
echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
exit 1
fi
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Pytest TPU tests
timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }}
run: ./ci/run_pytest_tpu.sh