mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Use cuda11_cudnn82
instead of cuda=11,cudnn=82
because the latter one is a syntax error
PiperOrigin-RevId: 404240654
This commit is contained in:
parent
a91cb81613
commit
ee752b32f7
@ -35,13 +35,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
pip install --upgrade pip
|
||||
|
||||
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
|
||||
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
|
||||
# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
|
||||
pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
|
||||
# Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
|
||||
pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
```
|
||||
|
||||
## jax 0.2.22 (Oct 12, 2021)
|
||||
|
@ -449,10 +449,10 @@ You can specify a particular CUDA and cuDNN version for jaxlib explicitly:
|
||||
pip install --upgrade pip
|
||||
|
||||
# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
|
||||
pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
|
||||
# Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
|
||||
pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
```
|
||||
|
||||
You can find your CUDA version with the command:
|
||||
|
6
setup.py
6
setup.py
@ -66,9 +66,9 @@ setup(
|
||||
'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda{_default_cuda_version}.cudnn{_default_cudnn_version}"],
|
||||
|
||||
# CUDA installations require adding jax releases URL; e.g.
|
||||
# $ pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
# $ pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
**{f'cuda={cuda_version},cudnn={cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda{cuda_version}.cudnn{cudnn_version}"
|
||||
# $ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
# $ pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
**{f'cuda{cuda_version}_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda{cuda_version}.cudnn{cudnn_version}"
|
||||
for cuda_version in _available_cuda_versions for cudnn_version in _available_cudnn_versions}
|
||||
},
|
||||
url='https://github.com/google/jax',
|
||||
|
Loading…
x
Reference in New Issue
Block a user