mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00

This test is now independent of jax2tf. Move it out and rename it export_harnesses_multi_platform_test.py. We disable the test in GitHub CI, because it is very large, pending some changes to ensure it parallelizes well. The test is still running in internal CI. This is matching the current behavior, since jax2tf tests are only run internally. PiperOrigin-RevId: 583603863
200 lines
7.0 KiB
YAML
200 lines
7.0 KiB
YAML
name: CI
|
|
|
|
# We test all supported Python versions as follows:
|
|
# - 3.9 : Documentation build
|
|
# - 3.9 : Part of Matrix with NumPy dispatch
|
|
# - 3.10 : Part of Matrix
|
|
# - 3.11 : Part of Matrix
|
|
|
|
on:
|
|
# Trigger the workflow on push or pull request,
|
|
# but only for the main branch
|
|
push:
|
|
branches:
|
|
- main
|
|
pull_request:
|
|
branches:
|
|
- main
|
|
|
|
permissions:
|
|
contents: read # to fetch code
|
|
actions: write # to cancel previous workflows
|
|
|
|
jobs:
|
|
lint_and_typecheck:
|
|
runs-on: ubuntu-latest
|
|
timeout-minutes: 5
|
|
steps:
|
|
- name: Cancel previous
|
|
uses: styfle/cancel-workflow-action@0.12.0
|
|
with:
|
|
access_token: ${{ github.token }}
|
|
if: ${{github.ref != 'refs/heads/main'}}
|
|
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
|
|
- name: Set up Python 3.11
|
|
uses: actions/setup-python@v4
|
|
with:
|
|
python-version: 3.11
|
|
- uses: pre-commit/action@v3.0.0
|
|
|
|
build:
|
|
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
|
|
runs-on: ${{ matrix.os }}
|
|
timeout-minutes: 60
|
|
strategy:
|
|
matrix:
|
|
include:
|
|
- name-prefix: "with numpy-dispatch"
|
|
python-version: "3.9"
|
|
os: ubuntu-20.04-16core
|
|
enable-x64: 1
|
|
prng-upgrade: 1
|
|
# Test experimental NumPy dispatch
|
|
package-overrides: "git+https://github.com/seberg/numpy-dispatch.git"
|
|
num_generated_cases: 1
|
|
use-latest-jaxlib: false
|
|
- name-prefix: "with 3.10"
|
|
python-version: "3.10"
|
|
os: ubuntu-20.04-16core
|
|
enable-x64: 0
|
|
prng-upgrade: 0
|
|
package-overrides: "none"
|
|
num_generated_cases: 1
|
|
use-latest-jaxlib: false
|
|
- name-prefix: "with 3.11"
|
|
python-version: "3.11"
|
|
os: ubuntu-20.04-16core
|
|
enable-x64: 0
|
|
prng-upgrade: 0
|
|
package-overrides: "none"
|
|
num_generated_cases: 1
|
|
use-latest-jaxlib: false
|
|
steps:
|
|
- name: Cancel previous
|
|
uses: styfle/cancel-workflow-action@0.12.0
|
|
with:
|
|
access_token: ${{ github.token }}
|
|
if: ${{github.ref != 'refs/heads/main'}}
|
|
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
|
|
- name: Set up Python ${{ matrix.python-version }}
|
|
uses: actions/setup-python@v4
|
|
with:
|
|
python-version: ${{ matrix.python-version }}
|
|
- name: Get pip cache dir
|
|
id: pip-cache
|
|
run: |
|
|
python -m pip install --upgrade pip wheel
|
|
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
|
- name: pip cache
|
|
uses: actions/cache@v3
|
|
with:
|
|
path: ${{ steps.pip-cache.outputs.dir }}
|
|
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
|
- name: Install dependencies
|
|
run: |
|
|
pip install -r build/test-requirements.txt
|
|
if [ "${{ matrix.package-overrides }}" != "none" ]; then
|
|
pip install ${{ matrix.package-overrides }}
|
|
fi
|
|
if [ "${{ matrix.use-latest-jaxlib }}" == "true" ]; then
|
|
pip install .[cpu]
|
|
else
|
|
pip install .[minimum-jaxlib]
|
|
fi
|
|
|
|
- name: Run tests
|
|
env:
|
|
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
|
|
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
|
|
JAX_ENABLE_CUSTOM_PRNG: ${{ matrix.prng-upgrade }}
|
|
JAX_THREEFRY_PARTITIONABLE: ${{ matrix.prng-upgrade }}
|
|
JAX_ENABLE_CHECKS: true
|
|
JAX_SKIP_SLOW_TESTS: true
|
|
PY_COLORS: 1
|
|
run: |
|
|
pip install -e .
|
|
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
|
|
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
|
|
echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG"
|
|
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
|
|
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
|
|
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
|
|
pytest -n auto --tb=short --maxfail=20 tests examples
|
|
|
|
|
|
documentation:
|
|
name: Documentation - test code snippets
|
|
runs-on: ubuntu-latest
|
|
timeout-minutes: 10
|
|
strategy:
|
|
matrix:
|
|
python-version: [3.9]
|
|
steps:
|
|
- name: Cancel previous
|
|
uses: styfle/cancel-workflow-action@0.12.0
|
|
with:
|
|
access_token: ${{ github.token }}
|
|
if: ${{github.ref != 'refs/heads/main'}}
|
|
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
|
|
- name: Set up Python ${{ matrix.python-version }}
|
|
uses: actions/setup-python@v4
|
|
with:
|
|
python-version: ${{ matrix.python-version }}
|
|
- name: Get pip cache dir
|
|
id: pip-cache
|
|
run: |
|
|
python -m pip install --upgrade pip wheel
|
|
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
|
- name: pip cache
|
|
uses: actions/cache@v3
|
|
with:
|
|
path: ${{ steps.pip-cache.outputs.dir }}
|
|
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
|
- name: Install dependencies
|
|
run: |
|
|
pip install -r docs/requirements.txt
|
|
- name: Test documentation
|
|
env:
|
|
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
|
|
JAX_ARRAY: 1
|
|
PY_COLORS: 1
|
|
run: |
|
|
pytest -n auto --tb=short docs
|
|
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
|
|
|
|
|
|
documentation_render:
|
|
name: Documentation - render documentation
|
|
runs-on: ubuntu-latest
|
|
timeout-minutes: 10
|
|
strategy:
|
|
matrix:
|
|
python-version: [3.9]
|
|
steps:
|
|
- name: Cancel previous
|
|
uses: styfle/cancel-workflow-action@0.12.0
|
|
with:
|
|
access_token: ${{ github.token }}
|
|
if: ${{github.ref != 'refs/heads/main'}}
|
|
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
|
|
- name: Set up Python ${{ matrix.python-version }}
|
|
uses: actions/setup-python@v4
|
|
with:
|
|
python-version: ${{ matrix.python-version }}
|
|
- name: Get pip cache dir
|
|
id: pip-cache
|
|
run: |
|
|
python -m pip install --upgrade pip wheel
|
|
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
|
- name: pip cache
|
|
uses: actions/cache@v3
|
|
with:
|
|
path: ${{ steps.pip-cache.outputs.dir }}
|
|
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
|
- name: Install dependencies
|
|
run: |
|
|
pip install -r docs/requirements.txt
|
|
- name: Render documentation
|
|
run: |
|
|
sphinx-build --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html -j auto
|