name: ROCm CPU CI # We test all supported Python versions as follows: # - 3.10 : Documentation build # - 3.10 : Part of Matrix with NumPy dispatch # - 3.10 : Part of Matrix # - 3.11 : Part of Matrix on: # Trigger the workflow on push or pull request, # but only for the main branch push: branches: - rocm-main pull_request: branches: - rocm-main permissions: contents: read # to fetch code actions: write # to cancel previous workflows concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true jobs: lint_and_typecheck: runs-on: ubuntu-latest timeout-minutes: 5 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: 3.11 - run: python -m pip install pre-commit - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ~/.cache/pre-commit key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} - run: pre-commit run --show-diff-on-failure --color=always --all-files build: name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" runs-on: ROCM-Ubuntu timeout-minutes: 60 strategy: matrix: # Test the oldest and newest supported Python versions here. include: - name-prefix: "with 3.10" python-version: "3.10" enable-x64: 1 prng-upgrade: 1 num_generated_cases: 1 - name-prefix: "with 3.13" python-version: "3.13" enable-x64: 0 prng-upgrade: 0 num_generated_cases: 1 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | pip install uv~=0.5.30 uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt - name: Run tests env: JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }} JAX_ENABLE_X64: ${{ matrix.enable-x64 }} JAX_ENABLE_CUSTOM_PRNG: ${{ matrix.prng-upgrade }} JAX_THREEFRY_PARTITIONABLE: ${{ matrix.prng-upgrade }} JAX_ENABLE_CHECKS: true JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | uv pip install --system -e . echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG" echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE" echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" pytest -n 4 --tb=short --maxfail=20 tests examples documentation: name: Documentation - test code snippets runs-on: ROCM-Ubuntu timeout-minutes: 10 strategy: matrix: python-version: ['3.10'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | pip install uv~=0.5.30 uv pip install --system -r docs/requirements.txt - name: Test documentation env: XLA_FLAGS: "--xla_force_host_platform_device_count=8" JAX_TRACEBACK_FILTERING: "off" JAX_ARRAY: 1 PY_COLORS: 1 run: | pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py documentation_render: name: Documentation - render documentation runs-on: ubuntu-latest timeout-minutes: 20 strategy: matrix: python-version: ['3.10'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | pip install uv~=0.5.30 uv pip install --system -r docs/requirements.txt - name: Render documentation run: | sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html jax2tf_test: name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})" runs-on: ${{ matrix.os }} timeout-minutes: 30 strategy: matrix: # Test the oldest supported Python version here. include: - python-version: "3.10" os: ubuntu-latest enable-x64: 0 num_generated_cases: 10 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | pip install uv~=0.5.30 uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt uv pip install --system --pre tensorflow==2.19.0rc0 - name: Run tests env: JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }} JAX_ENABLE_X64: ${{ matrix.enable-x64 }} JAX_ENABLE_CHECKS: true JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | uv pip install --system -e . echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py ffi: name: FFI example runs-on: ROCM-Ubuntu timeout-minutes: 30 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: 3.12 - name: Install JAX run: | pip install uv~=0.5.30 pip install uv uv pip install --system . - name: Build and install example project run: uv pip install --system ./examples/ffi[test] env: # We test building using GCC instead of clang. All other JAX builds use # clang, but it is useful to make sure that FFI users can compile using # a different toolchain. GCC is the default compiler on the # 'ubuntu-latest' runner, but we still set this explicitly just to be # clear. CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON - name: Run CPU tests run: python -m pytest examples/ffi/tests env: JAX_PLATFORM_NAME: cpu - name: Run GPU tests run: python -m pytest examples/ffi/tests