mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Build JAX wheels instead of installing it from the source repository
This change allows us to get rid of extra env vars which used to control whether to install `jax` at head. Now, `jax` will be be built and consumed in the same way as the other wheels in the continuous jobs. PiperOrigin-RevId: 734123590
This commit is contained in:
parent
2a34019388
commit
623865fe95
24
.github/workflows/pytest_cpu.yml
vendored
24
.github/workflows/pytest_cpu.yml
vendored
@ -29,11 +29,6 @@ on:
|
||||
type: string
|
||||
required: true
|
||||
default: "0"
|
||||
install-jax-current-commit:
|
||||
description: "Should the 'jax' package be installed from the current commit?"
|
||||
type: string
|
||||
required: true
|
||||
default: "1"
|
||||
gcs_download_uri:
|
||||
description: "GCS location prefix from where the artifacts should be downloaded"
|
||||
required: true
|
||||
@ -62,7 +57,6 @@ jobs:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
|
||||
JAXCI_PYTHON: "python${{ inputs.python }}"
|
||||
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
|
||||
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -88,7 +82,7 @@ jobs:
|
||||
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
|
||||
# `*-cp<py_version>-cp<py_version>t-*`.
|
||||
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
|
||||
- name: Download jaxlib wheel from GCS (non-Windows runs)
|
||||
- name: Download wheels from GCS (non-Windows runs)
|
||||
id: download-wheel-artifacts-nw
|
||||
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
||||
# fails. Instead, we verify the outcome in the step below so that we can print a more
|
||||
@ -96,14 +90,10 @@ jobs:
|
||||
continue-on-error: true
|
||||
if: ${{ !contains(inputs.runner, 'windows-x86') }}
|
||||
run: |
|
||||
mkdir -p $(pwd)/dist &&
|
||||
mkdir -p $(pwd)/dist
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
|
||||
|
||||
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
|
||||
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
|
||||
fi
|
||||
- name: Download jaxlib wheel from GCS (Windows runs)
|
||||
- name: Download wheels from GCS (Windows runs)
|
||||
id: download-wheel-artifacts-w
|
||||
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
||||
# fails. Instead, we verify the outcome in step below so that we can print a more
|
||||
@ -115,12 +105,8 @@ jobs:
|
||||
mkdir dist
|
||||
@REM Use `call` so that we can run sequential gsutil commands on Windows
|
||||
@REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652
|
||||
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
|
||||
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
|
||||
|
||||
@REM Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
|
||||
if not "${{ inputs.install-jax-current-commit }}"=="1" (
|
||||
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
|
||||
)
|
||||
- name: Skip the test run if the wheel artifacts were not downloaded successfully
|
||||
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
|
||||
run: |
|
||||
|
14
.github/workflows/pytest_cuda.yml
vendored
14
.github/workflows/pytest_cuda.yml
vendored
@ -34,11 +34,6 @@ on:
|
||||
type: string
|
||||
required: true
|
||||
default: "0"
|
||||
install-jax-current-commit:
|
||||
description: "Should the 'jax' package be installed from the current commit?"
|
||||
type: string
|
||||
required: true
|
||||
default: "1"
|
||||
gcs_download_uri:
|
||||
description: "GCS location prefix from where the artifacts should be downloaded"
|
||||
required: true
|
||||
@ -66,7 +61,6 @@ jobs:
|
||||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
|
||||
JAXCI_PYTHON: "python${{ inputs.python }}"
|
||||
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
|
||||
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -86,7 +80,7 @@ jobs:
|
||||
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
|
||||
# `*-cp<py_version>-cp<py_version>t-*`.
|
||||
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
|
||||
- name: Download the wheel artifacts from GCS
|
||||
- name: Download wheels from GCS
|
||||
id: download-wheel-artifacts
|
||||
# Set continue-on-error to true to prevent actions from failing the workflow if this step
|
||||
# fails. Instead, we verify the outcome in the next step so that we can print a more
|
||||
@ -94,14 +88,10 @@ jobs:
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mkdir -p $(pwd)/dist &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
|
||||
|
||||
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
|
||||
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
|
||||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
|
||||
fi
|
||||
- name: Skip the test run if the wheel artifacts were not downloaded successfully
|
||||
if: steps.download-wheel-artifacts.outcome == 'failure'
|
||||
run: |
|
||||
|
16
.github/workflows/wheel_tests_continuous.yml
vendored
16
.github/workflows/wheel_tests_continuous.yml
vendored
@ -27,6 +27,16 @@ concurrency:
|
||||
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
|
||||
|
||||
jobs:
|
||||
build-jax-artifact:
|
||||
uses: ./.github/workflows/build_artifacts.yml
|
||||
with:
|
||||
# Note that since jax is a pure python package, the runner OS and Python values do not
|
||||
# matter. In addition, cloning main XLA also has no effect.
|
||||
runner: "linux-x86-n2-16"
|
||||
artifact: "jax"
|
||||
upload_artifacts_to_gcs: true
|
||||
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
|
||||
|
||||
build-jaxlib-artifact:
|
||||
uses: ./.github/workflows/build_artifacts.yml
|
||||
strategy:
|
||||
@ -66,7 +76,7 @@ jobs:
|
||||
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
|
||||
# still want to run the tests for other platforms.
|
||||
if: ${{ !cancelled() }}
|
||||
needs: build-jaxlib-artifact
|
||||
needs: [build-jax-artifact, build-jaxlib-artifact]
|
||||
uses: ./.github/workflows/pytest_cpu.yml
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
@ -80,7 +90,6 @@ jobs:
|
||||
runner: ${{ matrix.runner }}
|
||||
python: ${{ matrix.python }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
install-jax-current-commit: 1
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
||||
|
||||
run-pytest-cuda:
|
||||
@ -88,7 +97,7 @@ jobs:
|
||||
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
|
||||
# still want to run the tests for other platforms.
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [build-jaxlib-artifact, build-cuda-artifacts]
|
||||
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
|
||||
uses: ./.github/workflows/pytest_cuda.yml
|
||||
strategy:
|
||||
fail-fast: false # don't cancel all jobs on failure
|
||||
@ -111,7 +120,6 @@ jobs:
|
||||
python: ${{ matrix.python }}
|
||||
cuda: ${{ matrix.cuda }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
install-jax-current-commit: 1
|
||||
# GCS upload URI is the same for both artifact build jobs
|
||||
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
|
||||
|
||||
|
@ -40,9 +40,6 @@ jobs:
|
||||
runner: ${{ matrix.runner }}
|
||||
python: ${{ matrix.python }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
|
||||
# GCS bucket.
|
||||
install-jax-current-commit: 0
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
||||
|
||||
run-pytest-cuda:
|
||||
@ -61,7 +58,4 @@ jobs:
|
||||
python: ${{ matrix.python }}
|
||||
cuda: ${{ matrix.cuda }}
|
||||
enable-x64: ${{ matrix.enable-x64 }}
|
||||
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
|
||||
# GCS bucket.
|
||||
install-jax-current-commit: 0
|
||||
gcs_download_uri: ${{inputs.gcs_download_uri}}
|
@ -74,9 +74,4 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-}
|
||||
# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels
|
||||
# on the system. By default, it is set to match the version of the hermetic
|
||||
# Python used by Bazel for building the wheels.
|
||||
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
|
||||
|
||||
# Installs the JAX package in editable mode at the current commit. Enabled by
|
||||
# default. Nightly/Release builds disable this flag in the Github action
|
||||
# workflow files.
|
||||
export JAXCI_INSTALL_JAX_CURRENT_COMMIT=${JAXCI_INSTALL_JAX_CURRENT_COMMIT:-"1"}
|
||||
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
|
@ -19,7 +19,7 @@
|
||||
# avoid using the Windows version of `find` on Msys.
|
||||
WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) )
|
||||
|
||||
if [[ -z "$WHEELS" ]]; then
|
||||
if [[ -z "${WHEELS[@]}" ]]; then
|
||||
echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR"
|
||||
exit 1
|
||||
fi
|
||||
@ -38,10 +38,4 @@ if [[ $(uname -s) =~ "MSYS_NT" ]]; then
|
||||
"$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}")
|
||||
else
|
||||
"$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}"
|
||||
fi
|
||||
|
||||
if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then
|
||||
echo "Installing the JAX package at the current commit..."
|
||||
# Install JAX package at the current commit.
|
||||
"$JAXCI_PYTHON" -m uv pip install .
|
||||
fi
|
Loading…
x
Reference in New Issue
Block a user