name: ROCm CPU CI

# We test all supported Python versions as follows:
# - 3.10 : Documentation build
# - 3.10 : Part of Matrix with NumPy dispatch
# - 3.10 : Part of Matrix
# - 3.11 : Part of Matrix

on:
  # Trigger the workflow on push or pull request,
  # but only for the main branch
  push:
    branches:
      - rocm-main
  pull_request:
    branches:
      - rocm-main

permissions:
  contents: read  # to fetch code
  actions: write  # to cancel previous workflows

concurrency:
  group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
  cancel-in-progress: true

jobs:
  lint_and_typecheck:
    runs-on: ubuntu-latest
    timeout-minutes: 5
    steps:
      - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683  # v4.2.2
      - name: Set up Python 3.11
        uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38  # v5.4.0
        with:
          python-version: 3.11
      - run: python -m pip install pre-commit
      - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
        with:
          path: ~/.cache/pre-commit
          key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
      - run: pre-commit run --show-diff-on-failure --color=always --all-files

  build:
    name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
    runs-on: ROCM-Ubuntu
    timeout-minutes: 60
    strategy:
      matrix:
        # Test the oldest and newest supported Python versions here.
        include:
          - name-prefix: "with 3.10"
            python-version: "3.10"
            enable-x64: 1
            prng-upgrade: 1
            num_generated_cases: 1
          - name-prefix: "with 3.13"
            python-version: "3.13"
            enable-x64: 0
            prng-upgrade: 0
            num_generated_cases: 1
    steps:
    - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683  # v4.2.2
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38  # v5.4.0
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        pip install uv~=0.5.30
        uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt

    - name: Run tests
      env:
        JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
        JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
        JAX_ENABLE_CUSTOM_PRNG: ${{ matrix.prng-upgrade }}
        JAX_THREEFRY_PARTITIONABLE: ${{ matrix.prng-upgrade }}
        JAX_ENABLE_CHECKS: true
        JAX_SKIP_SLOW_TESTS: true
        PY_COLORS: 1
      run: |
        uv pip install --system -e .
        echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
        echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
        echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG"
        echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
        echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
        echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
        pytest -n 4 --tb=short --maxfail=20 tests examples


  documentation:
    name: Documentation - test code snippets
    runs-on: ROCM-Ubuntu
    timeout-minutes: 10
    strategy:
      matrix:
        python-version: ['3.10']
    steps:
    - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683  # v4.2.2
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38  # v5.4.0
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        pip install uv~=0.5.30
        uv pip install --system -r docs/requirements.txt
    - name: Test documentation
      env:
        XLA_FLAGS: "--xla_force_host_platform_device_count=8"
        JAX_TRACEBACK_FILTERING: "off"
        JAX_ARRAY: 1
        PY_COLORS: 1
      run: |
        pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
        pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py


  documentation_render:
    name: Documentation - render documentation
    runs-on: ubuntu-latest
    timeout-minutes: 20
    strategy:
      matrix:
        python-version: ['3.10']
    steps:
    - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683  # v4.2.2
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38  # v5.4.0
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        pip install uv~=0.5.30
        uv pip install --system -r docs/requirements.txt
    - name: Render documentation
      run: |
        sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html

  jax2tf_test:
    name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
    runs-on: ${{ matrix.os }}
    timeout-minutes: 30
    strategy:
      matrix:
        # Test the oldest supported Python version here.
        include:
          - python-version: "3.10"
            os: ubuntu-latest
            enable-x64: 0
            num_generated_cases: 10
    steps:
    - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683  # v4.2.2
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38  # v5.4.0
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        pip install uv~=0.5.30
        uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt
        uv pip install --system --pre tensorflow==2.19.0rc0

    - name: Run tests
      env:
        JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
        JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
        JAX_ENABLE_CHECKS: true
        JAX_SKIP_SLOW_TESTS: true
        PY_COLORS: 1
      run: |
        uv pip install --system -e .
        echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
        echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
        echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
        echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
        pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py

  ffi:
    name: FFI example
    runs-on: ROCM-Ubuntu
    timeout-minutes: 30
    steps:
    - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683  # v4.2.2
    - name: Set up Python
      uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38  # v5.4.0
      with:
        python-version: 3.12
    - name: Install JAX
      run: |
        pip install uv~=0.5.30
        pip install uv
        uv pip install --system .
    - name: Build and install example project
      run: uv pip install --system ./examples/ffi[test]
      env:
        # We test building using GCC instead of clang. All other JAX builds use
        # clang, but it is useful to make sure that FFI users can compile using
        # a different toolchain. GCC is the default compiler on the
        # 'ubuntu-latest' runner, but we still set this explicitly just to be
        # clear.
        CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
    - name: Run CPU tests
      run: python -m pytest examples/ffi/tests
      env:
        JAX_PLATFORM_NAME: cpu
    - name: Run GPU tests
      run: python -m pytest examples/ffi/tests