Use uv to install Python packages

PiperOrigin-RevId: 730499307
This commit is contained in:
Nitin Srinivasan 2025-02-24 10:13:00 -08:00 committed by jax authors
parent 62530d5922
commit 4b4f2f9cb9
13 changed files with 58 additions and 56 deletions

View File

@ -62,7 +62,8 @@ jobs:
run: |
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
pip install -r build/test-requirements.txt
pip install uv~=0.5.30
uv pip install -r build/test-requirements.txt
- name: Build and install JAX
env:
ASAN_OPTIONS: detect_leaks=0
@ -73,8 +74,8 @@ jobs:
--bazel_options=--color=yes \
--bazel_options=--copt=-fsanitize=address \
--clang_path=/usr/bin/clang-18
pip install dist/jaxlib-*.whl
pip install -e .
uv pip install dist/jaxlib-*.whl \
-e .
- name: Run tests
env:
ASAN_OPTIONS: detect_leaks=0

View File

@ -75,7 +75,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
pip install uv~=0.5.30
uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt
- name: Run tests
@ -113,7 +113,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
pip install uv~=0.5.30
uv pip install --system -r docs/requirements.txt
- name: Test documentation
env:
@ -147,7 +147,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
pip install uv~=0.5.30
uv pip install --system -r docs/requirements.txt
- name: Render documentation
run: |
@ -173,7 +173,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
pip install uv~=0.5.30
uv pip install --system .[minimum-jaxlib] tensorflow -r build/test-requirements.txt
- name: Run tests
@ -205,7 +205,7 @@ jobs:
python-version: 3.12
- name: Install JAX
run: |
pip install uv
pip install uv~=0.5.30
uv pip install --system .[cuda12]
- name: Build and install example project
run: uv pip install --system ./examples/ffi[test]

View File

@ -59,11 +59,10 @@ jobs:
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install JAX test requirements
run: |
$PYTHON -m pip install -U -r build/test-requirements.txt
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
$PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
$PYTHON -m pip uninstall -y jax jaxlib libtpu
$PYTHON -m uv pip uninstall -y jax jaxlib libtpu
if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
# Build and install jaxlib at head
$PYTHON build/build.py build --wheels=jaxlib \
@ -71,30 +70,27 @@ jobs:
--local_xla_path="$(pwd)/xla" \
--verbose
$PYTHON -m pip install dist/*.whl
# Install "jax" at head
$PYTHON -m pip install -U -e .
# Install libtpu
$PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Install jaxlib, "jax" at head, and libtpu
$PYTHON -m uv pip install dist/*.whl \
-U -e . \
--pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
$PYTHON -m pip install .[tpu] \
$PYTHON -m uv pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$PYTHON -m pip install requests
$PYTHON -m uv pip install \
--pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
--pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
install requests
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
$PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$PYTHON -m pip install requests
$PYTHON -m uv pip install \
--pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
--pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
install requests
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1

View File

@ -74,8 +74,7 @@ jobs:
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install JAX test requirements
run: |
$JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
$JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
$JAXCI_PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
- name: Build jaxlib at head with latest XLA
run: |
# Build and install jaxlib at head
@ -86,7 +85,7 @@ jobs:
--verbose
# Install libtpu
$JAXCI_PYTHON -m pip install --pre libtpu \
$JAXCI_PYTHON -m uv pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Halt for testing
- name: Wait For Connection

View File

@ -37,9 +37,8 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
uv pip install --system .[ci]
uv pip install --system pytest-xdist -r array-api-tests/requirements.txt
pip install uv~=0.5.30
uv pip install --system .[ci] --system pytest-xdist -r array-api-tests/requirements.txt
- name: Run the test suite
env:
ARRAY_API_TESTS_MODULE: jax.numpy

View File

@ -35,15 +35,14 @@ jobs:
rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
pip install -U pip numpy wheel
pip install absl-py pytest
pip install uv~=0.5.30
uv pip install -U pip numpy wheel absl-py pytest
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
pip install --pre jaxlib \
uv pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
fi;
cd jax
pip install .
pip install jax-metal
uv pip install . jax-metal
- name: Run test
run: |
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate

View File

@ -125,7 +125,7 @@ jobs:
echo "Skipping the test run."
exit 1
- name: Install Python dependencies
run: $JAXCI_PYTHON -m pip install -r build/requirements.in
run: $JAXCI_PYTHON -m uv pip install -r build/requirements.in
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main

View File

@ -106,7 +106,7 @@ jobs:
echo "Skipping the test run."
exit 1
- name: Install Python dependencies
run: $JAXCI_PYTHON -m pip install -r build/requirements.in
run: $JAXCI_PYTHON -m uv pip install -r build/requirements.in
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main

View File

@ -111,9 +111,9 @@ jobs:
export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH
python3 -m pip install -r requirements/build_requirements.txt
python3 -m pip install uv~=0.5.30
# Make sure to install a compatible Cython version (master branch is best for now)
python3 -m pip install -U git+https://github.com/cython/cython
python3 -m uv pip install -r requirements/build_requirements.txt -U git+https://github.com/cython/cython
CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized

View File

@ -38,10 +38,11 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install JAX test requirements
run: |
pip install .[ci] -r build/test-requirements.txt
pip install uv~=0.5.30
uv pip install .[ci] -r build/test-requirements.txt
- name: Install numpy & scipy development versions
run: |
pip install \
uv pip install \
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \
--no-deps \
--pre \

View File

@ -37,8 +37,9 @@ jobs:
BAZEL_VC: "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC"
JAXLIB_RELEASE: true
run: |
python -m pip install -r build/test-requirements.txt
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
python -m pip install uv~=0.5.30
python -m uv pip install -r build/test-requirements.txt \
--upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py build --wheels=jaxlib `
--bazel_options=--color=yes `
@ -57,7 +58,7 @@ jobs:
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
python -m pip install --find-links ${{ github.workspace }}\dist jaxlib
python -m pip install -e ${{ github.workspace }}
python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib \
-e ${{ github.workspace }}
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
pytest -n auto --tb=short tests examples

View File

@ -46,8 +46,8 @@ jobs:
JAXLIB_NIGHTLY: true # Tag the wheels as dev versions
run: |
cd jax
python -m pip install -r build/test-requirements.txt
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
python -m pip install uv~=0.5.30
python -m uv pip install -r build/test-requirements.txt --upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py build --wheels=jaxlib `
--bazel_options=--color=yes `
@ -67,7 +67,7 @@ jobs:
PY_COLORS: 1
run: |
cd jax
python -m pip install --pre --find-links ${{ github.workspace }}\jax\dist jaxlib
python -m pip install -e ${{ github.workspace }}\jax
python -m uv pip install --pre --find-links ${{ github.workspace }}\jax\dist jaxlib \
-e ${{ github.workspace }}\jax
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
pytest -n auto --tb=short tests examples

View File

@ -27,15 +27,21 @@ fi
echo "Installing the following wheels:"
echo "${WHEELS[@]}"
# Install `uv` if it's not already installed. `uv` is much faster than pip for
# installing Python packages.
if ! command -v uv >/dev/null 2>&1; then
pip install uv~=0.5.30
fi
# On Windows, convert MSYS Linux-like paths to Windows paths.
if [[ $(uname -s) =~ "MSYS_NT" ]]; then
"$JAXCI_PYTHON" -m pip install $(cygpath -w "${WHEELS[@]}")
"$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}")
else
"$JAXCI_PYTHON" -m pip install "${WHEELS[@]}"
"$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}"
fi
if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then
echo "Installing the JAX package in editable mode at the current commit..."
# Install JAX package at the current commit.
"$JAXCI_PYTHON" -m pip install -U -e .
"$JAXCI_PYTHON" -m uv pip install -U -e .
fi