rocm_jax/.github/workflows/metal_plugin_ci.yml
Nitin Srinivasan 4b4f2f9cb9 Use uv to install Python packages
PiperOrigin-RevId: 730499307
2025-02-24 10:13:39 -08:00

54 lines
1.6 KiB
YAML

# JAX-Metal plugin CI
name: Jax-Metal CI
on:
schedule:
- cron: "0 12 * * *" # Daily at 12:00 UTC
workflow_dispatch: # allows triggering the workflow run manually
pull_request: # Automatically trigger on pull requests affecting this file
branches:
- main
paths:
- '**workflows/metal_plugin_ci.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
jax-metal-plugin-test:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["pypi_latest", "nightly"]
name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})"
runs-on: [self-hosted, macOS, ARM64]
steps:
- name: Get repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: jax
- name: Setup build and test enviroment
run: |
rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
pip install uv~=0.5.30
uv pip install -U pip numpy wheel absl-py pytest
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
uv pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
fi;
cd jax
uv pip install . jax-metal
- name: Run test
run: |
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
export ENABLE_PJRT_COMPATIBILITY=1
cd jax
pytest tests/lax_metal_test.py