Remove support for CUDA 11.

Pin minimal required versions for CUDA to 12.1.

Reverts 910a31d7b7510e3375718ab1ea0d38df7bd2c0d5

PiperOrigin-RevId: 618911489
This commit is contained in:
jax authors 2024-03-25 11:44:40 -07:00
parent 19e6156cce
commit 0be07e6aec
14 changed files with 193 additions and 160 deletions

View File

@ -228,30 +228,6 @@ build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --config=cuda
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda11.8_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.8_nvcc_base --config=cuda_clang
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_NVCC_CLANG="1"
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_CUDA_VERSION=11
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_CUDNN_VERSION=8
build:rbe_linux_cuda11.8_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.8"
build:rbe_linux_cuda11.8_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:rbe_linux_cuda11.8_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.8_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.8_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda11.8_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
build:rbe_linux_cuda11.8_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
build:rbe_linux_cuda11.8_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
build:rbe_linux_cuda11.8_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda"
build:rbe_linux_cuda11.8_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_nccl"
build:rbe_linux_cuda11.8_nvcc_py3.9 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.9"
build:rbe_linux_cuda11.8_nvcc_py3.9 --python_path="/usr/local/bin/python3.9"
build:rbe_linux_cuda11.8_nvcc_py3.10 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.10"
build:rbe_linux_cuda11.8_nvcc_py3.10 --python_path="/usr/local/bin/python3.10"
build:rbe_linux_cuda11.8_nvcc_py3.11 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.11"
build:rbe_linux_cuda11.8_nvcc_py3.11 --python_path="/usr/local/bin/python3.11"
build:rbe_linux_cuda11.8_nvcc_py3.12 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.12"
build:rbe_linux_cuda11.8_nvcc_py3.12 --python_path="/usr/local/bin/python3.12"
build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda12.3_nvcc_base --config=cuda_clang
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_NVCC_CLANG="1"

View File

@ -32,6 +32,10 @@ Remember to align the itemized text with the first line of an item within a list
## jaxlib 0.4.26
* Changes
* JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been
dropped.
## jax 0.4.25 (Feb 26, 2024)
* New Features

View File

@ -61,7 +61,7 @@ NVIDIA has dropped support for Kepler GPUs in its software.
You must first install the NVIDIA driver. We
recommend installing the newest driver available from NVIDIA, but the driver
must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
must be version >= 525.60.13 for CUDA 12 on Linux.
If you need to use a newer CUDA toolkit with an older driver, for example
on a cluster where you cannot update the NVIDIA driver easily, you may be
able to use the
@ -82,10 +82,6 @@ pip install --upgrade pip
# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
If JAX detects the wrong version of the CUDA libraries, there are several things
@ -113,14 +109,19 @@ able to use the
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
that NVIDIA provides for this purpose.
JAX currently ships two CUDA wheel variants:
* CUDA 12.3, cuDNN 8.9, NCCL 2.16
* CUDA 11.8, cuDNN 8.6, NCCL 2.16
JAX currently ships one CUDA wheel variant:
| Built with | Compatible with |
|------------|-----------------|
| CUDA 12.3 | CUDA 12.1+ |
| cuDNN 8.9 | cuDNN 8.9+ |
| NCCL 2.19 | NCCL 2.18+ |
You may use a JAX wheel provided the major version of your CUDA, cuDNN, and NCCL
installations match, and the minor versions are the same or newer.
JAX checks the versions of your libraries, and will report an error if they are
not sufficiently new.
Setting the `JAX_SKIP_CUDA_CONSTRAINTS_CHECK` environment variable will disable
the check, but using older versions of CUDA may lead to errors, or incorrect
results.
NCCL is an optional dependency, required only if you are performing multi-GPU
computations.
@ -134,9 +135,6 @@ pip install --upgrade pip
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
**These `pip` installations do not work with Windows, and may fail silently; see
@ -188,11 +186,6 @@ pip install -U --pre libtpu-nightly -f https://storage.googleapis.com/jax-releas
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
```
* Jaxlib GPU (Cuda 11):
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
```
## Google TPU
### pip installation: Google Cloud TPU

View File

@ -72,7 +72,7 @@ NVIDIA has dropped support for Kepler GPUs in its software.
You must first install the NVIDIA driver. You're
recommended to install the newest driver available from NVIDIA, but the driver
version must be >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
version must be >= 525.60.13 for CUDA 12 on Linux.
If you need to use a newer CUDA toolkit with an older driver, for example
on a cluster where you cannot update the NVIDIA driver easily, you may be
@ -99,10 +99,6 @@ pip install --upgrade pip
# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# NVIDIA CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things
@ -131,15 +127,19 @@ able to use the
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
that NVIDIA provides for this purpose.
JAX currently ships two NVIDIA CUDA wheel variants:
JAX currently ships one CUDA wheel variant:
- CUDA 12.2, cuDNN 8.9, NCCL 2.16
- CUDA 11.8, cuDNN 8.6, NCCL 2.16
| Built with | Compatible with |
|------------|-----------------|
| CUDA 12.3 | CUDA 12.1+ |
| cuDNN 8.9 | cuDNN 8.9+ |
| NCCL 2.19 | NCCL 2.18+ |
You may use a JAX wheel provided the major version of your CUDA, cuDNN, and NCCL
installations match, and the minor versions are the same or newer.
JAX checks the versions of your libraries, and will report an error if they are
not sufficiently new.
Setting the `JAX_SKIP_CUDA_CONSTRAINTS_CHECK` environment variable will disable
the check, but using older versions of CUDA may lead to errors, or incorrect
results.
NCCL is an optional dependency, required only if you are performing multi-GPU
computations.
@ -152,10 +152,6 @@ pip install --upgrade pip
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 8.9 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Installs the wheel compatible with NVIDIA CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
**These `pip` installations do not work with Windows, and may fail silently; refer to the table
@ -212,12 +208,6 @@ pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/lib
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
```
- `jaxlib` NVIDIA GPU (CUDA 11):
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
```
(install-google-tpu)=
## Google Cloud TPU
@ -318,4 +308,4 @@ pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_re
For specific older GPU wheels, be sure to use the `jax_cuda_releases.html` URL; for example
```bash
pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
```

View File

@ -31,6 +31,7 @@ import logging
import os
import pkgutil
import platform as py_platform
import traceback
import sys
import threading
from typing import Any, Callable, Union
@ -267,33 +268,101 @@ def _check_cuda_compute_capability(devices_to_check):
RuntimeWarning
)
def _check_cuda_versions():
assert cuda_versions is not None
def _version_check(name, get_version, get_build_version,
scale_for_comparison=1):
def _check_cuda_versions(raise_on_first_error: bool = False,
debug: bool = False):
assert cuda_versions is not None
results: list[dict[str, Any]] = []
def _make_msg(name: str,
runtime_version: int,
build_version: int,
min_supported: int,
debug_msg: bool = False):
if debug_msg:
return (f"Package: {name}\n"
f"Version JAX was built against: {build_version}\n"
f"Minimum supported: {min_supported}\n"
f"Installed version: {runtime_version}")
if min_supported:
req_str = (f"The local installation version must be no lower than "
f"{min_supported}.")
else:
req_str = ("The local installation must be the same version as "
"the version against which JAX was built.")
msg = (f"Outdated {name} installation found.\n"
f"Version JAX was built against: {build_version}\n"
f"Minimum supported: {min_supported}\n"
f"Installed version: {runtime_version}\n"
f"{req_str}")
return msg
def _version_check(name: str,
get_version,
get_build_version,
scale_for_comparison: int = 1,
min_supported_version: int = 0):
"""Checks the runtime CUDA component version against the JAX one.
Args:
name: Of the CUDA component.
get_version: A function to get the local runtime version of the component.
get_build_version: A function to get the build version of the component.
scale_for_comparison: For rounding down a version to ignore patch/minor.
min_supported_version: An absolute minimum version required. Must be
passed without rounding down.
Raises:
RuntimeError: If the component is not found, or is of unsupported version,
and if raising the error is not deferred till later.
"""
build_version = get_build_version()
try:
version = get_version()
except Exception as e:
raise RuntimeError(f"Unable to load {name}. Is it installed?") from e
if build_version // scale_for_comparison > version // scale_for_comparison:
raise RuntimeError(
f"Found {name} version {version}, but JAX was built against version "
f"{build_version}, which is newer. The copy of {name} that is "
"installed must be at least as new as the version against which JAX "
"was built."
)
err_msg = f"Unable to load {name}. Is it installed?"
if raise_on_first_error:
raise RuntimeError(err_msg) from e
err_msg += f"\n{traceback.format_exc()}"
results.append({"name": name, "installed": False, "msg": err_msg})
return
if not min_supported_version:
min_supported_version = build_version // scale_for_comparison
passed = min_supported_version <= version
if not passed or debug:
msg = _make_msg(name=name,
runtime_version=version,
build_version=build_version,
min_supported=min_supported_version,
debug_msg=passed)
if not passed and raise_on_first_error:
raise RuntimeError(msg)
else:
record = {"name": name,
"installed": True,
"msg": msg,
"passed": passed,
"build_version": build_version,
"version": version,
"minimum_supported": min_supported_version}
results.append(record)
_version_check("CUDA", cuda_versions.cuda_runtime_get_version,
cuda_versions.cuda_runtime_build_version)
cuda_versions.cuda_runtime_build_version,
scale_for_comparison=10,
min_supported_version=12010)
_version_check(
"cuDNN",
cuda_versions.cudnn_get_version,
cuda_versions.cudnn_build_version,
# NVIDIA promise both backwards and forwards compatibility for cuDNN patch
# versions: https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
# versions:
# https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
scale_for_comparison=100,
min_supported_version=8900
)
_version_check("cuFFT", cuda_versions.cufft_get_version,
cuda_versions.cufft_build_version,
@ -302,20 +371,42 @@ def _check_cuda_versions():
_version_check("cuSOLVER", cuda_versions.cusolver_get_version,
cuda_versions.cusolver_build_version,
# Ignore patch versions.
scale_for_comparison=100)
scale_for_comparison=100,
min_supported_version=11400)
_version_check("cuPTI", cuda_versions.cupti_get_version,
cuda_versions.cupti_build_version)
cuda_versions.cupti_build_version,
min_supported_version=18)
# TODO(jakevdp) remove these checks when minimum jaxlib is v0.4.21
if hasattr(cuda_versions, "cublas_get_version"):
_version_check("cuBLAS", cuda_versions.cublas_get_version,
cuda_versions.cublas_build_version,
# Ignore patch versions.
scale_for_comparison=100)
scale_for_comparison=100,
min_supported_version=120100)
if hasattr(cuda_versions, "cusparse_get_version"):
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
cuda_versions.cusparse_build_version,
# Ignore patch versions.
scale_for_comparison=100)
scale_for_comparison=100,
min_supported_version=12100)
errors = []
debug_results = []
for result in results:
message: str = result['msg']
if not result['installed'] or not result['passed']:
errors.append(message)
else:
debug_results.append(message)
join_str = f'\n{"-" * 50}\n'
if debug_results:
print(f'CUDA components status (debug):\n'
f'{join_str.join(debug_results)}')
if errors:
raise RuntimeError(f'Unable to use CUDA because of the '
f'following issues with CUDA components:\n'
f'{join_str.join(errors)}')
def make_gpu_client(
@ -335,6 +426,10 @@ def make_gpu_client(
if platform_name == "cuda":
if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"):
_check_cuda_versions()
else:
print('Skipped CUDA versions constraints check due to the '
'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.')
# TODO(micky774): remove this check when minimum jaxlib is v0.4.26
if jaxlib.version.__version_info__ >= (0, 4, 26):
devices_to_check = (allowed_devices if allowed_devices else

View File

@ -25,7 +25,7 @@ import jax._src.xla_bridge as xb
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax cuda plugin packages.
for pkg_name in ['jax_cuda12_plugin', 'jax_cuda11_plugin', 'jaxlib']:
for pkg_name in ['jax_cuda12_plugin', 'jaxlib']:
try:
cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension'

View File

@ -24,7 +24,7 @@ from .gpu_common_utils import GpuLibNotLinkedError
from jaxlib import xla_client
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cuda_linalg = importlib.import_module(
f"{cuda_module_name}._linalg", package="jaxlib"

View File

@ -27,7 +27,7 @@ from jaxlib import xla_client
from .hlo_helpers import custom_call
from .gpu_common_utils import GpuLibNotLinkedError
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cuda_prng = importlib.import_module(
f"{cuda_module_name}._prng", package="jaxlib"

View File

@ -22,7 +22,7 @@ import numpy as np
from jaxlib import xla_client
from .gpu_common_utils import GpuLibNotLinkedError
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_rnn = importlib.import_module(f"{cuda_module_name}._rnn", package="jaxlib")
except ImportError:

View File

@ -33,7 +33,7 @@ from .hlo_helpers import (
try:
from .cuda import _blas as _cublas # pytype: disable=import-error
except ImportError:
for cuda_module_name in ["jax_cuda12_plugin", "jax_cuda11_plugin"]:
for cuda_module_name in ["jax_cuda12_plugin"]:
try:
_cublas = importlib.import_module(f"{cuda_module_name}._blas")
except ImportError:
@ -45,7 +45,7 @@ if _cublas:
for _name, _value in _cublas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cusolver = importlib.import_module(
f"{cuda_module_name}._solver", package="jaxlib"

View File

@ -27,7 +27,7 @@ from jaxlib import xla_client
from .hlo_helpers import custom_call, mk_result_types_and_shapes
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cusparse = importlib.import_module(
f"{cuda_module_name}._sparse", package="jaxlib"

View File

@ -15,7 +15,7 @@ import importlib
from jaxlib import xla_client
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cuda_triton = importlib.import_module(
f"{cuda_module_name}._triton", package="jaxlib"

View File

@ -67,25 +67,17 @@ setup(
'ml_dtypes>=0.2.0',
],
extras_require={
'cuda11_pip': [
"nvidia-cublas-cu11>=11.11",
"nvidia-cuda-cupti-cu11>=11.8",
"nvidia-cuda-nvcc-cu11>=11.8",
"nvidia-cuda-runtime-cu11>=11.8",
"nvidia-cudnn-cu11>=8.8",
"nvidia-cufft-cu11>=10.9",
"nvidia-cusolver-cu11>=11.4",
"nvidia-cusparse-cu11>=11.7",
],
'cuda12_pip': [
"nvidia-cublas-cu12",
"nvidia-cuda-cupti-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-cuda-runtime-cu12",
"nvidia-cudnn-cu12>=8.9",
"nvidia-cufft-cu12",
"nvidia-cusolver-cu12",
"nvidia-cusparse-cu12",
"nvidia-cublas-cu12>=12.1.3.1",
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.1.105",
"nvidia-cuda-runtime-cu12>=12.1.105",
"nvidia-cudnn-cu12>=8.9.2.26",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
"nvidia-nccl-cu12>=2.18.1",
"nvidia-nvjitlink-cu12>=12.1.105",
],
},
url='https://github.com/google/jax',

View File

@ -25,9 +25,8 @@ project_name = 'jax'
_current_jaxlib_version = '0.4.25'
# The following should be updated with each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.4.25'
_available_cuda11_cudnn_versions = ['86']
_default_cuda11_cudnn_version = '86'
_default_cuda12_cudnn_version = '89'
_available_cuda12_cudnn_versions = [_default_cuda12_cudnn_version]
_libtpu_version = '0.1.dev20240224'
def load_version_module(pkg_path):
@ -110,78 +109,62 @@ setup(
# CUDA installations require adding the JAX CUDA releases URL, e.g.,
# Cuda installation defaulting to a CUDA and Cudnn version defined above.
# $ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{_default_cuda11_cudnn_version}"],
'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}"],
'cuda11_pip': [
f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{_default_cuda11_cudnn_version}",
"nvidia-cublas-cu11>=11.11",
"nvidia-cuda-cupti-cu11>=11.8",
"nvidia-cuda-nvcc-cu11>=11.8",
"nvidia-cuda-runtime-cu11>=11.8",
"nvidia-cudnn-cu11>=8.8",
"nvidia-cufft-cu11>=10.9",
"nvidia-cusolver-cu11>=11.4",
"nvidia-cusparse-cu11>=11.7",
"nvidia-nccl-cu11>=2.18.3",
],
'cuda12_pip': [
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
"nvidia-cublas-cu12>=12.3.4.1",
"nvidia-cuda-cupti-cu12>=12.3.101",
"nvidia-cuda-nvcc-cu12>=12.3.107",
"nvidia-cuda-runtime-cu12>=12.3.101",
"nvidia-cudnn-cu12>=8.9.7.29",
"nvidia-cufft-cu12>=11.0.12.1",
"nvidia-cusolver-cu12>=11.5.4.101",
"nvidia-cusparse-cu12>=12.2.0.103",
"nvidia-nccl-cu12>=2.19.3",
# nvjitlink is not a direct dependency of JAX, but it is a transitive
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
# do not have a version constraint on their dependencies, so the
# package doesn't get upgraded even though not doing that can cause
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
# Until NVIDIA add version constraints, add an version constraint
# here.
"nvidia-nvjitlink-cu12>=12.3.101",
],
'cuda12_pip': [
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
"nvidia-cublas-cu12>=12.1.3.1",
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.1.105",
"nvidia-cuda-runtime-cu12>=12.1.105",
"nvidia-cudnn-cu12>=8.9.2.26",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
"nvidia-nccl-cu12>=2.18.1",
# nvjitlink is not a direct dependency of JAX, but it is a transitive
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
# do not have a version constraint on their dependencies, so the
# package doesn't get upgraded even though not doing that can cause
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
# Until NVIDIA add version constraints, add a version constraint
# here.
"nvidia-nvjitlink-cu12>=12.1.105",
],
'cuda12': [
f"jaxlib=={_current_jaxlib_version}",
f"jax-cuda12-plugin=={_current_jaxlib_version}",
"nvidia-cublas-cu12>=12.3.4.1",
"nvidia-cuda-cupti-cu12>=12.3.101",
"nvidia-cuda-nvcc-cu12>=12.3.107",
"nvidia-cuda-runtime-cu12>=12.3.101",
"nvidia-cudnn-cu12>=8.9.7.29",
"nvidia-cufft-cu12>=11.0.12.1",
"nvidia-cusolver-cu12>=11.5.4.101",
"nvidia-cusparse-cu12>=12.2.0.103",
"nvidia-nccl-cu12>=2.19.3",
"nvidia-cublas-cu12>=12.1.3.1",
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.1.105",
"nvidia-cuda-runtime-cu12>=12.1.105",
"nvidia-cudnn-cu12>=8.9.2.26",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
"nvidia-nccl-cu12>=2.18.1",
# nvjitlink is not a direct dependency of JAX, but it is a transitive
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
# do not have a version constraint on their dependencies, so the
# package doesn't get upgraded even though not doing that can cause
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
# Until NVIDIA add version constraints, add an version constraint
# Until NVIDIA add version constraints, add a version constraint
# here.
"nvidia-nvjitlink-cu12>=12.3.101",
"nvidia-nvjitlink-cu12>=12.1.105",
],
# Target that does not depend on the CUDA pip wheels, for those who want
# to use a preinstalled CUDA.
'cuda11_local': [
f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{_default_cuda11_cudnn_version}",
],
'cuda12_local': [
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
],
# CUDA installations require adding jax releases URL; e.g.
# $ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# $ pip install jax[cuda11_cudnn86] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
**{f'cuda11_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{cudnn_version}"
for cudnn_version in _available_cuda11_cudnn_versions}
# $ pip install jax[cuda12_cudnn89] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
**{f'cuda12_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{cudnn_version}"
for cudnn_version in _available_cuda12_cudnn_versions}
},
url='https://github.com/google/jax',
license='Apache-2.0',