mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
192 lines
8.9 KiB
YAML
192 lines
8.9 KiB
YAML
# CI - Wheel Tests (Continuous)
|
|
#
|
|
# This workflow builds JAX artifacts and runs CPU/CUDA tests.
|
|
#
|
|
# It orchestrates the following:
|
|
# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and
|
|
# uploads it to a GCS bucket.
|
|
# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel
|
|
# that was built in the previous step and runs CPU tests.
|
|
# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and
|
|
# uploads them to a GCS bucket.
|
|
# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA
|
|
# artifacts that were built in the previous steps and runs the CUDA tests.
|
|
# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib
|
|
# and CUDA artifacts that were built in the previous steps and runs the
|
|
# CUDA tests using Bazel.
|
|
|
|
name: CI - Wheel Tests (Continuous)
|
|
|
|
on:
|
|
schedule:
|
|
- cron: "0 */2 * * *" # Run once every 2 hours
|
|
workflow_dispatch: # allows triggering the workflow run manually
|
|
|
|
concurrency:
|
|
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
|
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
|
|
|
|
jobs:
|
|
build-jax-artifact:
|
|
uses: ./.github/workflows/build_artifacts.yml
|
|
name: "Build jax artifact"
|
|
with:
|
|
# Note that since jax is a pure python package, the runner OS and Python values do not
|
|
# matter. In addition, cloning main XLA also has no effect.
|
|
runner: "linux-x86-n2-16"
|
|
artifact: "jax"
|
|
upload_artifacts_to_gcs: true
|
|
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
|
|
|
build-jaxlib-artifact:
|
|
uses: ./.github/workflows/build_artifacts.yml
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
|
|
runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-64"]
|
|
artifact: ["jaxlib"]
|
|
python: ["3.10"]
|
|
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
|
|
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
|
|
# values to the name and creates a separate entry for each matrix combination.
|
|
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
|
|
with:
|
|
runner: ${{ matrix.runner }}
|
|
artifact: ${{ matrix.artifact }}
|
|
python: ${{ matrix.python }}
|
|
clone_main_xla: 1
|
|
upload_artifacts_to_gcs: true
|
|
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
|
|
|
build-cuda-artifacts:
|
|
uses: ./.github/workflows/build_artifacts.yml
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
# Python values need to match the matrix stategy in the CUDA tests job below
|
|
runner: ["linux-x86-n2-16"]
|
|
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
|
|
python: ["3.10",]
|
|
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
|
|
with:
|
|
runner: ${{ matrix.runner }}
|
|
artifact: ${{ matrix.artifact }}
|
|
python: ${{ matrix.python }}
|
|
clone_main_xla: 1
|
|
upload_artifacts_to_gcs: true
|
|
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
|
|
|
run-pytest-cpu:
|
|
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
|
|
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
|
|
# still want to run the tests for other platforms.
|
|
if: ${{ !cancelled() }}
|
|
needs: [build-jax-artifact, build-jaxlib-artifact]
|
|
uses: ./.github/workflows/pytest_cpu.yml
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
# Runner OS and Python values need to match the matrix stategy in the
|
|
# build_jaxlib_artifact job above
|
|
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
|
|
python: ["3.10",]
|
|
enable-x64: [1, 0]
|
|
name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
|
|
with:
|
|
runner: ${{ matrix.runner }}
|
|
python: ${{ matrix.python }}
|
|
enable-x64: ${{ matrix.enable-x64 }}
|
|
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
|
|
|
run-pytest-cuda:
|
|
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
|
|
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
|
|
# still want to run the tests for other platforms.
|
|
if: ${{ !cancelled() }}
|
|
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
|
|
uses: ./.github/workflows/pytest_cuda.yml
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
# Python values need to match the matrix stategy in the artifact build jobs above
|
|
# See exlusions for what is fully tested
|
|
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"]
|
|
python: ["3.10",]
|
|
cuda: ["12.1","12.3","12.8"]
|
|
enable-x64: [1, 0]
|
|
exclude:
|
|
# L4 does not run on cuda 12.8 but tests other configs
|
|
- runner: "linux-x86-g2-48-l4-4gpu"
|
|
cuda: "12.8"
|
|
# H100 runs only a single config, CUDA 12.3 Enable x64 1
|
|
- runner: "linux-x86-a3-8g-h100-8gpu"
|
|
cuda: "12.8"
|
|
- runner: "linux-x86-a3-8g-h100-8gpu"
|
|
cuda: "12.1"
|
|
- runner: "linux-x86-a3-8g-h100-8gpu"
|
|
enable-x64: "0"
|
|
# B200 runs only a single config, CUDA 12.8 Enable x64 1
|
|
- runner: "linux-x86-a4-224-b200-1gpu"
|
|
enable-x64: "0"
|
|
- runner: "linux-x86-a4-224-b200-1gpu"
|
|
cuda: "12.1"
|
|
- runner: "linux-x86-a4-224-b200-1gpu"
|
|
cuda: "12.3"
|
|
|
|
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
|
|
with:
|
|
runner: ${{ matrix.runner }}
|
|
python: ${{ matrix.python }}
|
|
cuda: ${{ matrix.cuda }}
|
|
enable-x64: ${{ matrix.enable-x64 }}
|
|
# GCS upload URI is the same for both artifact build jobs
|
|
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
|
|
|
run-bazel-test-cuda:
|
|
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
|
|
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
|
|
# still want to run the tests for other platforms.
|
|
if: ${{ !cancelled() }}
|
|
needs: [build-jaxlib-artifact, build-cuda-artifacts]
|
|
uses: ./.github/workflows/bazel_cuda_non_rbe.yml
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
# Python values need to match the matrix stategy in the build artifacts job above
|
|
runner: ["linux-x86-g2-48-l4-4gpu",]
|
|
python: ["3.10",]
|
|
enable-x64: [1, 0]
|
|
name: "Bazel CUDA Non-RBE (JAX artifacts version = ${{ format('{0}', 'head') }})"
|
|
with:
|
|
runner: ${{ matrix.runner }}
|
|
python: ${{ matrix.python }}
|
|
enable-x64: ${{ matrix.enable-x64 }}
|
|
# GCS upload URI is the same for both artifact build jobs
|
|
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
|
|
|
run-pytest-tpu:
|
|
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
|
|
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
|
|
# still want to run the tests for other platforms.
|
|
if: ${{ !cancelled() }}
|
|
needs: [build-jax-artifact, build-jaxlib-artifact]
|
|
uses: ./.github/workflows/pytest_tpu.yml
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
python: ["3.10",]
|
|
tpu-specs: [
|
|
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
|
|
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
|
|
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
|
]
|
|
name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
|
|
with:
|
|
runner: ${{ matrix.tpu-specs.runner }}
|
|
cores: ${{ matrix.tpu-specs.cores }}
|
|
tpu-type: ${{ matrix.tpu-specs.type }}
|
|
python: ${{ matrix.python }}
|
|
run-full-tpu-test-suite: "1"
|
|
libtpu-version-type: "nightly"
|
|
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} |