diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml new file mode 100644 index 000000000..762f647b0 --- /dev/null +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -0,0 +1,67 @@ +# CI - Wheel Tests (Continuous) +# +# This workflow builds JAX artifacts and runs CPU/CUDA tests. +# +# It orchestrates the following: +# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and +# uploads it to a GCS bucket. +# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow to download the jaxlib wheel that was built +# in the previous step and runs CPU tests. +# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and +# uploads them to a GCS bucket. +# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow to download the jaxlib and CUDA artifacts +# that were built in the previous steps and runs the CUDA tests. +name: CI - Wheel Tests (Nightly/Release) + +on: + workflow_dispatch: + inputs: + gcs_download_uri: + description: "GCS location URI from where the artifacts should be downloaded" + required: true + default: 'gs://jax-nightly-release-transient/nightly/latest' + type: string + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +env: + # Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the + # GCS bucket. + JAXCI_INSTALL_JAX_CURRENT_COMMIT: "0" + +jobs: + run-pytest-cpu: + uses: ./.github/workflows/pytest_cpu.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Runner OS and Python values need to match the matrix stategy of our internal CI jobs + # that build the wheels. + runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] + python: ["3.10","3.11", "3.12", "3.13"] + enable-x64: [0] + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + enable-x64: ${{ matrix.enable-x64 }} + gcs_download_uri: ${{inputs.gcs_download_uri}} + + run-pytest-cuda: + uses: ./.github/workflows/pytest_cuda.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Runner OS and Python values need to match the matrix stategy of our internal CI jobs + # that build the wheels. + runner: ["linux-x86-g2-48-l4-4gpu"] + python: ["3.10","3.11", "3.12", "3.13"] + cuda: ["12.3", "12.1"] + enable-x64: [0] + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + cuda: ${{ matrix.cuda }} + enable-x64: ${{ matrix.enable-x64 }} + gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file diff --git a/ci/envs/default.env b/ci/envs/default.env index ae434dc61..f74a29688 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -67,3 +67,8 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} # on the system. By default, it is set to match the version of the hermetic # Python used by Bazel for building the wheels. export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} + +# Installs the JAX package in editable mode at the current commit. Enabled by +# default. Nightly/Release builds disable this flag in the Github action +# workflow files. +export JAXCI_INSTALL_JAX_CURRENT_COMMIT=${JAXCI_INSTALL_JAX_CURRENT_COMMIT:-"1"} \ No newline at end of file diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index 4af679c9d..fbeafe22d 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -17,7 +17,7 @@ # Install wheels stored in `JAXCI_OUTPUT_DIR` on the system using the Python # binary set in JAXCI_PYTHON. Use the absolute path to the `find` utility to # avoid using the Windows version of `find` on Msys. -WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) ) +WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) ) if [[ -z "$WHEELS" ]]; then echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" @@ -34,6 +34,8 @@ else "$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" fi -echo "Installing the JAX package in editable mode at the current commit..." -# Install JAX package at the current commit. -"$JAXCI_PYTHON" -m pip install -U -e . +if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then + echo "Installing the JAX package in editable mode at the current commit..." + # Install JAX package at the current commit. + "$JAXCI_PYTHON" -m pip install -U -e . +fi \ No newline at end of file