mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
122 lines
5.0 KiB
YAML
122 lines
5.0 KiB
YAML
# CI - Wheel Tests (Continuous)
|
|
#
|
|
# This workflow builds JAX artifacts and runs CPU/CUDA tests.
|
|
#
|
|
# It orchestrates the following:
|
|
# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and
|
|
# uploads it to a GCS bucket.
|
|
# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow to download the jaxlib wheel that was built
|
|
# in the previous step and runs CPU tests.
|
|
# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and
|
|
# uploads them to a GCS bucket.
|
|
# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow to download the jaxlib and CUDA artifacts
|
|
# that were built in the previous steps and runs the CUDA tests.
|
|
# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow to download the jaxlib and
|
|
# CUDA artifacts that were built in the previous steps and runs the CUDA
|
|
# tests using Bazel.
|
|
|
|
name: CI - Wheel Tests (Continuous)
|
|
|
|
on:
|
|
schedule:
|
|
- cron: "0 */2 * * *" # Run once every 2 hours
|
|
|
|
concurrency:
|
|
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
|
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
|
|
|
|
jobs:
|
|
build-jaxlib-artifact:
|
|
uses: ./.github/workflows/build_artifacts.yml
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
|
|
runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
|
|
artifact: ["jaxlib"]
|
|
python: ["3.10"]
|
|
with:
|
|
runner: ${{ matrix.runner }}
|
|
artifact: ${{ matrix.artifact }}
|
|
python: ${{ matrix.python }}
|
|
clone_main_xla: 1
|
|
upload_artifacts_to_gcs: true
|
|
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
|
|
|
build-cuda-artifacts:
|
|
uses: ./.github/workflows/build_artifacts.yml
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
# Python values need to match the matrix stategy in the CUDA tests job below
|
|
runner: ["linux-x86-n2-16"]
|
|
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
|
|
python: ["3.10",]
|
|
with:
|
|
runner: ${{ matrix.runner }}
|
|
artifact: ${{ matrix.artifact }}
|
|
python: ${{ matrix.python }}
|
|
clone_main_xla: 1
|
|
upload_artifacts_to_gcs: true
|
|
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
|
|
|
run-pytest-cpu:
|
|
needs: build-jaxlib-artifact
|
|
uses: ./.github/workflows/pytest_cpu.yml
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
# Runner OS and Python values need to match the matrix stategy in the
|
|
# build_jaxlib_artifact job above
|
|
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
|
|
python: ["3.10",]
|
|
enable-x64: [1, 0]
|
|
with:
|
|
runner: ${{ matrix.runner }}
|
|
python: ${{ matrix.python }}
|
|
enable-x64: ${{ matrix.enable-x64 }}
|
|
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
|
|
|
run-pytest-cuda:
|
|
needs: [build-jaxlib-artifact, build-cuda-artifacts]
|
|
uses: ./.github/workflows/pytest_cuda.yml
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
# Python values need to match the matrix stategy in the artifact build jobs above
|
|
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"]
|
|
python: ["3.10",]
|
|
cuda: ["12.3", "12.1"]
|
|
enable-x64: [1, 0]
|
|
exclude:
|
|
# Run only a single configuration on H100 to save resources
|
|
- runner: "linux-x86-a3-8g-h100-8gpu"
|
|
python: "3.10"
|
|
cuda: "12.1"
|
|
- runner: "linux-x86-a3-8g-h100-8gpu"
|
|
python: "3.10"
|
|
enable-x64: 0
|
|
with:
|
|
runner: ${{ matrix.runner }}
|
|
python: ${{ matrix.python }}
|
|
cuda: ${{ matrix.cuda }}
|
|
enable-x64: ${{ matrix.enable-x64 }}
|
|
# GCS upload URI is the same for both artifact build jobs
|
|
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
|
|
|
run-bazel-test-cuda:
|
|
needs: [build-jaxlib-artifact, build-cuda-artifacts]
|
|
uses: ./.github/workflows/bazel_cuda_non_rbe.yml
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
# Python values need to match the matrix stategy in the build artifacts job above
|
|
runner: ["linux-x86-g2-48-l4-4gpu",]
|
|
python: ["3.10",]
|
|
enable-x64: [1, 0]
|
|
with:
|
|
runner: ${{ matrix.runner }}
|
|
python: ${{ matrix.python }}
|
|
enable-x64: ${{ matrix.enable-x64 }}
|
|
# GCS upload URI is the same for both artifact build jobs
|
|
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} |