mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Download and use jax
wheels from GCS bucket for nightly/release test workflows
Unlike continuous workflows, when testing nightly/release artifacts, we want to download and install the `jax` wheels found in the GCS bucket instead of installing it from HEAD. It looks like `env` setting in the calling workflow isn't passed over to the called workflows so we define a new workflow input, `install-jax-current-commit`, to control the `jax` install behavior. PiperOrigin-RevId: 726086522
This commit is contained in:
parent
837418c652
commit
93831bdde7
20
.github/workflows/pytest_cpu.yml
vendored
20
.github/workflows/pytest_cpu.yml
vendored
@ -29,6 +29,11 @@ on:
|
||||
type: string
|
||||
required: true
|
||||
default: "0"
|
||||
install-jax-current-commit:
|
||||
description: "Should the 'jax' package be installed from the current commit?"
|
||||
type: string
|
||||
required: true
|
||||
default: "1"
|
||||
gcs_download_uri:
|
||||
description: "GCS location prefix from where the artifacts should be downloaded"
|
||||
required: true
|
||||
@ -57,6 +62,7 @@ jobs:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
|
||||
JAXCI_PYTHON: "python${{ inputs.python }}"
|
||||
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
|
||||
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -85,9 +91,14 @@ jobs:
|
||||
# informative error message.
|
||||
continue-on-error: true
|
||||
if: ${{ !contains(inputs.runner, 'windows-x86') }}
|
||||
run: >-
|
||||
run: |
|
||||
mkdir -p $(pwd)/dist &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
|
||||
|
||||
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
|
||||
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/
|
||||
fi
|
||||
- name: Download jaxlib wheel from GCS (Windows runs)
|
||||
id: download-wheel-artifacts-w
|
||||
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
||||
@ -96,9 +107,14 @@ jobs:
|
||||
continue-on-error: true
|
||||
if: ${{ contains(inputs.runner, 'windows-x86') }}
|
||||
shell: cmd
|
||||
run: >-
|
||||
run: |
|
||||
mkdir dist &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
|
||||
|
||||
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
|
||||
if not "${{ inputs.install-jax-current-commit }}"=="1" (
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl dist/
|
||||
)
|
||||
- name: Skip the test run if the wheel artifacts were not downloaded successfully
|
||||
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
|
||||
run: |
|
||||
|
13
.github/workflows/pytest_cuda.yml
vendored
13
.github/workflows/pytest_cuda.yml
vendored
@ -34,6 +34,11 @@ on:
|
||||
type: string
|
||||
required: true
|
||||
default: "0"
|
||||
install-jax-current-commit:
|
||||
description: "Should the 'jax' package be installed from the current commit?"
|
||||
type: string
|
||||
required: true
|
||||
default: "1"
|
||||
gcs_download_uri:
|
||||
description: "GCS location prefix from where the artifacts should be downloaded"
|
||||
required: true
|
||||
@ -61,6 +66,7 @@ jobs:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
|
||||
JAXCI_PYTHON: "python${{ inputs.python }}"
|
||||
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
|
||||
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -82,11 +88,16 @@ jobs:
|
||||
# fails. Instead, we verify the outcome in the next step so that we can print a more
|
||||
# informative error message.
|
||||
continue-on-error: true
|
||||
run: >-
|
||||
run: |
|
||||
mkdir -p $(pwd)/dist &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
|
||||
|
||||
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
|
||||
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/
|
||||
fi
|
||||
- name: Skip the test run if the wheel artifacts were not downloaded successfully
|
||||
if: steps.download-wheel-artifacts.outcome == 'failure'
|
||||
run: |
|
||||
|
2
.github/workflows/wheel_tests_continuous.yml
vendored
2
.github/workflows/wheel_tests_continuous.yml
vendored
@ -79,6 +79,7 @@ jobs:
|
||||
runner: ${{ matrix.runner }}
|
||||
python: ${{ matrix.python }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
install-jax-current-commit: 1
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
||||
|
||||
run-pytest-cuda:
|
||||
@ -109,6 +110,7 @@ jobs:
|
||||
python: ${{ matrix.python }}
|
||||
cuda: ${{ matrix.cuda }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
install-jax-current-commit: 1
|
||||
# GCS upload URI is the same for both artifact build jobs
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
||||
|
||||
|
@ -22,11 +22,6 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
|
||||
# GCS bucket.
|
||||
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "0"
|
||||
|
||||
jobs:
|
||||
run-pytest-cpu:
|
||||
uses: ./.github/workflows/pytest_cpu.yml
|
||||
@ -42,6 +37,9 @@ jobs:
|
||||
runner: ${{ matrix.runner }}
|
||||
python: ${{ matrix.python }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
|
||||
# GCS bucket.
|
||||
install-jax-current-commit: 0
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
||||
|
||||
run-pytest-cuda:
|
||||
@ -60,4 +58,7 @@ jobs:
|
||||
python: ${{ matrix.python }}
|
||||
cuda: ${{ matrix.cuda }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
|
||||
# GCS bucket.
|
||||
install-jax-current-commit: 0
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
Loading…
x
Reference in New Issue
Block a user