Print the list of installed packages before running pytests

Also, do not upgrade packages and disable editable mode when installing JAX at head

PiperOrigin-RevId: 732208266
This commit is contained in:
Nitin Srinivasan 2025-02-28 11:45:32 -08:00 committed by jax authors
parent da1cc0a50e
commit 0ed42dcdd0
4 changed files with 14 additions and 2 deletions

View File

@ -33,6 +33,10 @@ source ./ci/utilities/install_wheels_locally.sh
# Set up the build environment.
source "ci/utilities/setup_build_environment.sh"
# Print all the installed packages
echo "Installed packages:"
"$JAXCI_PYTHON" -m uv pip list
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
# Set up all test environment variables

View File

@ -34,6 +34,10 @@ source ./ci/utilities/install_wheels_locally.sh
# Set up the build environment.
source "ci/utilities/setup_build_environment.sh"
# Print all the installed packages
echo "Installed packages:"
"$JAXCI_PYTHON" -m uv pip list
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
nvidia-smi

View File

@ -33,6 +33,10 @@ source ./ci/utilities/install_wheels_locally.sh
# Set up the build environment.
source "ci/utilities/setup_build_environment.sh"
# Print all the installed packages
echo "Installed packages:"
"$JAXCI_PYTHON" -m uv pip list
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
"$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)'
"$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)'

View File

@ -41,7 +41,7 @@ else
fi
if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then
echo "Installing the JAX package in editable mode at the current commit..."
echo "Installing the JAX package at the current commit..."
# Install JAX package at the current commit.
"$JAXCI_PYTHON" -m uv pip install -U -e .
"$JAXCI_PYTHON" -m uv pip install .
fi