mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add notes about the new CUDA version restrictions to the changelog and installation instructions.
This commit is contained in:
parent
87af945cbe
commit
b7dfde8d87
14
CHANGELOG.md
14
CHANGELOG.md
@ -16,18 +16,24 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
|
||||
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
|
||||
* Changes:
|
||||
* {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken
|
||||
across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy
|
||||
1.11.
|
||||
* CUDA: JAX now verifies that the CUDA libraries it finds are at least as new
|
||||
as the CUDA libraries that JAX was built against. If older libraries are
|
||||
found, JAX raises an exception since that is preferable to mysterious
|
||||
failures and crashes.
|
||||
* Removed the "No GPU/TPU" found warning. Instead warn if, on Linux, an
|
||||
NVIDIA GPU or a Google TPU are found but not used and `--jax_platforms` was
|
||||
not specified.
|
||||
* {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken
|
||||
across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy
|
||||
1.11.
|
||||
|
||||
# jaxlib 0.4.17
|
||||
|
||||
* Changes:
|
||||
* Python 3.12 wheels were added in this release.
|
||||
* The CUDA 12 wheels now require CUDA 12.2.
|
||||
* The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.
|
||||
* This is likely the last release of jaxlib that supports CUDA 11. We intend
|
||||
to drop CUDA 11 support in the next release.
|
||||
|
||||
* Bug fixes:
|
||||
* Fixed log spam from ABSL when the JAX CPU backend was initialized.
|
||||
|
@ -58,8 +58,9 @@ not being installed alongside `jax`, although `jax` may successfully install
|
||||
### pip installation: GPU (CUDA, installed via pip, easier)
|
||||
|
||||
There are two ways to install JAX with NVIDIA GPU support: using CUDA and CUDNN
|
||||
installed from pip wheels, and using a self-installed CUDA/CUDNN. We recommend
|
||||
installing CUDA and CUDNN using the pip wheels, since it is much easier!
|
||||
installed from pip wheels, and using a self-installed CUDA/CUDNN. We
|
||||
strongly recommend installing CUDA and CUDNN using the pip wheels, since it is
|
||||
much easier!
|
||||
|
||||
JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer.
|
||||
Note that Kepler-series GPUs are no longer supported by JAX since
|
||||
@ -87,6 +88,13 @@ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-re
|
||||
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
|
||||
If JAX detects the wrong version of the CUDA libraries, there are several things
|
||||
to check:
|
||||
* make sure that `LD_LIBRARY_PATH` is not set, since `LD_LIBRARY_PATH` can
|
||||
override the CUDA libraries.
|
||||
* make sure that the CUDA libraries installed are those requested by JAX.
|
||||
Rerunning the installation command above should work.
|
||||
|
||||
### pip installation: GPU (CUDA, installed locally, harder)
|
||||
|
||||
If you prefer to use a preinstalled copy of CUDA, you must first
|
||||
@ -106,13 +114,13 @@ able to use the
|
||||
that NVIDIA provides for this purpose.
|
||||
|
||||
JAX currently ships two CUDA wheel variants:
|
||||
* CUDA 12.0 and CuDNN 8.9.
|
||||
* CUDA 12.2 and CuDNN 8.9.4.
|
||||
* CUDA 11.8 and CuDNN 8.6.
|
||||
|
||||
You may use a JAX wheel provided the major version of your CUDA and CuDNN
|
||||
installation matches, and the minor version is at least as new as the version
|
||||
JAX expects. For example, you would be able to use the CUDA 12.0 wheel with
|
||||
CUDA 12.1 and CuDNN 8.9.
|
||||
You may use a JAX wheel provided the major version of your CUDA
|
||||
installation matches, and the minor versions are the same or newer.
|
||||
JAX checks the versions of your libraries, and will report an error if they are
|
||||
not sufficiently new.
|
||||
|
||||
Your CUDA installation must also be new enough to support your GPU. If you have
|
||||
an Ada Lovelace (e.g., RTX 4080) or Hopper (e.g., H100) GPU,
|
||||
|
Loading…
x
Reference in New Issue
Block a user