README: mention that CUDA installation is for linux only

This commit is contained in:
Jake VanderPlas 2021-08-28 08:15:21 -07:00
parent ae9fad62bf
commit e46200e123

View File

@ -420,7 +420,8 @@ install [CUDA](https://developer.nvidia.com/cuda-downloads) and
[CuDNN](https://developer.nvidia.com/CUDNN),
if they have not already been installed. Unlike some other popular deep
learning systems, JAX does not bundle CUDA or CuDNN as part of the `pip`
package. The CUDA 10 JAX wheels require CuDNN 7, whereas the CUDA 11 wheels of
package. JAX provides pre-built CUDA-compatible wheels for **linux only**;
the CUDA 10 JAX wheels require CuDNN 7, whereas the CUDA 11 wheels of
JAX require CuDNN 8. Other combinations of CUDA and CuDNN are possible but
require building from source.
@ -428,7 +429,7 @@ Next, run
```bash
pip install --upgrade pip
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Note: wheels only available on linux.
```
The jaxlib version must correspond to the version of the existing CUDA