mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add TPU test jobs to the new CI continuous and nightly/release test workflows
Also, modify the TPU presubmit workflow to reuse the `build_artifacts.yml` and `pytest_tpu.yml` PiperOrigin-RevId: 735832964
This commit is contained in:
parent
c2c68c018f
commit
7ac6355262
85
.github/workflows/cloud-tpu-ci-presubmit.yml
vendored
85
.github/workflows/cloud-tpu-ci-presubmit.yml
vendored
@ -3,6 +3,7 @@
|
||||
# This job currently runs as a non-blocking presubmit. It is experimental and is currently being
|
||||
# tested to get to a stable state before we enable it as a blocking presubmit.
|
||||
name: CI - Cloud TPU (presubmit)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
@ -33,64 +34,32 @@ concurrency:
|
||||
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
|
||||
|
||||
jobs:
|
||||
cloud-tpu-test:
|
||||
build-jax-artifacts:
|
||||
if: github.event.repository.fork == false
|
||||
# Begin Presubmit Naming Check - name modification requires internal check to be updated
|
||||
uses: ./.github/workflows/build_artifacts.yml
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
tpu: [
|
||||
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
||||
]
|
||||
python-version: ["3.10"]
|
||||
name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})"
|
||||
# End Presubmit Naming Check github-tpu-presubmits
|
||||
env:
|
||||
JAXCI_PYTHON: python${{ matrix.python-version }}
|
||||
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
artifact: ["jax", "jaxlib"]
|
||||
with:
|
||||
runner: "linux-x86-n2-16"
|
||||
artifact: ${{ matrix.artifact }}
|
||||
python: "3.10"
|
||||
clone_main_xla: 1
|
||||
upload_artifacts_to_gcs: true
|
||||
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
||||
|
||||
runs-on: ${{ matrix.tpu.runner }}
|
||||
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
|
||||
|
||||
timeout-minutes: 60
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -ex {0}
|
||||
steps:
|
||||
# https://opensource.google/documentation/reference/github/services#actions
|
||||
# mandates using a specific commit for non-Google actions. We use
|
||||
# https://github.com/sethvargo/ratchet to pin specific versions.
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
# Checkout XLA at head, if we're building jaxlib at head.
|
||||
- name: Checkout XLA at head
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: openxla/xla
|
||||
path: xla
|
||||
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
|
||||
- name: Mark GitHub workspace as safe
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
- name: Install JAX test requirements
|
||||
run: |
|
||||
$JAXCI_PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
|
||||
- name: Build jaxlib at head with latest XLA
|
||||
run: |
|
||||
# Build and install jaxlib at head
|
||||
$JAXCI_PYTHON build/build.py build --wheels=jaxlib \
|
||||
--python_version=${{ matrix.python-version }} \
|
||||
--bazel_options=--config=rbe_linux_x86_64 \
|
||||
--local_xla_path="$(pwd)/xla" \
|
||||
--verbose
|
||||
|
||||
# Install libtpu
|
||||
$JAXCI_PYTHON -m uv pip install --pre libtpu \
|
||||
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
# Halt for testing
|
||||
- name: Wait For Connection
|
||||
uses: google-ml-infra/actions/ci_connection@main
|
||||
with:
|
||||
halt-dispatch-input: ${{ inputs.halt-for-connection }}
|
||||
- name: Install jaxlib wheel and run tests
|
||||
run: ./ci/run_pytest_tpu.sh
|
||||
run-pytest-tpu:
|
||||
if: github.event.repository.fork == false
|
||||
needs: [build-jax-artifacts]
|
||||
uses: ./.github/workflows/pytest_tpu.yml
|
||||
# Begin Presubmit Naming Check - name modification requires internal check to be updated
|
||||
name: "TPU test (jaxlib=head, v5e-8)"
|
||||
with:
|
||||
runner: "linux-x86-ct5lp-224-8tpu"
|
||||
cores: "8"
|
||||
tpu-type: "v5e-8"
|
||||
python: "3.10"
|
||||
libtpu-version-type: "nightly"
|
||||
gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }}
|
||||
# End Presubmit Naming Check github-tpu-presubmits
|
149
.github/workflows/pytest_tpu.yml
vendored
Normal file
149
.github/workflows/pytest_tpu.yml
vendored
Normal file
@ -0,0 +1,149 @@
|
||||
# CI - Pytest TPU
|
||||
#
|
||||
# This workflow runs the TPU tests with Pytest. It can only be triggered by other workflows via
|
||||
# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest TPU tests.
|
||||
#
|
||||
# It consists of the following job:
|
||||
# run-tests:
|
||||
# - Downloads the jaxlib wheel from a GCS bucket.
|
||||
# - Sets up the libtpu wheels.
|
||||
# - Executes the `run_pytest_cpu.sh` script, which performs the following actions:
|
||||
# - Installs the downloaded jaxlib wheel.
|
||||
# - Runs the TPU tests with Pytest.
|
||||
name: CI - Pytest TPU
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
# Note that the values for runners, cores, and tpu-type are linked to each other.
|
||||
# For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the
|
||||
# following mapping:
|
||||
# {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
|
||||
# {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
||||
runner:
|
||||
description: "Which runner should the workflow run on?"
|
||||
type: string
|
||||
required: true
|
||||
default: "linux-x86-ct5lp-224-8tpu"
|
||||
cores:
|
||||
description: "How many TPU cores should the test use?"
|
||||
type: string
|
||||
required: true
|
||||
default: "8"
|
||||
tpu-type:
|
||||
description: "Which TPU type is used for testing?"
|
||||
type: string
|
||||
required: true
|
||||
default: "v5e-8"
|
||||
python:
|
||||
description: "Which Python version should be used for testing?"
|
||||
type: string
|
||||
required: true
|
||||
default: "3.12"
|
||||
run-full-tpu-test-suite:
|
||||
description: "Should the full TPU test suite be run?"
|
||||
type: string
|
||||
required: false
|
||||
default: "0"
|
||||
libtpu-version-type:
|
||||
description: "Which libtpu version should be used for testing?"
|
||||
type: string
|
||||
required: false
|
||||
# Choices are:
|
||||
# - "nightly": Use the nightly libtpu wheel.
|
||||
# - "pypi_latest": Use the latest libtpu wheel from PyPI.
|
||||
# - "oldest_supported_libtpu": Use the oldest supported libtpu wheel.
|
||||
default: "nightly"
|
||||
gcs_download_uri:
|
||||
description: "GCS location prefix from where the artifacts should be downloaded"
|
||||
required: true
|
||||
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
||||
type: string
|
||||
halt-for-connection:
|
||||
description: 'Should this workflow run wait for a remote connection?'
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
run-tests:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
runs-on: ${{ inputs.runner }}
|
||||
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
|
||||
# Begin Presubmit Naming Check - name modification requires internal check to be updated
|
||||
name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})"
|
||||
# End Presubmit Naming Check github-tpu-presubmits
|
||||
|
||||
env:
|
||||
LIBTPU_OLDEST_VERSION_DATE: 20241205
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
|
||||
JAXCI_PYTHON: "python${{ inputs.python }}"
|
||||
JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}"
|
||||
JAXCI_TPU_CORES: "${{ inputs.cores }}"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set env vars for use in artifact download URL
|
||||
run: |
|
||||
os=$(uname -s | awk '{print tolower($0)}')
|
||||
arch=$(uname -m)
|
||||
|
||||
# Get the major and minor version of Python.
|
||||
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
|
||||
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
|
||||
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')
|
||||
|
||||
echo "OS=${os}" >> $GITHUB_ENV
|
||||
echo "ARCH=${arch}" >> $GITHUB_ENV
|
||||
# Python wheels follow a naming convention: standard wheels use the pattern
|
||||
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
|
||||
# `*-cp<py_version>-cp<py_version>t-*`.
|
||||
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
|
||||
- name: Download JAX wheels from GCS
|
||||
id: download-wheel-artifacts
|
||||
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
||||
# fails. Instead, we verify the outcome in the step below so that we can print a more
|
||||
# informative error message.
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mkdir -p $(pwd)/dist
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
|
||||
- name: Skip the test run if the wheel artifacts were not downloaded successfully
|
||||
if: steps.download-wheel-artifacts.outcome == 'failure'
|
||||
run: |
|
||||
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
|
||||
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
|
||||
echo "Skipping the test run."
|
||||
exit 1
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt
|
||||
- name: Set up libtpu wheels
|
||||
run: |
|
||||
if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then
|
||||
echo "Using nightly libtpu"
|
||||
$JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then
|
||||
echo "Using latest libtpu from PyPI"
|
||||
# Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh`
|
||||
# script will install the latest libtpu wheel from PyPI.
|
||||
echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV
|
||||
elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then
|
||||
echo "Using oldest supported libtpu"
|
||||
$JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
|
||||
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
else
|
||||
echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
|
||||
exit 1
|
||||
fi
|
||||
# Halt for testing
|
||||
- name: Wait For Connection
|
||||
uses: google-ml-infra/actions/ci_connection@main
|
||||
with:
|
||||
halt-dispatch-input: ${{ inputs.halt-for-connection }}
|
||||
- name: Run Pytest TPU tests
|
||||
timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }}
|
||||
run: ./ci/run_pytest_tpu.sh
|
26
.github/workflows/wheel_tests_continuous.yml
vendored
26
.github/workflows/wheel_tests_continuous.yml
vendored
@ -142,4 +142,30 @@ jobs:
|
||||
python: ${{ matrix.python }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
# GCS upload URI is the same for both artifact build jobs
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
||||
|
||||
run-pytest-tpu:
|
||||
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
|
||||
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
|
||||
# still want to run the tests for other platforms.
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [build-jax-artifact, build-jaxlib-artifact]
|
||||
uses: ./.github/workflows/pytest_tpu.yml
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
python: ["3.10",]
|
||||
tpu-specs: [
|
||||
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
|
||||
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
|
||||
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
||||
]
|
||||
name: "TPU tests (jax=head, jaxlib=head)"
|
||||
with:
|
||||
runner: ${{ matrix.tpu-specs.runner }}
|
||||
cores: ${{ matrix.tpu-specs.cores }}
|
||||
tpu-type: ${{ matrix.tpu-specs.type }}
|
||||
python: ${{ matrix.python }}
|
||||
run-full-tpu-test-suite: "1"
|
||||
libtpu-version-type: "nightly"
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
@ -58,4 +58,29 @@ jobs:
|
||||
python: ${{ matrix.python }}
|
||||
cuda: ${{ matrix.cuda }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
||||
|
||||
run-pytest-tpu:
|
||||
uses: ./.github/workflows/pytest_tpu.yml
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
matrix:
|
||||
python: ["3.10","3.11", "3.12", "3.13"]
|
||||
tpu-specs: [
|
||||
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
|
||||
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
|
||||
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
|
||||
]
|
||||
libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"]
|
||||
exclude:
|
||||
- libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }}
|
||||
- libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }}
|
||||
name: "TPU tests (jax=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, jaxlib=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
|
||||
with:
|
||||
runner: ${{ matrix.tpu-specs.runner }}
|
||||
cores: ${{ matrix.tpu-specs.cores }}
|
||||
tpu-type: ${{ matrix.tpu-specs.type }}
|
||||
python: ${{ matrix.python }}
|
||||
run-full-tpu-test-suite: "1"
|
||||
libtpu-version-type: ${{ matrix.libtpu-version-type }}
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
@ -74,4 +74,14 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-}
|
||||
# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels
|
||||
# on the system. By default, it is set to match the version of the hermetic
|
||||
# Python used by Bazel for building the wheels.
|
||||
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
|
||||
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
|
||||
|
||||
# When set to 1, the full TPU test suite is run. Otherwise, a subset of tests
|
||||
# is run.
|
||||
export JAXCI_RUN_FULL_TPU_TEST_SUITE=${JAXCI_RUN_FULL_TPU_TEST_SUITE:-0}
|
||||
|
||||
# We use this environment variable to control which additional wheels to install
|
||||
# from PyPI. For instance, it can be set to "tpu_pypi" to install the latest
|
||||
# libtpu wheel from PyPI. See ci/utilities/install_wheels_locally.sh for the
|
||||
# list of valid values and their behavior.
|
||||
export JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=${JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI:-""}
|
@ -52,23 +52,37 @@ export JAX_SKIP_SLOW_TESTS=true
|
||||
|
||||
echo "Running TPU tests..."
|
||||
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
|
||||
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
|
||||
--maxfail=20 -m "not multiaccelerator" \
|
||||
tests/pallas/ops_test.py \
|
||||
tests/pallas/export_back_compat_pallas_test.py \
|
||||
tests/pallas/export_pallas_test.py \
|
||||
tests/pallas/tpu_ops_test.py \
|
||||
tests/pallas/tpu_pallas_test.py \
|
||||
tests/pallas/tpu_pallas_random_test.py \
|
||||
tests/pallas/tpu_pallas_async_test.py \
|
||||
tests/pallas/tpu_pallas_state_test.py
|
||||
if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
|
||||
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
|
||||
--maxfail=20 -m "not multiaccelerator" tests examples
|
||||
|
||||
# Run Pallas printing tests, which need to run with I/O capturing disabled.
|
||||
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
|
||||
# Run Pallas printing tests, which need to run with I/O capturing disabled.
|
||||
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s \
|
||||
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
|
||||
|
||||
# Run multi-accelerator across all chips
|
||||
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \
|
||||
tests/pjit_test.py \
|
||||
tests/pallas/tpu_pallas_distributed_test.py
|
||||
# Run multi-accelerator across all chips
|
||||
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
|
||||
else
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
|
||||
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
|
||||
--maxfail=20 -m "not multiaccelerator" \
|
||||
tests/pallas/ops_test.py \
|
||||
tests/pallas/export_back_compat_pallas_test.py \
|
||||
tests/pallas/export_pallas_test.py \
|
||||
tests/pallas/tpu_ops_test.py \
|
||||
tests/pallas/tpu_pallas_test.py \
|
||||
tests/pallas/tpu_pallas_random_test.py \
|
||||
tests/pallas/tpu_pallas_async_test.py \
|
||||
tests/pallas/tpu_pallas_state_test.py
|
||||
|
||||
# Run Pallas printing tests, which need to run with I/O capturing disabled.
|
||||
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
|
||||
|
||||
# Run multi-accelerator across all chips
|
||||
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \
|
||||
tests/pjit_test.py \
|
||||
tests/pallas/tpu_pallas_distributed_test.py
|
||||
fi
|
@ -17,8 +17,19 @@
|
||||
# Install wheels stored in `JAXCI_OUTPUT_DIR` on the system using the Python
|
||||
# binary set in JAXCI_PYTHON. Use the absolute path to the `find` utility to
|
||||
# avoid using the Windows version of `find` on Msys.
|
||||
|
||||
WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) )
|
||||
|
||||
for i in "${!WHEELS[@]}"; do
|
||||
if [[ "${WHEELS[$i]}" == *jax*py3*none*any.whl ]]; then
|
||||
if [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "tpu_pypi" ]]; then
|
||||
# Append [tpu] to the jax wheel name to download the latest libtpu wheel
|
||||
# from PyPI.
|
||||
WHEELS[$i]="${WHEELS[$i]}[tpu]"
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ -z "${WHEELS[@]}" ]]; then
|
||||
echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR"
|
||||
exit 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user