Use cuda11_cudnn82 instead of cuda=11,cudnn=82 because the latter one is a syntax error

PiperOrigin-RevId: 404240654
This commit is contained in:
Yash Katariya 2021-10-19 06:24:27 -07:00 committed by jax authors
parent a91cb81613
commit ee752b32f7
3 changed files with 8 additions and 8 deletions

View File

@ -35,13 +35,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```
## jax 0.2.22 (Oct 12, 2021)

View File

@ -449,10 +449,10 @@ You can specify a particular CUDA and cuDNN version for jaxlib explicitly:
pip install --upgrade pip
# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```
You can find your CUDA version with the command:

View File

@ -66,9 +66,9 @@ setup(
'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda{_default_cuda_version}.cudnn{_default_cudnn_version}"],
# CUDA installations require adding jax releases URL; e.g.
# $ pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
# $ pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
**{f'cuda={cuda_version},cudnn={cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda{cuda_version}.cudnn{cudnn_version}"
# $ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
# $ pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
**{f'cuda{cuda_version}_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda{cuda_version}.cudnn{cudnn_version}"
for cuda_version in _available_cuda_versions for cudnn_version in _available_cudnn_versions}
},
url='https://github.com/google/jax',