mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Update conda GPU install instructions.
This commit is contained in:
parent
c4778c4804
commit
eb1710f84a
@ -526,14 +526,14 @@ conda install jax -c conda-forge
|
||||
|
||||
To install on a machine with an NVidia GPU, run
|
||||
```bash
|
||||
conda install jax cuda-nvcc -c conda-forge -c nvidia
|
||||
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
|
||||
```
|
||||
|
||||
Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which
|
||||
JAX requires. You must therefore either install the `cuda-nvcc` package from
|
||||
the `nvidia` channel, or install CUDA on your machine separately so that `ptxas`
|
||||
is in your path. The channel order above is important (`conda-forge` before
|
||||
`nvidia`). We are working on simplifying this.
|
||||
`nvidia`).
|
||||
|
||||
If you would like to override which release of CUDA is used by JAX, or to
|
||||
install the CUDA build on a machine without GPUs, follow the instructions in the
|
||||
|
Loading…
x
Reference in New Issue
Block a user