rocm_jax/.github/workflows/pytest_cuda.yml
Nitin Srinivasan 4b4f2f9cb9 Use uv to install Python packages
PiperOrigin-RevId: 730499307
2025-02-24 10:13:39 -08:00

117 lines
5.1 KiB
YAML

# CI - Pytest CUDA
#
# This workflow runs the CUDA tests with Pytest. It can only be triggered by other workflows via
# `workflow_call`. It is used by the `CI - Wheel Tests` workflows to run the Pytest CUDA tests.
#
# It consists of the following job:
# run-tests:
# - Downloads the jaxlib and CUDA artifacts from a GCS bucket.
# - Executes the `run_pytest_cuda.sh` script, which performs the following actions:
# - Installs the downloaded wheel artifacts.
# - Runs the CUDA tests with Pytest.
name: CI - Pytest CUDA
on:
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-n2-16"
python:
description: "Which python version to test?"
type: string
required: true
default: "3.12"
cuda:
description: "Which CUDA version to test?"
type: string
required: true
default: "12.3"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
required: true
default: "0"
install-jax-current-commit:
description: "Should the 'jax' package be installed from the current commit?"
type: string
required: true
default: "1"
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
required: true
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: boolean
required: false
default: false
jobs:
run-tests:
defaults:
run:
# Explicitly set the shell to bash
shell: bash
runs-on: ${{ inputs.runner }}
# TODO: Update to the generic ML ecosystem test containers when they are ready.
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }}
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
env:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set env vars for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)
# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')
echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
- name: Download the wheel artifacts from GCS
id: download-wheel-artifacts
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the next step so that we can print a more
# informative error message.
continue-on-error: true
run: |
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
fi
- name: Skip the test run if the wheel artifacts were not downloaded successfully
if: steps.download-wheel-artifacts.outcome == 'failure'
run: |
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
echo "Skipping the test run."
exit 1
- name: Install Python dependencies
run: $JAXCI_PYTHON -m uv pip install -r build/requirements.in
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Pytest CUDA tests
timeout-minutes: 60
run: ./ci/run_pytest_cuda.sh