Update conda GPU install instructions.

This commit is contained in:
Peter Hawkins 2023-02-23 08:37:10 -05:00
parent c4778c4804
commit eb1710f84a

View File

@ -526,14 +526,14 @@ conda install jax -c conda-forge
To install on a machine with an NVidia GPU, run
```bash
conda install jax cuda-nvcc -c conda-forge -c nvidia
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
```
Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which
JAX requires. You must therefore either install the `cuda-nvcc` package from
the `nvidia` channel, or install CUDA on your machine separately so that `ptxas`
is in your path. The channel order above is important (`conda-forge` before
`nvidia`). We are working on simplifying this.
`nvidia`).
If you would like to override which release of CUDA is used by JAX, or to
install the CUDA build on a machine without GPUs, follow the instructions in the