mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
135 lines
6.8 KiB
YAML
135 lines
6.8 KiB
YAML
# Cloud TPU CI
|
|
#
|
|
# This job currently runs once per day. We use self-hosted TPU runners, so we'd
|
|
# have to add more runners to run on every commit.
|
|
#
|
|
# This job's build matrix runs over several TPU architectures using both the
|
|
# latest released jaxlib on PyPi ("pypi_latest") and the latest nightly
|
|
# jaxlib.("nightly"). It also installs a matching libtpu, either the one pinned
|
|
# to the release for "pypi_latest", or the latest nightly.for "nightly". It
|
|
# always locally installs jax from github head (already checked out by the
|
|
# Github Actions environment).
|
|
|
|
name: CI - Cloud TPU (nightly)
|
|
on:
|
|
schedule:
|
|
- cron: "0 2,14 * * *" # Run at 7am and 7pm PST
|
|
workflow_dispatch: # allows triggering the workflow run manually
|
|
# This should also be set to read-only in the project settings, but it's nice to
|
|
# document and enforce the permissions here.
|
|
permissions:
|
|
contents: read
|
|
jobs:
|
|
cloud-tpu-test:
|
|
strategy:
|
|
fail-fast: false # don't cancel all jobs on failure
|
|
matrix:
|
|
jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
|
|
tpu: [
|
|
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
|
|
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
|
|
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
|
]
|
|
python-version: ["3.10"]
|
|
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
|
|
env:
|
|
LIBTPU_OLDEST_VERSION_DATE: 20241118
|
|
PYTHON: python${{ matrix.python-version }}
|
|
runs-on: ${{ matrix.tpu.runner }}
|
|
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
|
|
timeout-minutes: 180
|
|
defaults:
|
|
run:
|
|
shell: bash -ex {0}
|
|
steps:
|
|
# https://opensource.google/documentation/reference/github/services#actions
|
|
# mandates using a specific commit for non-Google actions. We use
|
|
# https://github.com/sethvargo/ratchet to pin specific versions.
|
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
# Checkout XLA at head, if we're building jaxlib at head.
|
|
- name: Checkout XLA at head
|
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
if: ${{ matrix.jaxlib-version == 'head' }}
|
|
with:
|
|
repository: openxla/xla
|
|
path: xla
|
|
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
|
|
- name: Mark GitHub workspace as safe
|
|
run: |
|
|
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
|
- name: Install JAX test requirements
|
|
run: |
|
|
$PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
|
|
- name: Install JAX
|
|
run: |
|
|
$PYTHON -m uv pip uninstall -y jax jaxlib libtpu
|
|
if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
|
|
# Build and install jaxlib at head
|
|
$PYTHON build/build.py build --wheels=jaxlib \
|
|
--bazel_options=--config=rbe_linux_x86_64 \
|
|
--local_xla_path="$(pwd)/xla" \
|
|
--verbose
|
|
|
|
# Install jaxlib, "jax" at head, and libtpu
|
|
$PYTHON -m uv pip install dist/*.whl \
|
|
-U -e . \
|
|
--pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
|
|
$PYTHON -m uv pip install .[tpu] \
|
|
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
|
|
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
|
|
$PYTHON -m uv pip install \
|
|
--pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
|
|
--pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
|
|
install requests
|
|
|
|
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
|
|
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
|
|
$PYTHON -m uv pip install \
|
|
--pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
|
|
--pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
|
|
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
|
|
install requests
|
|
else
|
|
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
|
|
exit 1
|
|
fi
|
|
|
|
$PYTHON -c 'import sys; print("python version:", sys.version)'
|
|
$PYTHON -c 'import jax; print("jax version:", jax.__version__)'
|
|
$PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
|
|
strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
|
|
$PYTHON -c 'import jax; print("libtpu version:",
|
|
jax.lib.xla_bridge.get_backend().platform_version)'
|
|
- name: Run tests
|
|
env:
|
|
JAX_PLATFORMS: tpu,cpu
|
|
PY_COLORS: 1
|
|
run: |
|
|
# We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic TPU does not
|
|
# guarantee anything about forward compatibility (unless jax.export is used) and the 12
|
|
# week compatibility window accumulates way too many failures.
|
|
IGNORE_FLAGS=
|
|
if [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
|
|
IGNORE_FLAGS="--ignore=tests/pallas"
|
|
fi
|
|
# Run single-accelerator tests in parallel
|
|
JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
|
|
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
|
|
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
|
|
# Run Pallas printing tests, which need to run with I/O capturing disabled.
|
|
TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
|
|
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
|
|
# Run multi-accelerator across all chips
|
|
$PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
|
|
- name: Send chat on failure
|
|
# Don't notify when testing the workflow from a branch.
|
|
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
|
|
run: |
|
|
curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
|
|
--header 'Content-Type: application/json' \
|
|
--data-raw "{
|
|
'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu.type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
|
|
}"
|