Add cupti pip dependency, needed for GPU profiling.

Issue https://github.com/google/jax/issues/15384

PiperOrigin-RevId: 521841461
This commit is contained in:
Peter Hawkins 2023-04-04 12:54:55 -07:00 committed by jax authors
parent c1f65fc8b2
commit 75d0f6522d

View File

@ -97,6 +97,7 @@ setup(
'cuda11_pip': [
f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{_default_cuda11_cudnn_version}",
"nvidia-cublas-cu11>=11.11",
"nvidia-cuda-cupti-cu11>=11.8",
"nvidia-cuda-nvcc-cu11>=11.8",
"nvidia-cuda-runtime-cu11>=11.8",
"nvidia-cudnn-cu11>=8.6",
@ -108,6 +109,7 @@ setup(
'cuda12_pip': [
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
"nvidia-cublas-cu12",
"nvidia-cuda-cupti-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-cuda-runtime-cu12",
"nvidia-cudnn-cu12",