Update the cuda 12 dependencies since we upgraded to cuda 12.3

PiperOrigin-RevId: 607453817
This commit is contained in:
Yash Katariya 2024-02-15 14:20:20 -08:00 committed by jax authors
parent 4834423e17
commit f12550964d

View File

@ -127,16 +127,15 @@ setup(
'cuda12_pip': [
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
"nvidia-cublas-cu12>=12.2.5.6",
"nvidia-cuda-cupti-cu12>=12.2.142",
"nvidia-cuda-nvcc-cu12>=12.2.140",
"nvidia-cuda-runtime-cu12>=12.2.140",
"nvidia-cudnn-cu12>=8.9",
"nvidia-cufft-cu12>=11.0.8.103",
"nvidia-cusolver-cu12>=11.5.2",
"nvidia-cusparse-cu12>=12.1.2.141",
"nvidia-cublas-cu12>=12.3.4.1",
"nvidia-cuda-cupti-cu12>=12.3.101",
"nvidia-cuda-nvcc-cu12>=12.3.107",
"nvidia-cuda-runtime-cu12>=12.3.101",
"nvidia-cudnn-cu12>=8.9.7.29",
"nvidia-cufft-cu12>=11.0.12.1",
"nvidia-cusolver-cu12>=11.5.4.101",
"nvidia-cusparse-cu12>=12.2.0.103",
"nvidia-nccl-cu12>=2.19.3",
# nvjitlink is not a direct dependency of JAX, but it is a transitive
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
# do not have a version constraint on their dependencies, so the
@ -144,22 +143,21 @@ setup(
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
# Until NVIDIA add version constraints, add an version constraint
# here.
"nvidia-nvjitlink-cu12>=12.2",
"nvidia-nvjitlink-cu12>=12.3.101",
],
'cuda12': [
f"jaxlib=={_current_jaxlib_version}",
f"jax-cuda12-plugin=={_current_jaxlib_version}",
"nvidia-cublas-cu12>=12.2.5.6",
"nvidia-cuda-cupti-cu12>=12.2.142",
"nvidia-cuda-nvcc-cu12>=12.2.140",
"nvidia-cuda-runtime-cu12>=12.2.140",
"nvidia-cudnn-cu12>=8.9",
"nvidia-cufft-cu12>=11.0.8.103",
"nvidia-cusolver-cu12>=11.5.2",
"nvidia-cusparse-cu12>=12.1.2.141",
"nvidia-cublas-cu12>=12.3.4.1",
"nvidia-cuda-cupti-cu12>=12.3.101",
"nvidia-cuda-nvcc-cu12>=12.3.107",
"nvidia-cuda-runtime-cu12>=12.3.101",
"nvidia-cudnn-cu12>=8.9.7.29",
"nvidia-cufft-cu12>=11.0.12.1",
"nvidia-cusolver-cu12>=11.5.4.101",
"nvidia-cusparse-cu12>=12.2.0.103",
"nvidia-nccl-cu12>=2.19.3",
# nvjitlink is not a direct dependency of JAX, but it is a transitive
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
# do not have a version constraint on their dependencies, so the
@ -167,7 +165,7 @@ setup(
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
# Until NVIDIA add version constraints, add an version constraint
# here.
"nvidia-nvjitlink-cu12>=12.2",
"nvidia-nvjitlink-cu12>=12.3.101",
],
# Target that does not depend on the CUDA pip wheels, for those who want