# CI - Pytest TPU # # This workflow runs the TPU tests with Pytest. It can only be triggered by other workflows via # `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest TPU tests. # # It consists of the following job: # run-tests: # - Downloads the jaxlib wheel from a GCS bucket. # - Sets up the libtpu wheels. # - Executes the `run_pytest_cpu.sh` script, which performs the following actions: # - Installs the downloaded jaxlib wheel. # - Runs the TPU tests with Pytest. name: CI - Pytest TPU on: workflow_call: inputs: # Note that the values for runners, cores, and tpu-type are linked to each other. # For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the # following mapping: # {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, # {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} runner: description: "Which runner should the workflow run on?" type: string required: true default: "linux-x86-ct5lp-224-8tpu" cores: description: "How many TPU cores should the test use?" type: string required: true default: "8" tpu-type: description: "Which TPU type is used for testing?" type: string required: true default: "v5e-8" python: description: "Which Python version should be used for testing?" type: string required: true default: "3.12" run-full-tpu-test-suite: description: "Should the full TPU test suite be run?" type: string required: false default: "0" libtpu-version-type: description: "Which libtpu version should be used for testing?" type: string required: false # Choices are: # - "nightly": Use the nightly libtpu wheel. # - "pypi_latest": Use the latest libtpu wheel from PyPI. # - "oldest_supported_libtpu": Use the oldest supported libtpu wheel. default: "nightly" gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' type: boolean required: false default: false jobs: run-tests: defaults: run: shell: bash runs-on: ${{ inputs.runner }} container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" # Begin Presubmit Naming Check - name modification requires internal check to be updated name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})" # End Presubmit Naming Check github-tpu-presubmits env: LIBTPU_OLDEST_VERSION_DATE: 20241205 JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}" JAXCI_TPU_CORES: "${{ inputs.cores }}" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set env vars for use in artifact download URL run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) # Get the major and minor version of Python. # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') echo "OS=${os}" >> $GITHUB_ENV echo "ARCH=${arch}" >> $GITHUB_ENV # Python wheels follow a naming convention: standard wheels use the pattern # `*-cp-cp-*`, while free-threaded wheels use # `*-cp-cpt-*`. echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV - name: Download JAX wheels from GCS id: download-wheel-artifacts # Set continue-on-error to true to prevent actions from failing the workflow if this step # fails. Instead, we verify the outcome in the step below so that we can print a more # informative error message. continue-on-error: true run: | mkdir -p $(pwd)/dist gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | echo "Failed to download wheel artifacts from GCS. Please check if the wheels were" echo "built successfully by the artifact build jobs and are available in the GCS bucket." echo "Skipping the test run." exit 1 - name: Install Python dependencies run: | $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt - name: Set up libtpu wheels run: | if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then echo "Using nightly libtpu" $JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then echo "Using latest libtpu from PyPI" # Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh` # script will install the latest libtpu wheel from PyPI. echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then echo "Using oldest supported libtpu" $JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV else echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}" exit 1 fi # Halt for testing - name: Wait For Connection uses: google-ml-infra/actions/ci_connection@main with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest TPU tests timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }} run: ./ci/run_pytest_tpu.sh