Update JAX dependencies, extras, and documentation for plugins.

* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
This commit is contained in:
Peter Hawkins 2024-06-13 10:50:24 -04:00
parent a9edaeb38e
commit b13733c13f
7 changed files with 59 additions and 98 deletions

View File

@ -13,9 +13,16 @@ Remember to align the itemized text with the first line of an item within a list
bumped to 0.4.0 but this has been rolled back in this release to give users
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
release.
* jax now depends on jaxlib directly. This change was enabled by the CUDA
plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with `pip install jax`, no extras required.
## jaxlib 0.4.30
* Support for monolithic CUDA jaxlibs has been dropped. You must use the
plugin-based installation (`pip install jax[cuda12]` or
`pip install jax[cuda12_local]`).
## jax 0.4.29 (June 10, 2024)
* Changes

View File

@ -396,8 +396,8 @@ Some standouts:
| Hardware | Instructions |
|------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12]"` |
| CPU | `pip install -U jax` |
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |

View File

@ -9,20 +9,16 @@ different builds for different operating systems and accelerators.
* **CPU-only (Linux/macOS/Windows)**
```
pip install -U "jax[cpu]"
pip install -U jax
```
* **GPU (NVIDIA, CUDA 12, x86_64)**
* **GPU (NVIDIA, CUDA 12)**
```
pip install -U "jax[cuda12]"
```
* **GPU (NVIDIA, CUDA 12, x86_64) legacy**
You should prefer `jax[cuda12]`, which uses the common CPU jaxlib and adds GPU
support as a plugin. The monolithic `jax[cuda12_pip]` option will be removed in
a future JAX release.
* **TPU (Google Cloud TPU VM) **
```
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
(install-supported-platforms)=
@ -48,6 +44,7 @@ Currently, the JAX team releases `jaxlib` wheels for the following
operating systems and architectures:
- Linux, x86_64
- Linux, aarch64
- macOS, Intel
- macOS, Apple ARM-based
- Windows, x86_64 (*experimental*)
@ -57,7 +54,7 @@ development on a laptop, you can run:
```bash
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
pip install --upgrade jax
```
On Windows, you may also need to install the
@ -97,8 +94,8 @@ There are two ways to install JAX with NVIDIA GPU support:
The JAX team strongly recommends installing CUDA and cuDNN using the pip wheels,
since it is much easier!
This method is only supported on x86_64, because NVIDIA has not released aarch64
CUDA pip packages.
NVIDIA has released CUDA pip packages only for x86_64 and aarch64; on other
platforms you must use a local installation of CUDA.
```bash
pip install --upgrade pip
@ -106,11 +103,6 @@ pip install --upgrade pip
# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"
# Legacy way of NVIDIA CUDA 12 installation. You should prefer `jax[cuda12]`,
# which uses the common CPU jaxlib and adds GPU support as a plugin. The
# monolithic `jax[cuda12_pip]` option will be removed in a future JAX release.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things
@ -127,7 +119,7 @@ If you prefer to use a preinstalled copy of NVIDIA CUDA, you must first
install NVIDIA [CUDA](https://developer.nvidia.com/cuda-downloads) and
[cuDNN](https://developer.nvidia.com/CUDNN).
JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other
JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 and Linux aarch64 only**. Other
combinations of operating system and architecture are possible, but require
building from source (refer to {ref}`building-from-source` to learn more}.
@ -141,11 +133,11 @@ that NVIDIA provides for this purpose.
JAX currently ships one CUDA wheel variant:
| Built with | Compatible with |
|------------|-------------------|
| CUDA 12.3 | CUDA >=12.1 |
| CUDNN 8.9 | CUDNN >=8.9, <9.0 |
| NCCL 2.19 | NCCL >=2.18 |
| Built with | Compatible with |
|------------|--------------------|
| CUDA 12.3 | CUDA >=12.1 |
| CUDNN 9.0 | CUDNN >=9.0, <10.0 |
| NCCL 2.19 | NCCL >=2.18 |
JAX checks the versions of your libraries, and will report an error if they are
not sufficiently new.
@ -163,7 +155,7 @@ pip install --upgrade pip
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 8.9 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "jax[cuda12_local]"
```
**These `pip` installations do not work with Windows, and may fail silently; refer to the table

View File

@ -27,7 +27,7 @@ This document provides a quick overview of essential JAX features, so you can ge
JAX can be installed for CPU on Linux, Windows, and macOS directly from the [Python Package Index](https://pypi.org/project/jax/):
```
pip install "jax[cpu]"
pip install jax
```
or, for NVIDIA GPU:
```

View File

@ -51,6 +51,27 @@ setup(
packages=[package_name],
python_requires=">=3.9",
install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"],
extras_require={
'with_cuda': [
"nvidia-cublas-cu12>=12.1.3.1",
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.1.105",
"nvidia-cuda-runtime-cu12>=12.1.105",
"nvidia-cudnn-cu12>=9.0,<10.0",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
"nvidia-nccl-cu12>=2.18.1",
# nvjitlink is not a direct dependency of JAX, but it is a transitive
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
# do not have a version constraint on their dependencies, so the
# package doesn't get upgraded even though not doing that can cause
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
# Until NVIDIA add version constraints, add a version constraint
# here.
"nvidia-nvjitlink-cu12>=12.1.105",
],
},
url="https://github.com/google/jax",
license="Apache-2.0",
classifiers=[

View File

@ -66,21 +66,6 @@ setup(
'numpy>=1.22',
'ml_dtypes>=0.2.0',
],
extras_require={
'cuda12_pip': [
"nvidia-cublas-cu12>=12.1.3.1",
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.1.105",
"nvidia-cuda-runtime-cu12>=12.1.105",
# https://docs.nvidia.com/deeplearning/cudnn/developer/misc.html#cudnn-api-compatibility
"nvidia-cudnn-cu12>=9.0,<10.0",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
"nvidia-nccl-cu12>=2.18.1",
"nvidia-nvjitlink-cu12>=12.1.105",
],
},
url='https://github.com/google/jax',
license='Apache-2.0',
classifiers=[

View File

@ -22,8 +22,6 @@ project_name = 'jax'
_current_jaxlib_version = '0.4.29'
# The following should be updated with each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.4.29'
_default_cuda12_cudnn_version = '91'
_available_cuda12_cudnn_versions = [_default_cuda12_cudnn_version]
_libtpu_version = '0.1.dev20240609'
def load_version_module(pkg_path):
@ -35,6 +33,7 @@ def load_version_module(pkg_path):
_version_module = load_version_module(project_name)
__version__ = _version_module._get_version_for_build()
_jax_version = _version_module._version # JAX version, with no .dev suffix.
_cmdclass = _version_module._get_cmdclass(project_name)
_minimum_jaxlib_version = _version_module._minimum_jaxlib_version
@ -54,6 +53,7 @@ setup(
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
python_requires='>=3.9',
install_requires=[
f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',
'ml_dtypes>=0.2.0',
'numpy>=1.22',
"numpy>=1.23.2; python_version>='3.11'",
@ -70,9 +70,9 @@ setup(
# Minimum jaxlib version; used in testing.
'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'],
# CPU-only jaxlib can be installed via:
# $ pip install jax[cpu]
'cpu': [f'jaxlib=={_current_jaxlib_version}'],
# A CPU-only jax doesn't require any extras, but we keep this extra
# around for compatibility.
'cpu': [],
# Used only for CI builds that install JAX from github HEAD.
'ci': [f'jaxlib=={_latest_jaxlib_version_on_pypi}'],
@ -80,71 +80,27 @@ setup(
# Cloud TPU VM jaxlib can be installed via:
# $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
'tpu': [
f'jaxlib=={_current_jaxlib_version}',
f'jaxlib>={_current_jaxlib_version},<={_jax_version}',
f'libtpu-nightly=={_libtpu_version}',
'requests', # necessary for jax.distributed.initialize
],
# CUDA installations require adding the JAX CUDA releases URL, e.g.,
# Cuda installation defaulting to a CUDA and Cudnn version defined above.
# $ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}"],
'cuda12_pip': [
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
"nvidia-cublas-cu12>=12.1.3.1",
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.1.105",
"nvidia-cuda-runtime-cu12>=12.1.105",
# https://docs.nvidia.com/deeplearning/cudnn/developer/misc.html#cudnn-api-compatibility
"nvidia-cudnn-cu12>=9.0,<10.0",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
"nvidia-nccl-cu12>=2.18.1",
# nvjitlink is not a direct dependency of JAX, but it is a transitive
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
# do not have a version constraint on their dependencies, so the
# package doesn't get upgraded even though not doing that can cause
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
# Until NVIDIA add version constraints, add a version constraint
# here.
"nvidia-nvjitlink-cu12>=12.1.105",
],
'cuda': [
f"jaxlib=={_current_jaxlib_version}",
f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}",
],
'cuda12': [
f"jaxlib=={_current_jaxlib_version}",
f"jax-cuda12-plugin=={_current_jaxlib_version}",
"nvidia-cublas-cu12>=12.1.3.1",
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.1.105",
"nvidia-cuda-runtime-cu12>=12.1.105",
"nvidia-cudnn-cu12>=9.0,<10.0",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
"nvidia-nccl-cu12>=2.18.1",
# nvjitlink is not a direct dependency of JAX, but it is a transitive
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
# do not have a version constraint on their dependencies, so the
# package doesn't get upgraded even though not doing that can cause
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
# Until NVIDIA add version constraints, add a version constraint
# here.
"nvidia-nvjitlink-cu12>=12.1.105",
f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}",
],
# Target that does not depend on the CUDA pip wheels, for those who want
# to use a preinstalled CUDA.
'cuda12_local': [
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
f"jaxlib=={_current_jaxlib_version}",
f"jax-cuda12-plugin=={_current_jaxlib_version}",
],
# CUDA installations require adding jax releases URL; e.g.
# $ pip install jax[cuda12_cudnn89] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
**{f'cuda12_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{cudnn_version}"
for cudnn_version in _available_cuda12_cudnn_versions}
},
url='https://github.com/google/jax',
license='Apache-2.0',