mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #25374 from traversaro:patch-1
PiperOrigin-RevId: 704673954
This commit is contained in:
commit
8e7aaa792b
@ -253,18 +253,14 @@ simply run:
|
||||
conda install jax -c conda-forge
|
||||
```
|
||||
|
||||
To install it on a machine with an NVIDIA GPU, run:
|
||||
If you run this command on machine with an NVIDIA GPU, this should install a CUDA-enabled package of `jaxlib`.
|
||||
|
||||
To ensure that the jax version you are installing is indeed CUDA-enabled, run:
|
||||
|
||||
```bash
|
||||
conda install "jaxlib=*=*cuda*" jax cuda-nvcc -c conda-forge -c nvidia
|
||||
conda install "jaxlib=*=*cuda*" jax -c conda-forge
|
||||
```
|
||||
|
||||
Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which
|
||||
JAX requires. You must therefore either install the `cuda-nvcc` package from
|
||||
the `nvidia` channel, or install CUDA on your machine separately so that `ptxas`
|
||||
is in your path. The channel order above is important (`conda-forge` before
|
||||
`nvidia`).
|
||||
|
||||
If you would like to override which release of CUDA is used by JAX, or to
|
||||
install the CUDA build on a machine without GPUs, follow the instructions in the
|
||||
[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch)
|
||||
|
Loading…
x
Reference in New Issue
Block a user