name: ROCm GPU CI on: # Trigger the workflow on push or pull request, # but only for the rocm-main branch push: branches: - rocm-main - 'rocm-jaxlib-v*' pull_request: branches: - rocm-main - 'rocm-jaxlib-v*' concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true jobs: build-jax-in-docker: # strategy and matrix come here runs-on: mi-250 env: BASE_IMAGE: "ubuntu:22.04" TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} PYTHON_VERSION: "3.10" ROCM_VERSION: "6.2.4" WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} steps: - name: Clean up old runs run: | ls # Make sure that we own all of the files so that we have permissions to delete them docker run -v "./:/jax" ubuntu /bin/bash -c "chown -R $UID /jax/workdir_* || true" # Remove any old work directories from this machine rm -rf workdir_* ls - name: Print system info run: | whoami printenv df -h rocm-smi - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: ${{ env.WORKSPACE_DIR }} - name: Build JAX run: | pushd $WORKSPACE_DIR python3 build/rocm/ci_build \ --rocm-version $ROCM_VERSION \ --base-docker $BASE_IMAGE \ --python-versions $PYTHON_VERSION \ --compiler=clang \ dist_docker \ --image-tag $TEST_IMAGE - name: Archive jax wheels uses: actions/upload-artifact@v4 with: name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }} path: ./dist/*.whl - name: Run tests env: GPU_COUNT: "8" GFX: "gfx90a" run: | cd $WORKSPACE_DIR python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py"