rocm_jax/.github/workflows/wheel_tests_continuous.yml
2025-03-17 20:19:20 +00:00

192 lines
8.9 KiB
YAML

# CI - Wheel Tests (Continuous)
#
# This workflow builds JAX artifacts and runs CPU/CUDA tests.
#
# It orchestrates the following:
# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and
# uploads it to a GCS bucket.
# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel
# that was built in the previous step and runs CPU tests.
# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and
# uploads them to a GCS bucket.
# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA
# artifacts that were built in the previous steps and runs the CUDA tests.
# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib
# and CUDA artifacts that were built in the previous steps and runs the
# CUDA tests using Bazel.
name: CI - Wheel Tests (Continuous)
on:
schedule:
- cron: "0 */2 * * *" # Run once every 2 hours
workflow_dispatch: # allows triggering the workflow run manually
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
jobs:
build-jax-artifact:
uses: ./.github/workflows/build_artifacts.yml
name: "Build jax artifact"
with:
# Note that since jax is a pure python package, the runner OS and Python values do not
# matter. In addition, cloning main XLA also has no effect.
runner: "linux-x86-n2-16"
artifact: "jax"
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
build-jaxlib-artifact:
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-64"]
artifact: ["jaxlib"]
python: ["3.10"]
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
# values to the name and creates a separate entry for each matrix combination.
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
with:
runner: ${{ matrix.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ matrix.python }}
clone_main_xla: 1
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
build-cuda-artifacts:
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the CUDA tests job below
runner: ["linux-x86-n2-16"]
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
python: ["3.10",]
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
with:
runner: ${{ matrix.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ matrix.python }}
clone_main_xla: 1
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
run-pytest-cpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_cpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy in the
# build_jaxlib_artifact job above
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
python: ["3.10",]
enable-x64: [1, 0]
name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-pytest-cuda:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/pytest_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the artifact build jobs above
# See exlusions for what is fully tested
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"]
python: ["3.10",]
cuda: ["12.1","12.3","12.8"]
enable-x64: [1, 0]
exclude:
# L4 does not run on cuda 12.8 but tests other configs
- runner: "linux-x86-g2-48-l4-4gpu"
cuda: "12.8"
# H100 runs only a single config, CUDA 12.3 Enable x64 1
- runner: "linux-x86-a3-8g-h100-8gpu"
cuda: "12.8"
- runner: "linux-x86-a3-8g-h100-8gpu"
cuda: "12.1"
- runner: "linux-x86-a3-8g-h100-8gpu"
enable-x64: "0"
# B200 runs only a single config, CUDA 12.8 Enable x64 1
- runner: "linux-x86-a4-224-b200-1gpu"
enable-x64: "0"
- runner: "linux-x86-a4-224-b200-1gpu"
cuda: "12.1"
- runner: "linux-x86-a4-224-b200-1gpu"
cuda: "12.3"
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
cuda: ${{ matrix.cuda }}
enable-x64: ${{ matrix.enable-x64 }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-bazel-test-cuda:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/bazel_cuda_non_rbe.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the build artifacts job above
runner: ["linux-x86-g2-48-l4-4gpu",]
python: ["3.10",]
enable-x64: [1, 0]
name: "Bazel CUDA Non-RBE (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-pytest-tpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_tpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python: ["3.10",]
tpu-specs: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: "nightly"
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}