mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add an experimental Cloud TPU presubmit job
This adds an experimental non-blocking presubmit job that will run a subset of TPU tests, focusing on frequently failing tests. The goal is to achieve comprehensive coverage while keeping the runtime around 10 minutes. PiperOrigin-RevId: 706064568
This commit is contained in:
parent
f4e5f14a7b
commit
d05ab5bb0d
93
.github/workflows/cloud-tpu-ci-presubmit.yml
vendored
Normal file
93
.github/workflows/cloud-tpu-ci-presubmit.yml
vendored
Normal file
@ -0,0 +1,93 @@
|
||||
# Cloud TPU CI (presubmit)
|
||||
#
|
||||
# This job currently runs as a non-blocking presubmit. It is experimental and is currently being
|
||||
# tested to get to a stable state before we enable it as a blocking presubmit.
|
||||
name: CI - Cloud TPU (presubmit)
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
halt-for-connection:
|
||||
description: 'Should this workflow run wait for a remote connection?'
|
||||
type: choice
|
||||
required: true
|
||||
default: 'no'
|
||||
options:
|
||||
- 'yes'
|
||||
- 'no'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
# This should also be set to read-only in the project settings, but it's nice to
|
||||
# document and enforce the permissions here.
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
cloud-tpu-test:
|
||||
if: github.event.repository.fork == false
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
tpu: [
|
||||
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
||||
]
|
||||
python-version: ["3.10"]
|
||||
|
||||
name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})"
|
||||
|
||||
env:
|
||||
JAXCI_PYTHON: python${{ matrix.python-version }}
|
||||
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}
|
||||
|
||||
runs-on: ${{ matrix.tpu.runner }}
|
||||
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
|
||||
|
||||
timeout-minutes: 60
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -ex {0}
|
||||
|
||||
steps:
|
||||
# https://opensource.google/documentation/reference/github/services#actions
|
||||
# mandates using a specific commit for non-Google actions. We use
|
||||
# https://github.com/sethvargo/ratchet to pin specific versions.
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
# Checkout XLA at head, if we're building jaxlib at head.
|
||||
- name: Checkout XLA at head
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: openxla/xla
|
||||
path: xla
|
||||
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
|
||||
- name: Mark GitHub workspace as safe
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
- name: Install JAX test requirements
|
||||
run: |
|
||||
$JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
|
||||
$JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
|
||||
- name: Build jaxlib at head with latest XLA
|
||||
run: |
|
||||
# Build and install jaxlib at head
|
||||
$JAXCI_PYTHON build/build.py build --wheels=jaxlib \
|
||||
--python_version=${{ matrix.python-version }} \
|
||||
--bazel_options=--config=rbe_linux_x86_64 \
|
||||
--local_xla_path="$(pwd)/xla" \
|
||||
--verbose
|
||||
|
||||
# Install libtpu
|
||||
$JAXCI_PYTHON -m pip install --pre libtpu \
|
||||
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
# Halt for testing
|
||||
- name: Wait For Connection
|
||||
uses: google-ml-infra/actions/ci_connection@main
|
||||
with:
|
||||
halt-dispatch-input: ${{ inputs.halt-for-connection }}
|
||||
- name: Install jaxlib wheel and run tests
|
||||
run: ./ci/run_pytest_tpu.sh
|
25
ci/run_pytest_tpu.sh
Normal file → Executable file
25
ci/run_pytest_tpu.sh
Normal file → Executable file
@ -33,29 +33,28 @@ source ./ci/utilities/install_wheels_locally.sh
|
||||
# Set up the build environment.
|
||||
source "ci/utilities/setup_build_environment.sh"
|
||||
|
||||
export PY_COLORS=1
|
||||
export JAX_SKIP_SLOW_TESTS=true
|
||||
|
||||
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
|
||||
|
||||
"$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)'
|
||||
"$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)'
|
||||
"$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
|
||||
strings /usr/local/lib/"$JAXCI_PYTHON"/site-packages/libtpu/libtpu.so | grep 'Built on'
|
||||
strings /usr/local/lib/"$JAXCI_PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
|
||||
"$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'
|
||||
|
||||
echo "Running TPU tests..."
|
||||
# Set up all common test environment variables
|
||||
export PY_COLORS=1
|
||||
export JAX_PLATFORMS=tpu,cpu
|
||||
# Run single-accelerator tests in parallel
|
||||
export JAX_ENABLE_TPU_XDIST=true
|
||||
export JAX_SKIP_SLOW_TESTS=true
|
||||
# End of common test environment variable setup
|
||||
|
||||
"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
|
||||
echo "Running TPU tests..."
|
||||
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
|
||||
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
|
||||
--maxfail=20 -m "not multiaccelerator" tests examples
|
||||
--maxfail=20 -m "not multiaccelerator" tests/pallas/tpu_ops_test.py
|
||||
|
||||
# Run Pallas printing tests, which need to run with I/O capturing disabled.
|
||||
export TPU_STDERR_LOG_LEVEL=0
|
||||
"$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
|
||||
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
|
||||
|
||||
# Run multi-accelerator across all chips
|
||||
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
|
||||
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests/pjit_test.py
|
Loading…
x
Reference in New Issue
Block a user