From 0ed42dcdd0c2758df402ab8146233db9ab0c3cf9 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 28 Feb 2025 11:45:32 -0800 Subject: [PATCH] Print the list of installed packages before running pytests Also, do not upgrade packages and disable editable mode when installing JAX at head PiperOrigin-RevId: 732208266 --- ci/run_pytest_cpu.sh | 4 ++++ ci/run_pytest_cuda.sh | 4 ++++ ci/run_pytest_tpu.sh | 4 ++++ ci/utilities/install_wheels_locally.sh | 4 ++-- 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh index 0b045bdc7..43581ef2c 100755 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -33,6 +33,10 @@ source ./ci/utilities/install_wheels_locally.sh # Set up the build environment. source "ci/utilities/setup_build_environment.sh" +# Print all the installed packages +echo "Installed packages:" +"$JAXCI_PYTHON" -m uv pip list + "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" # Set up all test environment variables diff --git a/ci/run_pytest_cuda.sh b/ci/run_pytest_cuda.sh index d98068385..45020542b 100755 --- a/ci/run_pytest_cuda.sh +++ b/ci/run_pytest_cuda.sh @@ -34,6 +34,10 @@ source ./ci/utilities/install_wheels_locally.sh # Set up the build environment. source "ci/utilities/setup_build_environment.sh" +# Print all the installed packages +echo "Installed packages:" +"$JAXCI_PYTHON" -m uv pip list + "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" nvidia-smi diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index 3707efe88..feaccea8e 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -33,6 +33,10 @@ source ./ci/utilities/install_wheels_locally.sh # Set up the build environment. source "ci/utilities/setup_build_environment.sh" +# Print all the installed packages +echo "Installed packages:" +"$JAXCI_PYTHON" -m uv pip list + "$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" "$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)' "$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)' diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index 8a49d8f06..f0e245e14 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -41,7 +41,7 @@ else fi if [[ "$JAXCI_INSTALL_JAX_CURRENT_COMMIT" == "1" ]]; then - echo "Installing the JAX package in editable mode at the current commit..." + echo "Installing the JAX package at the current commit..." # Install JAX package at the current commit. - "$JAXCI_PYTHON" -m uv pip install -U -e . + "$JAXCI_PYTHON" -m uv pip install . fi \ No newline at end of file