mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Update JAX dependencies, extras, and documentation for plugins.
* Make jaxlib a direct dependency of jax. * Remove mentions of monolithic CUDA installations from the JAX documentation. * Drop the cuda12_pip extra and the cudnn version specific extras. * Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place. * Make a few small doc improvements.
This commit is contained in:
parent
a9edaeb38e
commit
b13733c13f
@ -13,9 +13,16 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
bumped to 0.4.0 but this has been rolled back in this release to give users
|
||||
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
|
||||
release.
|
||||
* jax now depends on jaxlib directly. This change was enabled by the CUDA
|
||||
plugin switch: there are no longer multiple jaxlib variants. You can install
|
||||
a CPU-only jax with `pip install jax`, no extras required.
|
||||
|
||||
## jaxlib 0.4.30
|
||||
|
||||
* Support for monolithic CUDA jaxlibs has been dropped. You must use the
|
||||
plugin-based installation (`pip install jax[cuda12]` or
|
||||
`pip install jax[cuda12_local]`).
|
||||
|
||||
## jax 0.4.29 (June 10, 2024)
|
||||
|
||||
* Changes
|
||||
|
@ -396,8 +396,8 @@ Some standouts:
|
||||
|
||||
| Hardware | Instructions |
|
||||
|------------|-----------------------------------------------------------------------------------------------------------------|
|
||||
| CPU | `pip install -U "jax[cpu]"` |
|
||||
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12]"` |
|
||||
| CPU | `pip install -U jax` |
|
||||
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
|
||||
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
|
||||
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
|
||||
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
|
||||
|
@ -9,20 +9,16 @@ different builds for different operating systems and accelerators.
|
||||
|
||||
* **CPU-only (Linux/macOS/Windows)**
|
||||
```
|
||||
pip install -U "jax[cpu]"
|
||||
pip install -U jax
|
||||
```
|
||||
* **GPU (NVIDIA, CUDA 12, x86_64)**
|
||||
* **GPU (NVIDIA, CUDA 12)**
|
||||
```
|
||||
pip install -U "jax[cuda12]"
|
||||
```
|
||||
|
||||
* **GPU (NVIDIA, CUDA 12, x86_64) legacy**
|
||||
|
||||
You should prefer `jax[cuda12]`, which uses the common CPU jaxlib and adds GPU
|
||||
support as a plugin. The monolithic `jax[cuda12_pip]` option will be removed in
|
||||
a future JAX release.
|
||||
* **TPU (Google Cloud TPU VM) **
|
||||
```
|
||||
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
```
|
||||
|
||||
(install-supported-platforms)=
|
||||
@ -48,6 +44,7 @@ Currently, the JAX team releases `jaxlib` wheels for the following
|
||||
operating systems and architectures:
|
||||
|
||||
- Linux, x86_64
|
||||
- Linux, aarch64
|
||||
- macOS, Intel
|
||||
- macOS, Apple ARM-based
|
||||
- Windows, x86_64 (*experimental*)
|
||||
@ -57,7 +54,7 @@ development on a laptop, you can run:
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade "jax[cpu]"
|
||||
pip install --upgrade jax
|
||||
```
|
||||
|
||||
On Windows, you may also need to install the
|
||||
@ -97,8 +94,8 @@ There are two ways to install JAX with NVIDIA GPU support:
|
||||
The JAX team strongly recommends installing CUDA and cuDNN using the pip wheels,
|
||||
since it is much easier!
|
||||
|
||||
This method is only supported on x86_64, because NVIDIA has not released aarch64
|
||||
CUDA pip packages.
|
||||
NVIDIA has released CUDA pip packages only for x86_64 and aarch64; on other
|
||||
platforms you must use a local installation of CUDA.
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
@ -106,11 +103,6 @@ pip install --upgrade pip
|
||||
# NVIDIA CUDA 12 installation
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda12]"
|
||||
|
||||
# Legacy way of NVIDIA CUDA 12 installation. You should prefer `jax[cuda12]`,
|
||||
# which uses the common CPU jaxlib and adds GPU support as a plugin. The
|
||||
# monolithic `jax[cuda12_pip]` option will be removed in a future JAX release.
|
||||
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
|
||||
If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things
|
||||
@ -127,7 +119,7 @@ If you prefer to use a preinstalled copy of NVIDIA CUDA, you must first
|
||||
install NVIDIA [CUDA](https://developer.nvidia.com/cuda-downloads) and
|
||||
[cuDNN](https://developer.nvidia.com/CUDNN).
|
||||
|
||||
JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 only**. Other
|
||||
JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 and Linux aarch64 only**. Other
|
||||
combinations of operating system and architecture are possible, but require
|
||||
building from source (refer to {ref}`building-from-source` to learn more}.
|
||||
|
||||
@ -141,11 +133,11 @@ that NVIDIA provides for this purpose.
|
||||
|
||||
JAX currently ships one CUDA wheel variant:
|
||||
|
||||
| Built with | Compatible with |
|
||||
|------------|-------------------|
|
||||
| CUDA 12.3 | CUDA >=12.1 |
|
||||
| CUDNN 8.9 | CUDNN >=8.9, <9.0 |
|
||||
| NCCL 2.19 | NCCL >=2.18 |
|
||||
| Built with | Compatible with |
|
||||
|------------|--------------------|
|
||||
| CUDA 12.3 | CUDA >=12.1 |
|
||||
| CUDNN 9.0 | CUDNN >=9.0, <10.0 |
|
||||
| NCCL 2.19 | NCCL >=2.18 |
|
||||
|
||||
JAX checks the versions of your libraries, and will report an error if they are
|
||||
not sufficiently new.
|
||||
@ -163,7 +155,7 @@ pip install --upgrade pip
|
||||
|
||||
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 8.9 or newer.
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
pip install --upgrade "jax[cuda12_local]"
|
||||
```
|
||||
|
||||
**These `pip` installations do not work with Windows, and may fail silently; refer to the table
|
||||
|
@ -27,7 +27,7 @@ This document provides a quick overview of essential JAX features, so you can ge
|
||||
|
||||
JAX can be installed for CPU on Linux, Windows, and macOS directly from the [Python Package Index](https://pypi.org/project/jax/):
|
||||
```
|
||||
pip install "jax[cpu]"
|
||||
pip install jax
|
||||
```
|
||||
or, for NVIDIA GPU:
|
||||
```
|
||||
|
@ -51,6 +51,27 @@ setup(
|
||||
packages=[package_name],
|
||||
python_requires=">=3.9",
|
||||
install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"],
|
||||
extras_require={
|
||||
'with_cuda': [
|
||||
"nvidia-cublas-cu12>=12.1.3.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.1.105",
|
||||
"nvidia-cuda-nvcc-cu12>=12.1.105",
|
||||
"nvidia-cuda-runtime-cu12>=12.1.105",
|
||||
"nvidia-cudnn-cu12>=9.0,<10.0",
|
||||
"nvidia-cufft-cu12>=11.0.2.54",
|
||||
"nvidia-cusolver-cu12>=11.4.5.107",
|
||||
"nvidia-cusparse-cu12>=12.1.0.106",
|
||||
"nvidia-nccl-cu12>=2.18.1",
|
||||
# nvjitlink is not a direct dependency of JAX, but it is a transitive
|
||||
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
|
||||
# do not have a version constraint on their dependencies, so the
|
||||
# package doesn't get upgraded even though not doing that can cause
|
||||
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
|
||||
# Until NVIDIA add version constraints, add a version constraint
|
||||
# here.
|
||||
"nvidia-nvjitlink-cu12>=12.1.105",
|
||||
],
|
||||
},
|
||||
url="https://github.com/google/jax",
|
||||
license="Apache-2.0",
|
||||
classifiers=[
|
||||
|
@ -66,21 +66,6 @@ setup(
|
||||
'numpy>=1.22',
|
||||
'ml_dtypes>=0.2.0',
|
||||
],
|
||||
extras_require={
|
||||
'cuda12_pip': [
|
||||
"nvidia-cublas-cu12>=12.1.3.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.1.105",
|
||||
"nvidia-cuda-nvcc-cu12>=12.1.105",
|
||||
"nvidia-cuda-runtime-cu12>=12.1.105",
|
||||
# https://docs.nvidia.com/deeplearning/cudnn/developer/misc.html#cudnn-api-compatibility
|
||||
"nvidia-cudnn-cu12>=9.0,<10.0",
|
||||
"nvidia-cufft-cu12>=11.0.2.54",
|
||||
"nvidia-cusolver-cu12>=11.4.5.107",
|
||||
"nvidia-cusparse-cu12>=12.1.0.106",
|
||||
"nvidia-nccl-cu12>=2.18.1",
|
||||
"nvidia-nvjitlink-cu12>=12.1.105",
|
||||
],
|
||||
},
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
classifiers=[
|
||||
|
70
setup.py
70
setup.py
@ -22,8 +22,6 @@ project_name = 'jax'
|
||||
_current_jaxlib_version = '0.4.29'
|
||||
# The following should be updated with each new jaxlib release.
|
||||
_latest_jaxlib_version_on_pypi = '0.4.29'
|
||||
_default_cuda12_cudnn_version = '91'
|
||||
_available_cuda12_cudnn_versions = [_default_cuda12_cudnn_version]
|
||||
_libtpu_version = '0.1.dev20240609'
|
||||
|
||||
def load_version_module(pkg_path):
|
||||
@ -35,6 +33,7 @@ def load_version_module(pkg_path):
|
||||
|
||||
_version_module = load_version_module(project_name)
|
||||
__version__ = _version_module._get_version_for_build()
|
||||
_jax_version = _version_module._version # JAX version, with no .dev suffix.
|
||||
_cmdclass = _version_module._get_cmdclass(project_name)
|
||||
_minimum_jaxlib_version = _version_module._minimum_jaxlib_version
|
||||
|
||||
@ -54,6 +53,7 @@ setup(
|
||||
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
|
||||
python_requires='>=3.9',
|
||||
install_requires=[
|
||||
f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',
|
||||
'ml_dtypes>=0.2.0',
|
||||
'numpy>=1.22',
|
||||
"numpy>=1.23.2; python_version>='3.11'",
|
||||
@ -70,9 +70,9 @@ setup(
|
||||
# Minimum jaxlib version; used in testing.
|
||||
'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'],
|
||||
|
||||
# CPU-only jaxlib can be installed via:
|
||||
# $ pip install jax[cpu]
|
||||
'cpu': [f'jaxlib=={_current_jaxlib_version}'],
|
||||
# A CPU-only jax doesn't require any extras, but we keep this extra
|
||||
# around for compatibility.
|
||||
'cpu': [],
|
||||
|
||||
# Used only for CI builds that install JAX from github HEAD.
|
||||
'ci': [f'jaxlib=={_latest_jaxlib_version_on_pypi}'],
|
||||
@ -80,71 +80,27 @@ setup(
|
||||
# Cloud TPU VM jaxlib can be installed via:
|
||||
# $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
'tpu': [
|
||||
f'jaxlib=={_current_jaxlib_version}',
|
||||
f'jaxlib>={_current_jaxlib_version},<={_jax_version}',
|
||||
f'libtpu-nightly=={_libtpu_version}',
|
||||
'requests', # necessary for jax.distributed.initialize
|
||||
],
|
||||
|
||||
# CUDA installations require adding the JAX CUDA releases URL, e.g.,
|
||||
# Cuda installation defaulting to a CUDA and Cudnn version defined above.
|
||||
# $ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}"],
|
||||
|
||||
|
||||
'cuda12_pip': [
|
||||
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
|
||||
"nvidia-cublas-cu12>=12.1.3.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.1.105",
|
||||
"nvidia-cuda-nvcc-cu12>=12.1.105",
|
||||
"nvidia-cuda-runtime-cu12>=12.1.105",
|
||||
# https://docs.nvidia.com/deeplearning/cudnn/developer/misc.html#cudnn-api-compatibility
|
||||
"nvidia-cudnn-cu12>=9.0,<10.0",
|
||||
"nvidia-cufft-cu12>=11.0.2.54",
|
||||
"nvidia-cusolver-cu12>=11.4.5.107",
|
||||
"nvidia-cusparse-cu12>=12.1.0.106",
|
||||
"nvidia-nccl-cu12>=2.18.1",
|
||||
# nvjitlink is not a direct dependency of JAX, but it is a transitive
|
||||
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
|
||||
# do not have a version constraint on their dependencies, so the
|
||||
# package doesn't get upgraded even though not doing that can cause
|
||||
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
|
||||
# Until NVIDIA add version constraints, add a version constraint
|
||||
# here.
|
||||
"nvidia-nvjitlink-cu12>=12.1.105",
|
||||
],
|
||||
'cuda': [
|
||||
f"jaxlib=={_current_jaxlib_version}",
|
||||
f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}",
|
||||
],
|
||||
|
||||
'cuda12': [
|
||||
f"jaxlib=={_current_jaxlib_version}",
|
||||
f"jax-cuda12-plugin=={_current_jaxlib_version}",
|
||||
"nvidia-cublas-cu12>=12.1.3.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.1.105",
|
||||
"nvidia-cuda-nvcc-cu12>=12.1.105",
|
||||
"nvidia-cuda-runtime-cu12>=12.1.105",
|
||||
"nvidia-cudnn-cu12>=9.0,<10.0",
|
||||
"nvidia-cufft-cu12>=11.0.2.54",
|
||||
"nvidia-cusolver-cu12>=11.4.5.107",
|
||||
"nvidia-cusparse-cu12>=12.1.0.106",
|
||||
"nvidia-nccl-cu12>=2.18.1",
|
||||
# nvjitlink is not a direct dependency of JAX, but it is a transitive
|
||||
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
|
||||
# do not have a version constraint on their dependencies, so the
|
||||
# package doesn't get upgraded even though not doing that can cause
|
||||
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
|
||||
# Until NVIDIA add version constraints, add a version constraint
|
||||
# here.
|
||||
"nvidia-nvjitlink-cu12>=12.1.105",
|
||||
f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}",
|
||||
],
|
||||
|
||||
# Target that does not depend on the CUDA pip wheels, for those who want
|
||||
# to use a preinstalled CUDA.
|
||||
'cuda12_local': [
|
||||
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
|
||||
f"jaxlib=={_current_jaxlib_version}",
|
||||
f"jax-cuda12-plugin=={_current_jaxlib_version}",
|
||||
],
|
||||
|
||||
# CUDA installations require adding jax releases URL; e.g.
|
||||
# $ pip install jax[cuda12_cudnn89] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
**{f'cuda12_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{cudnn_version}"
|
||||
for cudnn_version in _available_cuda12_cudnn_versions}
|
||||
},
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
|
Loading…
x
Reference in New Issue
Block a user