mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

Per the deprecation policy (https://jax.readthedocs.io/en/latest/deprecation.html), Python 3.6 support has been due for removal since June 23, 2020.
151 lines
4.8 KiB
YAML
151 lines
4.8 KiB
YAML
name: CI
|
|
|
|
on:
|
|
# Trigger the workflow on push or pull request,
|
|
# but only for the main branch
|
|
push:
|
|
branches:
|
|
- main
|
|
pull_request:
|
|
branches:
|
|
- main
|
|
|
|
jobs:
|
|
lint_and_typecheck:
|
|
runs-on: ubuntu-latest
|
|
timeout-minutes: 5
|
|
steps:
|
|
- name: Cancel previous
|
|
uses: styfle/cancel-workflow-action@0.8.0
|
|
with:
|
|
access_token: ${{ github.token }}
|
|
if: ${{github.ref != 'refs/head/main'}}
|
|
- uses: actions/checkout@v2
|
|
- name: Set up Python 3.8
|
|
uses: actions/setup-python@v2
|
|
with:
|
|
python-version: 3.8
|
|
- uses: pre-commit/action@v2.0.0
|
|
|
|
build:
|
|
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
|
|
runs-on: ${{ matrix.os }}
|
|
timeout-minutes: 60
|
|
strategy:
|
|
matrix:
|
|
include:
|
|
- name-prefix: "with many tests"
|
|
python-version: 3.8
|
|
os: ubuntu-latest
|
|
enable-x64: 0
|
|
package-overrides: "none"
|
|
num_generated_cases: 25
|
|
use-latest-jaxlib: false
|
|
- name-prefix: "with numpy-dispatch"
|
|
python-version: 3.9
|
|
os: ubuntu-latest
|
|
enable-x64: 1
|
|
# Test experimental NumPy dispatch
|
|
package-overrides: "git+https://github.com/seberg/numpy-dispatch.git"
|
|
num_generated_cases: 10
|
|
use-latest-jaxlib: false
|
|
- name-prefix: "with internal numpy/scipy"
|
|
python-version: 3.7
|
|
os: ubuntu-latest
|
|
enable-x64: 1
|
|
# Test with numpy version that matches Google's internal version
|
|
package-overrides: "numpy==1.19.5 scipy==1.2.1"
|
|
num_generated_cases: 10
|
|
use-latest-jaxlib: false
|
|
- name-prefix: "with 3.7"
|
|
python-version: 3.7
|
|
os: ubuntu-latest
|
|
enable-x64: 0
|
|
# Test with the oldest legal NumPy version.
|
|
package-overrides: "numpy==1.17.5 scipy==1.2.1"
|
|
num_generated_cases: 8
|
|
# Test against latest jaxlib
|
|
use-latest-jaxlib: true
|
|
steps:
|
|
- name: Cancel previous
|
|
uses: styfle/cancel-workflow-action@0.7.0
|
|
with:
|
|
access_token: ${{ github.token }}
|
|
if: ${{github.ref != 'refs/head/main'}}
|
|
- uses: actions/checkout@v2
|
|
- name: Set up Python ${{ matrix.python-version }}
|
|
uses: actions/setup-python@v2
|
|
with:
|
|
python-version: ${{ matrix.python-version }}
|
|
- name: Get pip cache dir
|
|
id: pip-cache
|
|
run: |
|
|
python -m pip install --upgrade pip wheel
|
|
echo "::set-output name=dir::$(pip cache dir)"
|
|
- name: pip cache
|
|
uses: actions/cache@v2
|
|
with:
|
|
path: ${{ steps.pip-cache.outputs.dir }}
|
|
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
|
restore-keys: |
|
|
${{ runner.os }}-pip-
|
|
- name: Install dependencies
|
|
run: |
|
|
pip install -r build/test-requirements.txt
|
|
if [ "${{ matrix.package-overrides }}" != "none" ]; then
|
|
pip install ${{ matrix.package-overrides }}
|
|
fi
|
|
if [ "${{ matrix.use-latest-jaxlib }}" == "true" ]; then
|
|
pip install -e .[cpu]
|
|
fi
|
|
|
|
- name: Run tests
|
|
env:
|
|
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
|
|
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
|
|
run: |
|
|
pip install -e .
|
|
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
|
|
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
|
|
pytest -n auto --tb=short tests examples
|
|
|
|
|
|
documentation:
|
|
runs-on: ubuntu-latest
|
|
timeout-minutes: 5
|
|
strategy:
|
|
matrix:
|
|
python-version: [3.7]
|
|
steps:
|
|
- name: Cancel previous
|
|
uses: styfle/cancel-workflow-action@0.7.0
|
|
with:
|
|
access_token: ${{ github.token }}
|
|
if: ${{github.ref != 'refs/head/main'}}
|
|
- uses: actions/checkout@v2
|
|
- name: Set up Python ${{ matrix.python-version }}
|
|
uses: actions/setup-python@v2
|
|
with:
|
|
python-version: ${{ matrix.python-version }}
|
|
- name: Get pip cache dir
|
|
id: pip-cache
|
|
run: |
|
|
python -m pip install --upgrade pip wheel
|
|
echo "::set-output name=dir::$(pip cache dir)"
|
|
- name: pip cache
|
|
uses: actions/cache@v2
|
|
with:
|
|
path: ${{ steps.pip-cache.outputs.dir }}
|
|
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
|
restore-keys: |
|
|
${{ runner.os }}-pip-
|
|
- name: Install dependencies
|
|
run: |
|
|
pip install -r docs/requirements.txt
|
|
- name: Test documentation
|
|
env:
|
|
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
|
|
run: |
|
|
pytest -n 1 --tb=short docs
|
|
pytest -n 1 --tb=short --doctest-modules --ignore=jax/experimental/jax2tf jax
|