mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Update the cuda 12 dependencies since we upgraded to cuda 12.3
PiperOrigin-RevId: 607453817
This commit is contained in:
parent
4834423e17
commit
f12550964d
38
setup.py
38
setup.py
@ -127,16 +127,15 @@ setup(
|
||||
|
||||
'cuda12_pip': [
|
||||
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
|
||||
"nvidia-cublas-cu12>=12.2.5.6",
|
||||
"nvidia-cuda-cupti-cu12>=12.2.142",
|
||||
"nvidia-cuda-nvcc-cu12>=12.2.140",
|
||||
"nvidia-cuda-runtime-cu12>=12.2.140",
|
||||
"nvidia-cudnn-cu12>=8.9",
|
||||
"nvidia-cufft-cu12>=11.0.8.103",
|
||||
"nvidia-cusolver-cu12>=11.5.2",
|
||||
"nvidia-cusparse-cu12>=12.1.2.141",
|
||||
"nvidia-cublas-cu12>=12.3.4.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.3.101",
|
||||
"nvidia-cuda-nvcc-cu12>=12.3.107",
|
||||
"nvidia-cuda-runtime-cu12>=12.3.101",
|
||||
"nvidia-cudnn-cu12>=8.9.7.29",
|
||||
"nvidia-cufft-cu12>=11.0.12.1",
|
||||
"nvidia-cusolver-cu12>=11.5.4.101",
|
||||
"nvidia-cusparse-cu12>=12.2.0.103",
|
||||
"nvidia-nccl-cu12>=2.19.3",
|
||||
|
||||
# nvjitlink is not a direct dependency of JAX, but it is a transitive
|
||||
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
|
||||
# do not have a version constraint on their dependencies, so the
|
||||
@ -144,22 +143,21 @@ setup(
|
||||
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
|
||||
# Until NVIDIA add version constraints, add an version constraint
|
||||
# here.
|
||||
"nvidia-nvjitlink-cu12>=12.2",
|
||||
"nvidia-nvjitlink-cu12>=12.3.101",
|
||||
],
|
||||
|
||||
'cuda12': [
|
||||
f"jaxlib=={_current_jaxlib_version}",
|
||||
f"jax-cuda12-plugin=={_current_jaxlib_version}",
|
||||
"nvidia-cublas-cu12>=12.2.5.6",
|
||||
"nvidia-cuda-cupti-cu12>=12.2.142",
|
||||
"nvidia-cuda-nvcc-cu12>=12.2.140",
|
||||
"nvidia-cuda-runtime-cu12>=12.2.140",
|
||||
"nvidia-cudnn-cu12>=8.9",
|
||||
"nvidia-cufft-cu12>=11.0.8.103",
|
||||
"nvidia-cusolver-cu12>=11.5.2",
|
||||
"nvidia-cusparse-cu12>=12.1.2.141",
|
||||
"nvidia-cublas-cu12>=12.3.4.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.3.101",
|
||||
"nvidia-cuda-nvcc-cu12>=12.3.107",
|
||||
"nvidia-cuda-runtime-cu12>=12.3.101",
|
||||
"nvidia-cudnn-cu12>=8.9.7.29",
|
||||
"nvidia-cufft-cu12>=11.0.12.1",
|
||||
"nvidia-cusolver-cu12>=11.5.4.101",
|
||||
"nvidia-cusparse-cu12>=12.2.0.103",
|
||||
"nvidia-nccl-cu12>=2.19.3",
|
||||
|
||||
# nvjitlink is not a direct dependency of JAX, but it is a transitive
|
||||
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
|
||||
# do not have a version constraint on their dependencies, so the
|
||||
@ -167,7 +165,7 @@ setup(
|
||||
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
|
||||
# Until NVIDIA add version constraints, add an version constraint
|
||||
# here.
|
||||
"nvidia-nvjitlink-cu12>=12.2",
|
||||
"nvidia-nvjitlink-cu12>=12.3.101",
|
||||
],
|
||||
|
||||
# Target that does not depend on the CUDA pip wheels, for those who want
|
||||
|
Loading…
x
Reference in New Issue
Block a user