mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
54 lines
1.6 KiB
YAML
54 lines
1.6 KiB
YAML
# JAX-Metal plugin CI
|
|
|
|
name: Jax-Metal CI
|
|
on:
|
|
schedule:
|
|
- cron: "0 12 * * *" # Daily at 12:00 UTC
|
|
workflow_dispatch: # allows triggering the workflow run manually
|
|
pull_request: # Automatically trigger on pull requests affecting this file
|
|
branches:
|
|
- main
|
|
paths:
|
|
- '**workflows/metal_plugin_ci.yml'
|
|
|
|
concurrency:
|
|
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
|
cancel-in-progress: true
|
|
|
|
jobs:
|
|
jax-metal-plugin-test:
|
|
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
jaxlib-version: ["pypi_latest", "nightly"]
|
|
name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})"
|
|
runs-on: [self-hosted, macOS, ARM64]
|
|
|
|
steps:
|
|
- name: Get repo
|
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
with:
|
|
path: jax
|
|
- name: Setup build and test enviroment
|
|
run: |
|
|
rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv
|
|
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
|
|
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
|
|
pip install uv~=0.5.30
|
|
uv pip install -U pip numpy wheel absl-py pytest
|
|
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
|
|
uv pip install --pre jaxlib \
|
|
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
|
fi;
|
|
cd jax
|
|
uv pip install . jax-metal
|
|
- name: Run test
|
|
run: |
|
|
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
|
|
export ENABLE_PJRT_COMPATIBILITY=1
|
|
cd jax
|
|
pytest tests/lax_metal_test.py
|
|
|
|
|