mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15278 from hawkinsp:cudainstall
PiperOrigin-RevId: 520364354
This commit is contained in:
commit
2d94f76ca3
112
README.md
112
README.md
@ -417,66 +417,82 @@ On Linux, it is often necessary to first update `pip` to a version that supports
|
||||
**These `pip` installations do not work with Windows, and may fail silently; see
|
||||
[above](#installation).**
|
||||
|
||||
### pip installation: GPU (CUDA)
|
||||
### pip installation: GPU (CUDA, installed via pip, easier)
|
||||
|
||||
If you want to install JAX with both CPU and NVidia GPU support, you must first
|
||||
install [CUDA](https://developer.nvidia.com/cuda-downloads) and
|
||||
[CuDNN](https://developer.nvidia.com/CUDNN),
|
||||
if they have not already been installed. Unlike some other popular deep
|
||||
learning systems, JAX does not bundle CUDA or CuDNN as part of the `pip`
|
||||
package.
|
||||
There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN
|
||||
installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend
|
||||
installing CUDA and CUDNN using the pip wheels, since it is much easier!
|
||||
|
||||
JAX provides pre-built CUDA-compatible wheels for **Linux only**,
|
||||
with CUDA 11.4 or newer, and CuDNN 8.2 or newer. Note these existing wheels are currently for `x86_64` architectures only. Other combinations of
|
||||
operating system, CUDA, and CuDNN are possible, but require [building from
|
||||
source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
|
||||
|
||||
* CUDA 11.4 or newer is *required*.
|
||||
* Your CUDA installation must be new enough to support your GPU. If you have
|
||||
an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU,
|
||||
you must use CUDA 11.8 or newer.
|
||||
* The supported cuDNN versions for the prebuilt wheels are:
|
||||
* cuDNN 8.6 or newer. We recommend using the cuDNN 8.6 wheel if your cuDNN
|
||||
installation is new enough, since it supports additional functionality.
|
||||
* cuDNN 8.2 or newer.
|
||||
* You *must* use an NVIDIA driver version that is at least as new as your
|
||||
[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions).
|
||||
For example, if you have CUDA 11.4 update 4 installed, you must use NVIDIA
|
||||
driver 470.82.01 or newer if on Linux. This is a strict requirement that
|
||||
exists because JAX relies on JIT-compiling code; older drivers may lead to
|
||||
failures.
|
||||
* If you need to use an newer CUDA toolkit with an older driver, for example
|
||||
on a cluster where you cannot update the NVIDIA driver easily, you may be
|
||||
able to use the
|
||||
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
|
||||
that NVIDIA provides for this purpose.
|
||||
|
||||
Next, run
|
||||
You must first install the NVIDIA driver. We
|
||||
recommend installing the newest driver available from NVIDIA, but the driver
|
||||
must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
|
||||
# CUDA 12 installation
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# CUDA 11 installation
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
|
||||
### pip installation: GPU (CUDA, installed locally, harder)
|
||||
|
||||
If you prefer to use a preinstalled copy of CUDA, you must first
|
||||
install [CUDA](https://developer.nvidia.com/cuda-downloads) and
|
||||
[CuDNN](https://developer.nvidia.com/CUDNN).
|
||||
|
||||
JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other
|
||||
combinations of operating system and architecture are possible, but require
|
||||
[building from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
|
||||
|
||||
You should use an NVIDIA driver version that is at least as new as your
|
||||
[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions).
|
||||
If you need to use an newer CUDA toolkit with an older driver, for example
|
||||
on a cluster where you cannot update the NVIDIA driver easily, you may be
|
||||
able to use the
|
||||
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
|
||||
that NVIDIA provides for this purpose.
|
||||
|
||||
JAX currently ships three CUDA wheel variants:
|
||||
* CUDA 12.0 and CuDNN 8.8.
|
||||
* CUDA 11.8 and CuDNN 8.6.
|
||||
* CUDA 11.4 and CuDNN 8.2. This wheel is deprecated and will be discontinued
|
||||
with jax 0.4.8.
|
||||
|
||||
You may use a JAX wheel provided the major version of your CUDA and CuDNN
|
||||
installation matches, and the minor version is at least as new as the version
|
||||
JAX expects. For example, you would be able to use the CUDA 12.0 wheel with
|
||||
CUDA 12.1 and CuDNN 8.9.
|
||||
|
||||
Your CUDA installation must also be new enough to support your GPU. If you have
|
||||
an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU,
|
||||
you must use CUDA 11.8 or newer.
|
||||
|
||||
|
||||
To install, run
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
|
||||
# Installs the wheel compatible with CUDA 12 and cuDNN 8.8 or newer.
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Installs the wheel compatible with Cuda 11.4+ and cudnn 8.2+ (deprecated).
|
||||
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
|
||||
**These `pip` installations do not work with Windows, and may fail silently; see
|
||||
[above](#installation).**
|
||||
|
||||
The jaxlib version must correspond to the version of the existing CUDA
|
||||
installation you want to use. You can specify a particular CUDA and CuDNN
|
||||
version for jaxlib explicitly:
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
|
||||
# Installs the wheel compatible with Cuda >= 11.8 and cudnn >= 8.6
|
||||
pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
|
||||
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
|
||||
You can find your CUDA version with the command:
|
||||
|
||||
```bash
|
||||
|
Loading…
x
Reference in New Issue
Block a user