mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
212 lines
7.8 KiB
YAML
212 lines
7.8 KiB
YAML
name: ROCm CPU CI
|
|
|
|
# We test all supported Python versions as follows:
|
|
# - 3.10 : Documentation build
|
|
# - 3.10 : Part of Matrix with NumPy dispatch
|
|
# - 3.10 : Part of Matrix
|
|
# - 3.11 : Part of Matrix
|
|
|
|
on:
|
|
# Trigger the workflow on push or pull request,
|
|
# but only for the main branch
|
|
push:
|
|
branches:
|
|
- rocm-main
|
|
pull_request:
|
|
branches:
|
|
- rocm-main
|
|
|
|
permissions:
|
|
contents: read # to fetch code
|
|
actions: write # to cancel previous workflows
|
|
|
|
concurrency:
|
|
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
|
cancel-in-progress: true
|
|
|
|
jobs:
|
|
lint_and_typecheck:
|
|
runs-on: ubuntu-latest
|
|
timeout-minutes: 5
|
|
steps:
|
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
- name: Set up Python 3.11
|
|
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
|
with:
|
|
python-version: 3.11
|
|
- run: python -m pip install pre-commit
|
|
- uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
|
with:
|
|
path: ~/.cache/pre-commit
|
|
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
|
|
- run: pre-commit run --show-diff-on-failure --color=always --all-files
|
|
|
|
build:
|
|
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
|
|
runs-on: ROCM-Ubuntu
|
|
timeout-minutes: 60
|
|
strategy:
|
|
matrix:
|
|
# Test the oldest and newest supported Python versions here.
|
|
include:
|
|
- name-prefix: "with 3.10"
|
|
python-version: "3.10"
|
|
enable-x64: 1
|
|
prng-upgrade: 1
|
|
num_generated_cases: 1
|
|
- name-prefix: "with 3.13"
|
|
python-version: "3.13"
|
|
enable-x64: 0
|
|
prng-upgrade: 0
|
|
num_generated_cases: 1
|
|
steps:
|
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
- name: Set up Python ${{ matrix.python-version }}
|
|
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
|
with:
|
|
python-version: ${{ matrix.python-version }}
|
|
- name: Install dependencies
|
|
run: |
|
|
pip install uv~=0.5.30
|
|
uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt
|
|
|
|
- name: Run tests
|
|
env:
|
|
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
|
|
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
|
|
JAX_ENABLE_CUSTOM_PRNG: ${{ matrix.prng-upgrade }}
|
|
JAX_THREEFRY_PARTITIONABLE: ${{ matrix.prng-upgrade }}
|
|
JAX_ENABLE_CHECKS: true
|
|
JAX_SKIP_SLOW_TESTS: true
|
|
PY_COLORS: 1
|
|
run: |
|
|
uv pip install --system -e .
|
|
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
|
|
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
|
|
echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG"
|
|
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
|
|
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
|
|
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
|
|
pytest -n 4 --tb=short --maxfail=20 tests examples
|
|
|
|
|
|
documentation:
|
|
name: Documentation - test code snippets
|
|
runs-on: ROCM-Ubuntu
|
|
timeout-minutes: 10
|
|
strategy:
|
|
matrix:
|
|
python-version: ['3.10']
|
|
steps:
|
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
- name: Set up Python ${{ matrix.python-version }}
|
|
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
|
with:
|
|
python-version: ${{ matrix.python-version }}
|
|
- name: Install dependencies
|
|
run: |
|
|
pip install uv~=0.5.30
|
|
uv pip install --system -r docs/requirements.txt
|
|
- name: Test documentation
|
|
env:
|
|
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
|
|
JAX_TRACEBACK_FILTERING: "off"
|
|
JAX_ARRAY: 1
|
|
PY_COLORS: 1
|
|
run: |
|
|
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
|
|
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py
|
|
|
|
|
|
documentation_render:
|
|
name: Documentation - render documentation
|
|
runs-on: ubuntu-latest
|
|
timeout-minutes: 20
|
|
strategy:
|
|
matrix:
|
|
python-version: ['3.10']
|
|
steps:
|
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
- name: Set up Python ${{ matrix.python-version }}
|
|
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
|
with:
|
|
python-version: ${{ matrix.python-version }}
|
|
- name: Install dependencies
|
|
run: |
|
|
pip install uv~=0.5.30
|
|
uv pip install --system -r docs/requirements.txt
|
|
- name: Render documentation
|
|
run: |
|
|
sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html
|
|
|
|
jax2tf_test:
|
|
name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
|
|
runs-on: ${{ matrix.os }}
|
|
timeout-minutes: 30
|
|
strategy:
|
|
matrix:
|
|
# Test the oldest supported Python version here.
|
|
include:
|
|
- python-version: "3.10"
|
|
os: ubuntu-latest
|
|
enable-x64: 0
|
|
num_generated_cases: 10
|
|
steps:
|
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
- name: Set up Python ${{ matrix.python-version }}
|
|
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
|
with:
|
|
python-version: ${{ matrix.python-version }}
|
|
- name: Install dependencies
|
|
run: |
|
|
pip install uv~=0.5.30
|
|
uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt
|
|
uv pip install --system --pre tensorflow==2.19.0rc0
|
|
|
|
- name: Run tests
|
|
env:
|
|
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
|
|
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
|
|
JAX_ENABLE_CHECKS: true
|
|
JAX_SKIP_SLOW_TESTS: true
|
|
PY_COLORS: 1
|
|
run: |
|
|
uv pip install --system -e .
|
|
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
|
|
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
|
|
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
|
|
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
|
|
pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py
|
|
|
|
ffi:
|
|
name: FFI example
|
|
runs-on: ROCM-Ubuntu
|
|
timeout-minutes: 30
|
|
steps:
|
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
- name: Set up Python
|
|
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
|
with:
|
|
python-version: 3.12
|
|
- name: Install JAX
|
|
run: |
|
|
pip install uv~=0.5.30
|
|
pip install uv
|
|
uv pip install --system .
|
|
- name: Build and install example project
|
|
run: uv pip install --system ./examples/ffi[test]
|
|
env:
|
|
# We test building using GCC instead of clang. All other JAX builds use
|
|
# clang, but it is useful to make sure that FFI users can compile using
|
|
# a different toolchain. GCC is the default compiler on the
|
|
# 'ubuntu-latest' runner, but we still set this explicitly just to be
|
|
# clear.
|
|
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
|
|
- name: Run CPU tests
|
|
run: python -m pytest examples/ffi/tests
|
|
env:
|
|
JAX_PLATFORM_NAME: cpu
|
|
- name: Run GPU tests
|
|
run: python -m pytest examples/ffi/tests
|
|
|