Merge pull request #18991 from skye:revert_cuda_install

PiperOrigin-RevId: 591097432
This commit is contained in:
jax authors 2023-12-14 17:28:34 -08:00
commit a7b60234d9
2 changed files with 36 additions and 37 deletions

View File

@ -394,8 +394,8 @@ Some standouts:
| Hardware | Instructions |
|------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |

View File

@ -55,7 +55,7 @@ not being installed alongside `jax`, although `jax` may successfully install
## NVIDIA GPU
JAX supports NVIDIA GPUs that have SM version 5.0 (Maxwell) or newer.
JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer.
Note that Kepler-series GPUs are no longer supported by JAX since
NVIDIA has dropped support for Kepler GPUs in its software.
@ -81,11 +81,11 @@ pip install --upgrade pip
# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11]"
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
If JAX detects the wrong version of the CUDA libraries, there are several things
@ -162,6 +162,37 @@ Toolbox](https://github.com/NVIDIA/JAX-Toolbox) containers, which are
bleeding edge containers containing nightly releases of jax and some
models/frameworks.
## Nightly installation
Nightly releases reflect the state of the main repository at the time they are
built, and may not pass the full test suite.
* JAX:
```bash
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
```
* Jaxlib CPU:
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
```
* Jaxlib TPU:
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
* Jaxlib GPU (Cuda 12):
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
```
* Jaxlib GPU (Cuda 11):
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
```
## Google TPU
### pip installation: Google Cloud TPU
@ -234,38 +265,6 @@ See the `conda-forge`
[jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories
for more details.
## Nightly installation
Nightly releases reflect the state of the main repository at the time they are
built, and may not pass the full test suite.
* JAX:
```bash
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
```
* Jaxlib CPU:
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
```
* Jaxlib TPU:
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
* Jaxlib GPU (Cuda 12):
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
```
* Jaxlib GPU (Cuda 11):
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
```
## Building JAX from source
See [Building JAX from source](developer.md#building-from-source).