mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
CI: add jax2tf_test action
This commit is contained in:
parent
00528b9858
commit
d03572eac5
54
.github/workflows/ci-build.yaml
vendored
54
.github/workflows/ci-build.yaml
vendored
@ -178,3 +178,57 @@ jobs:
|
||||
- name: Render documentation
|
||||
run: |
|
||||
sphinx-build --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html
|
||||
|
||||
|
||||
jax2tf_test:
|
||||
name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
matrix:
|
||||
# Test the oldest supported Python version here.
|
||||
include:
|
||||
- python-version: "3.10"
|
||||
os: ubuntu-latest
|
||||
enable-x64: 0
|
||||
num_generated_cases: 10
|
||||
steps:
|
||||
- name: Cancel previous
|
||||
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1
|
||||
with:
|
||||
access_token: ${{ github.token }}
|
||||
if: ${{github.ref != 'refs/heads/main'}}
|
||||
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # ratchet:actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Get pip cache dir
|
||||
id: pip-cache
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install .[minimum-jaxlib] tensorflow -r build/test-requirements.txt
|
||||
|
||||
- name: Run tests
|
||||
env:
|
||||
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
|
||||
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
|
||||
JAX_ENABLE_CHECKS: true
|
||||
JAX_SKIP_SLOW_TESTS: true
|
||||
PY_COLORS: 1
|
||||
run: |
|
||||
pip install -e .
|
||||
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
|
||||
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
|
||||
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
|
||||
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
|
||||
pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py
|
||||
|
Loading…
x
Reference in New Issue
Block a user