mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

I missed adding this in from https://github.com/jax-ml/jax/blob/main/.github/workflows/cloud-tpu-ci-nightly.yml when I added the TPU jobs to the new CI workflows PiperOrigin-RevId: 736094492
152 lines
6.9 KiB
YAML
152 lines
6.9 KiB
YAML
# CI - Pytest TPU
|
|
#
|
|
# This workflow runs the TPU tests with Pytest. It can only be triggered by other workflows via
|
|
# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest TPU tests.
|
|
#
|
|
# It consists of the following job:
|
|
# run-tests:
|
|
# - Downloads the jaxlib wheel from a GCS bucket.
|
|
# - Sets up the libtpu wheels.
|
|
# - Executes the `run_pytest_cpu.sh` script, which performs the following actions:
|
|
# - Installs the downloaded jaxlib wheel.
|
|
# - Runs the TPU tests with Pytest.
|
|
name: CI - Pytest TPU
|
|
|
|
on:
|
|
workflow_call:
|
|
inputs:
|
|
# Note that the values for runners, cores, and tpu-type are linked to each other.
|
|
# For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the
|
|
# following mapping:
|
|
# {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
|
|
# {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
|
runner:
|
|
description: "Which runner should the workflow run on?"
|
|
type: string
|
|
required: true
|
|
default: "linux-x86-ct5lp-224-8tpu"
|
|
cores:
|
|
description: "How many TPU cores should the test use?"
|
|
type: string
|
|
required: true
|
|
default: "8"
|
|
tpu-type:
|
|
description: "Which TPU type is used for testing?"
|
|
type: string
|
|
required: true
|
|
default: "v5e-8"
|
|
python:
|
|
description: "Which Python version should be used for testing?"
|
|
type: string
|
|
required: true
|
|
default: "3.12"
|
|
run-full-tpu-test-suite:
|
|
description: "Should the full TPU test suite be run?"
|
|
type: string
|
|
required: false
|
|
default: "0"
|
|
libtpu-version-type:
|
|
description: "Which libtpu version should be used for testing?"
|
|
type: string
|
|
required: false
|
|
# Choices are:
|
|
# - "nightly": Use the nightly libtpu wheel.
|
|
# - "pypi_latest": Use the latest libtpu wheel from PyPI.
|
|
# - "oldest_supported_libtpu": Use the oldest supported libtpu wheel.
|
|
default: "nightly"
|
|
gcs_download_uri:
|
|
description: "GCS location prefix from where the artifacts should be downloaded"
|
|
required: true
|
|
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
|
type: string
|
|
halt-for-connection:
|
|
description: 'Should this workflow run wait for a remote connection?'
|
|
type: boolean
|
|
required: false
|
|
default: false
|
|
|
|
jobs:
|
|
run-tests:
|
|
defaults:
|
|
run:
|
|
shell: bash
|
|
runs-on: ${{ inputs.runner }}
|
|
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
|
|
# Begin Presubmit Naming Check - name modification requires internal check to be updated
|
|
name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})"
|
|
# End Presubmit Naming Check github-tpu-presubmits
|
|
|
|
env:
|
|
LIBTPU_OLDEST_VERSION_DATE: 20241205
|
|
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
|
|
JAXCI_PYTHON: "python${{ inputs.python }}"
|
|
JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}"
|
|
JAXCI_TPU_CORES: "${{ inputs.cores }}"
|
|
|
|
steps:
|
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
- name: Set env vars for use in artifact download URL
|
|
run: |
|
|
os=$(uname -s | awk '{print tolower($0)}')
|
|
arch=$(uname -m)
|
|
|
|
# Get the major and minor version of Python.
|
|
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
|
|
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
|
|
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')
|
|
|
|
echo "OS=${os}" >> $GITHUB_ENV
|
|
echo "ARCH=${arch}" >> $GITHUB_ENV
|
|
# Python wheels follow a naming convention: standard wheels use the pattern
|
|
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
|
|
# `*-cp<py_version>-cp<py_version>t-*`.
|
|
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
|
|
- name: Download JAX wheels from GCS
|
|
id: download-wheel-artifacts
|
|
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
|
# fails. Instead, we verify the outcome in the step below so that we can print a more
|
|
# informative error message.
|
|
continue-on-error: true
|
|
run: |
|
|
mkdir -p $(pwd)/dist
|
|
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
|
|
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
|
|
- name: Skip the test run if the wheel artifacts were not downloaded successfully
|
|
if: steps.download-wheel-artifacts.outcome == 'failure'
|
|
run: |
|
|
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
|
|
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
|
|
echo "Skipping the test run."
|
|
exit 1
|
|
- name: Install Python dependencies
|
|
run: |
|
|
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt
|
|
- name: Set up libtpu wheels
|
|
run: |
|
|
if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then
|
|
echo "Using nightly libtpu"
|
|
$JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then
|
|
echo "Using latest libtpu from PyPI"
|
|
# Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh`
|
|
# script will install the latest libtpu wheel from PyPI.
|
|
echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV
|
|
elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then
|
|
echo "Using oldest supported libtpu"
|
|
$JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
|
|
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
|
|
echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV
|
|
else
|
|
echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
|
|
exit 1
|
|
fi
|
|
# Halt for testing
|
|
- name: Wait For Connection
|
|
uses: google-ml-infra/actions/ci_connection@main
|
|
with:
|
|
halt-dispatch-input: ${{ inputs.halt-for-connection }}
|
|
- name: Run Pytest TPU tests
|
|
timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }}
|
|
run: ./ci/run_pytest_tpu.sh
|