mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #18991 from skye:revert_cuda_install
PiperOrigin-RevId: 591097432
This commit is contained in:
commit
a7b60234d9
@ -394,8 +394,8 @@ Some standouts:
|
||||
|
||||
| Hardware | Instructions |
|
||||
|------------|-----------------------------------------------------------------------------------------------------------------|
|
||||
| CPU | `pip install -U "jax[cpu]"` |
|
||||
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
|
||||
| CPU | `pip install -U "jax[cpu]"` |
|
||||
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` |
|
||||
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
|
||||
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
|
||||
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
|
||||
|
@ -55,7 +55,7 @@ not being installed alongside `jax`, although `jax` may successfully install
|
||||
|
||||
## NVIDIA GPU
|
||||
|
||||
JAX supports NVIDIA GPUs that have SM version 5.0 (Maxwell) or newer.
|
||||
JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer.
|
||||
Note that Kepler-series GPUs are no longer supported by JAX since
|
||||
NVIDIA has dropped support for Kepler GPUs in its software.
|
||||
|
||||
@ -81,11 +81,11 @@ pip install --upgrade pip
|
||||
|
||||
# CUDA 12 installation
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda12]"
|
||||
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# CUDA 11 installation
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda11]"
|
||||
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
|
||||
If JAX detects the wrong version of the CUDA libraries, there are several things
|
||||
@ -162,6 +162,37 @@ Toolbox](https://github.com/NVIDIA/JAX-Toolbox) containers, which are
|
||||
bleeding edge containers containing nightly releases of jax and some
|
||||
models/frameworks.
|
||||
|
||||
## Nightly installation
|
||||
|
||||
Nightly releases reflect the state of the main repository at the time they are
|
||||
built, and may not pass the full test suite.
|
||||
|
||||
* JAX:
|
||||
```bash
|
||||
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
```
|
||||
|
||||
* Jaxlib CPU:
|
||||
```bash
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
```
|
||||
|
||||
* Jaxlib TPU:
|
||||
```bash
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
```
|
||||
|
||||
* Jaxlib GPU (Cuda 12):
|
||||
```bash
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
|
||||
```
|
||||
|
||||
* Jaxlib GPU (Cuda 11):
|
||||
```bash
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
|
||||
```
|
||||
|
||||
## Google TPU
|
||||
|
||||
### pip installation: Google Cloud TPU
|
||||
@ -234,38 +265,6 @@ See the `conda-forge`
|
||||
[jax](https://github.com/conda-forge/jax-feedstock#installing-jax) repositories
|
||||
for more details.
|
||||
|
||||
|
||||
## Nightly installation
|
||||
|
||||
Nightly releases reflect the state of the main repository at the time they are
|
||||
built, and may not pass the full test suite.
|
||||
|
||||
* JAX:
|
||||
```bash
|
||||
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
```
|
||||
|
||||
* Jaxlib CPU:
|
||||
```bash
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
```
|
||||
|
||||
* Jaxlib TPU:
|
||||
```bash
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
```
|
||||
|
||||
* Jaxlib GPU (Cuda 12):
|
||||
```bash
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
|
||||
```
|
||||
|
||||
* Jaxlib GPU (Cuda 11):
|
||||
```bash
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
|
||||
```
|
||||
|
||||
## Building JAX from source
|
||||
See [Building JAX from source](developer.md#building-from-source).
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user