name: CI - with Numpy/Scipy nightly wheels (nightly) # This configures a github action that runs the JAX test suite against nightly development builds # of numpy and scipy, in order to catch issues with new package versions prior to their release. # Unlike our other CI, this is one that we expect to fail frequently, and so we don't run it against # every commit and PR in the repository. Rather, we run it on a schedule, and failures lead to an # issue being created or updated. # Portions of this adapted from https://github.com/pydata/xarray/blob/main/.github/workflows/upstream-dev-ci.yaml concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true on: schedule: - cron: "0 12 * * *" # Daily at 12:00 UTC workflow_dispatch: # allows triggering the workflow run manually pull_request: # Automatically trigger on pull requests affecting this file branches: - main paths: - '**workflows/upstream-nightly.yml' jobs: upstream-dev: runs-on: ROCM-Ubuntu permissions: contents: read issues: write # for failed-build-issue strategy: fail-fast: false matrix: python-version: ["3.13"] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements run: | pip install uv~=0.5.30 uv pip install --system .[ci] -r build/test-requirements.txt - name: Install numpy & scipy development versions run: | uv pip install \ --system \ -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \ --no-deps \ --pre \ --upgrade \ numpy \ scipy python -c "import numpy; print(f'{numpy.__version__=}')" python -c "import scipy; print(f'{scipy.__version__=}')" - name: Run tests if: success() id: status env: JAX_NUM_GENERATED_CASES: 1 JAX_ENABLE_X64: true JAX_ENABLE_CHECKS: true JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" pytest -n 2 --tb=short --maxfail=20 tests examples - name: Notify failed build uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0 if: failure() && github.event.pull_request == null with: github-token: ${{ secrets.GITHUB_TOKEN }}