Clarify the NVidia driver version requirements.

This commit is contained in:
Peter Hawkins 2022-02-17 14:31:16 -05:00
parent 032bfe0915
commit d704c151fa

View File

@ -425,24 +425,25 @@ with CUDA 11.1 or newer, and CuDNN 8.0.5 or newer. Other combinations of
operating system, CUDA, and CuDNN are possible, but require building from
source.
* CUDA 11.1 or newer is required.
* CUDA 11.1 or newer is *required*.
* You may be able to use older CUDA versions if you build from source, but
there are known bugs in CUDA in all CUDA versions older than 11.1.
there are known bugs in CUDA in all CUDA versions older than 11.1, so we
do not ship prebuilt binaries for older CUDA versions.
* The supported cuDNN versions for the prebuilt wheels are:
* cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
installation is new enough, since it supports additional functionality.
* cuDNN 8.0.5 or newer.
* To use the prebuilt wheels, we recommend you have NVidia driver 470.82.01 or
newer installed.
* If your GPU has a
[CUDA compute capability](https://developer.nvidia.com/cuda-gpus) that
matches one of the architectures that the JAX wheels are built for
(3.5, 5.2, 6.0, 7.0, and 8.0), you may also be able to use older drivers
that are version 450.80.02 or newer.
* If your GPU does not match one of those architectures, you must either
use driver 470.82.01 or newer, use the
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/),
or build jaxlib from source.
* You *must* use an NVidia driver version that is at least as new as your
[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions).
For example, if you have CUDA 11.4 update 4 installed, you must use NVidia
driver 470.82.01 or newer if on Linux. This is a strict requirement that
exists because JAX relies on JIT-compiling code; older drivers may lead to
failures.
* If you need to use an newer CUDA toolkit with an older driver, for example
on a cluster where you cannot update the NVidia driver easily, you may be
able to use the
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
that NVidia provides for this purpose.
Next, run
@ -461,7 +462,7 @@ version for jaxlib explicitly:
```bash
pip install --upgrade pip
# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.2
# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5