diff --git a/CHANGELOG.md b/CHANGELOG.md index 54575d1c4..47cea93ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,13 +35,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. - pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html + pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. - pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html + pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. - pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html + pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html ``` ## jax 0.2.22 (Oct 12, 2021) diff --git a/README.md b/README.md index e646cc844..19ecc693a 100644 --- a/README.md +++ b/README.md @@ -449,10 +449,10 @@ You can specify a particular CUDA and cuDNN version for jaxlib explicitly: pip install --upgrade pip # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. -pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html +pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. -pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html +pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html ``` You can find your CUDA version with the command: diff --git a/setup.py b/setup.py index 29f4495d3..5922c012e 100644 --- a/setup.py +++ b/setup.py @@ -66,9 +66,9 @@ setup( 'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda{_default_cuda_version}.cudnn{_default_cudnn_version}"], # CUDA installations require adding jax releases URL; e.g. - # $ pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html - # $ pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html - **{f'cuda={cuda_version},cudnn={cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda{cuda_version}.cudnn{cudnn_version}" + # $ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html + # $ pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html + **{f'cuda{cuda_version}_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda{cuda_version}.cudnn{cudnn_version}" for cuda_version in _available_cuda_versions for cudnn_version in _available_cudnn_versions} }, url='https://github.com/google/jax',