Download and use jax wheels from GCS bucket for nightly/release test workflows

Unlike continuous workflows, when testing nightly/release artifacts, we want to download and install the `jax` wheels found in the GCS bucket instead of installing it from HEAD.

It looks like `env` setting in the calling workflow isn't passed over to the called workflows so we define a new workflow input, `install-jax-current-commit`, to control the `jax` install behavior.

PiperOrigin-RevId: 726086522
This commit is contained in:
Nitin Srinivasan 2025-02-12 09:30:41 -08:00 committed by jax authors
parent 837418c652
commit 93831bdde7
4 changed files with 38 additions and 8 deletions

View File

@ -29,6 +29,11 @@ on:
type: string
required: true
default: "0"
install-jax-current-commit:
description: "Should the 'jax' package be installed from the current commit?"
type: string
required: true
default: "1"
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
required: true
@ -57,6 +62,7 @@ jobs:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -85,9 +91,14 @@ jobs:
# informative error message.
continue-on-error: true
if: ${{ !contains(inputs.runner, 'windows-x86') }}
run: >-
run: |
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/
fi
- name: Download jaxlib wheel from GCS (Windows runs)
id: download-wheel-artifacts-w
# Set continue-on-error to true to prevent actions from failing the workflow if this step
@ -96,9 +107,14 @@ jobs:
continue-on-error: true
if: ${{ contains(inputs.runner, 'windows-x86') }}
shell: cmd
run: >-
run: |
mkdir dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
if not "${{ inputs.install-jax-current-commit }}"=="1" (
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl dist/
)
- name: Skip the test run if the wheel artifacts were not downloaded successfully
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
run: |

View File

@ -34,6 +34,11 @@ on:
type: string
required: true
default: "0"
install-jax-current-commit:
description: "Should the 'jax' package be installed from the current commit?"
type: string
required: true
default: "1"
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
required: true
@ -61,6 +66,7 @@ jobs:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -82,11 +88,16 @@ jobs:
# fails. Instead, we verify the outcome in the next step so that we can print a more
# informative error message.
continue-on-error: true
run: >-
run: |
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/
fi
- name: Skip the test run if the wheel artifacts were not downloaded successfully
if: steps.download-wheel-artifacts.outcome == 'failure'
run: |

View File

@ -79,6 +79,7 @@ jobs:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
install-jax-current-commit: 1
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-pytest-cuda:
@ -109,6 +110,7 @@ jobs:
python: ${{ matrix.python }}
cuda: ${{ matrix.cuda }}
enable-x64: ${{ matrix.enable-x64 }}
install-jax-current-commit: 1
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

View File

@ -22,11 +22,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
env:
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
# GCS bucket.
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "0"
jobs:
run-pytest-cpu:
uses: ./.github/workflows/pytest_cpu.yml
@ -42,6 +37,9 @@ jobs:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
# GCS bucket.
install-jax-current-commit: 0
gcs_download_uri: ${{inputs.gcs_download_uri}}
run-pytest-cuda:
@ -60,4 +58,7 @@ jobs:
python: ${{ matrix.python }}
cuda: ${{ matrix.cuda }}
enable-x64: ${{ matrix.enable-x64 }}
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
# GCS bucket.
install-jax-current-commit: 0
gcs_download_uri: ${{inputs.gcs_download_uri}}