mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
137 lines
6.0 KiB
YAML
137 lines
6.0 KiB
YAML
# CI - Pytest CPU
|
|
#
|
|
# This workflow runs the CPU tests with Pytest. It can only be triggered by other workflows via
|
|
# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest CPU tests.
|
|
#
|
|
# It consists of the following job:
|
|
# run-tests:
|
|
# - Downloads the jaxlib wheel from a GCS bucket.
|
|
# - Executes the `run_pytest_cpu.sh` script, which performs the following actions:
|
|
# - Installs the downloaded jaxlib wheel.
|
|
# - Runs the CPU tests with Pytest.
|
|
name: CI - Pytest CPU
|
|
|
|
on:
|
|
workflow_call:
|
|
inputs:
|
|
runner:
|
|
description: "Which runner should the workflow run on?"
|
|
type: string
|
|
required: true
|
|
default: "linux-x86-n2-16"
|
|
python:
|
|
description: "Which python version should the artifact be built for?"
|
|
type: string
|
|
required: true
|
|
default: "3.12"
|
|
enable-x64:
|
|
description: "Should x64 mode be enabled?"
|
|
type: string
|
|
required: true
|
|
default: "0"
|
|
install-jax-current-commit:
|
|
description: "Should the 'jax' package be installed from the current commit?"
|
|
type: string
|
|
required: true
|
|
default: "1"
|
|
gcs_download_uri:
|
|
description: "GCS location prefix from where the artifacts should be downloaded"
|
|
required: true
|
|
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
|
type: string
|
|
halt-for-connection:
|
|
description: 'Should this workflow run wait for a remote connection?'
|
|
type: boolean
|
|
required: false
|
|
default: false
|
|
|
|
jobs:
|
|
run-tests:
|
|
defaults:
|
|
run:
|
|
# Explicitly set the shell to bash to override Windows's default (cmd)
|
|
shell: bash
|
|
runs-on: ${{ inputs.runner }}
|
|
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
|
|
(contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') ||
|
|
(contains(inputs.runner, 'windows-x86') && null) }}
|
|
|
|
name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
|
|
|
|
env:
|
|
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
|
|
JAXCI_PYTHON: "python${{ inputs.python }}"
|
|
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
|
|
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
|
|
|
|
steps:
|
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
- name: Set env vars for use in artifact download URL
|
|
run: |
|
|
os=$(uname -s | awk '{print tolower($0)}')
|
|
arch=$(uname -m)
|
|
|
|
# Adjust os and arch for Windows
|
|
if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then
|
|
os="win"
|
|
arch="amd64"
|
|
fi
|
|
|
|
# Get the major and minor version of Python.
|
|
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
|
|
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')
|
|
|
|
echo "OS=${os}" >> $GITHUB_ENV
|
|
echo "ARCH=${arch}" >> $GITHUB_ENV
|
|
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
|
|
- name: Download jaxlib wheel from GCS (non-Windows runs)
|
|
id: download-wheel-artifacts-nw
|
|
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
|
# fails. Instead, we verify the outcome in the step below so that we can print a more
|
|
# informative error message.
|
|
continue-on-error: true
|
|
if: ${{ !contains(inputs.runner, 'windows-x86') }}
|
|
run: |
|
|
mkdir -p $(pwd)/dist &&
|
|
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
|
|
|
|
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
|
|
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
|
|
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
|
|
fi
|
|
- name: Download jaxlib wheel from GCS (Windows runs)
|
|
id: download-wheel-artifacts-w
|
|
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
|
# fails. Instead, we verify the outcome in step below so that we can print a more
|
|
# informative error message.
|
|
continue-on-error: true
|
|
if: ${{ contains(inputs.runner, 'windows-x86') }}
|
|
shell: cmd
|
|
run: |
|
|
mkdir dist
|
|
@REM Use `call` so that we can run sequential gsutil commands on Windows
|
|
@REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652
|
|
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
|
|
|
|
@REM Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
|
|
if not "${{ inputs.install-jax-current-commit }}"=="1" (
|
|
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
|
|
)
|
|
- name: Skip the test run if the wheel artifacts were not downloaded successfully
|
|
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
|
|
run: |
|
|
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
|
|
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
|
|
echo "Skipping the test run."
|
|
exit 1
|
|
- name: Install Python dependencies
|
|
run: $JAXCI_PYTHON -m uv pip install -r build/requirements.in
|
|
# Halt for testing
|
|
- name: Wait For Connection
|
|
uses: google-ml-infra/actions/ci_connection@main
|
|
with:
|
|
halt-dispatch-input: ${{ inputs.halt-for-connection }}
|
|
- name: Run Pytest CPU tests
|
|
timeout-minutes: 60
|
|
run: ./ci/run_pytest_cpu.sh
|