rocm_jax/.github/workflows/pytest_cpu.yml
Nitin Srinivasan 4b4f2f9cb9 Use uv to install Python packages
PiperOrigin-RevId: 730499307
2025-02-24 10:13:39 -08:00

137 lines
6.0 KiB
YAML

# CI - Pytest CPU
#
# This workflow runs the CPU tests with Pytest. It can only be triggered by other workflows via
# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest CPU tests.
#
# It consists of the following job:
# run-tests:
# - Downloads the jaxlib wheel from a GCS bucket.
# - Executes the `run_pytest_cpu.sh` script, which performs the following actions:
# - Installs the downloaded jaxlib wheel.
# - Runs the CPU tests with Pytest.
name: CI - Pytest CPU
on:
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-n2-16"
python:
description: "Which python version should the artifact be built for?"
type: string
required: true
default: "3.12"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
required: true
default: "0"
install-jax-current-commit:
description: "Should the 'jax' package be installed from the current commit?"
type: string
required: true
default: "1"
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
required: true
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: boolean
required: false
default: false
jobs:
run-tests:
defaults:
run:
# Explicitly set the shell to bash to override Windows's default (cmd)
shell: bash
runs-on: ${{ inputs.runner }}
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
(contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') ||
(contains(inputs.runner, 'windows-x86') && null) }}
name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
env:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set env vars for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)
# Adjust os and arch for Windows
if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then
os="win"
arch="amd64"
fi
# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')
echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
- name: Download jaxlib wheel from GCS (non-Windows runs)
id: download-wheel-artifacts-nw
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the step below so that we can print a more
# informative error message.
continue-on-error: true
if: ${{ !contains(inputs.runner, 'windows-x86') }}
run: |
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
fi
- name: Download jaxlib wheel from GCS (Windows runs)
id: download-wheel-artifacts-w
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in step below so that we can print a more
# informative error message.
continue-on-error: true
if: ${{ contains(inputs.runner, 'windows-x86') }}
shell: cmd
run: |
mkdir dist
@REM Use `call` so that we can run sequential gsutil commands on Windows
@REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
@REM Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
if not "${{ inputs.install-jax-current-commit }}"=="1" (
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
)
- name: Skip the test run if the wheel artifacts were not downloaded successfully
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
run: |
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
echo "Skipping the test run."
exit 1
- name: Install Python dependencies
run: $JAXCI_PYTHON -m uv pip install -r build/requirements.in
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Pytest CPU tests
timeout-minutes: 60
run: ./ci/run_pytest_cpu.sh