mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1. Reverts 910a31d7b7510e3375718ab1ea0d38df7bd2c0d5 PiperOrigin-RevId: 618911489
This commit is contained in:
parent
19e6156cce
commit
0be07e6aec
24
.bazelrc
24
.bazelrc
@ -228,30 +228,6 @@ build:rbe_linux_cuda_base --config=rbe_linux
|
||||
build:rbe_linux_cuda_base --config=cuda
|
||||
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
|
||||
|
||||
build:rbe_linux_cuda11.8_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda11.8_nvcc_base --config=cuda_clang
|
||||
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_NVCC_CLANG="1"
|
||||
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_CUDA_VERSION=11
|
||||
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_CUDNN_VERSION=8
|
||||
build:rbe_linux_cuda11.8_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.8"
|
||||
build:rbe_linux_cuda11.8_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
|
||||
build:rbe_linux_cuda11.8_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.8_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.8_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda11.8_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.8_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.8_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.8_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda"
|
||||
build:rbe_linux_cuda11.8_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_nccl"
|
||||
build:rbe_linux_cuda11.8_nvcc_py3.9 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.9"
|
||||
build:rbe_linux_cuda11.8_nvcc_py3.9 --python_path="/usr/local/bin/python3.9"
|
||||
build:rbe_linux_cuda11.8_nvcc_py3.10 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.10"
|
||||
build:rbe_linux_cuda11.8_nvcc_py3.10 --python_path="/usr/local/bin/python3.10"
|
||||
build:rbe_linux_cuda11.8_nvcc_py3.11 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.11"
|
||||
build:rbe_linux_cuda11.8_nvcc_py3.11 --python_path="/usr/local/bin/python3.11"
|
||||
build:rbe_linux_cuda11.8_nvcc_py3.12 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.12"
|
||||
build:rbe_linux_cuda11.8_nvcc_py3.12 --python_path="/usr/local/bin/python3.12"
|
||||
|
||||
build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda12.3_nvcc_base --config=cuda_clang
|
||||
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_NVCC_CLANG="1"
|
||||
|
@ -32,6 +32,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jaxlib 0.4.26
|
||||
|
||||
* Changes
|
||||
* JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been
|
||||
dropped.
|
||||
|
||||
## jax 0.4.25 (Feb 26, 2024)
|
||||
|
||||
* New Features
|
||||
|
@ -61,7 +61,7 @@ NVIDIA has dropped support for Kepler GPUs in its software.
|
||||
|
||||
You must first install the NVIDIA driver. We
|
||||
recommend installing the newest driver available from NVIDIA, but the driver
|
||||
must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
|
||||
must be version >= 525.60.13 for CUDA 12 on Linux.
|
||||
If you need to use a newer CUDA toolkit with an older driver, for example
|
||||
on a cluster where you cannot update the NVIDIA driver easily, you may be
|
||||
able to use the
|
||||
@ -82,10 +82,6 @@ pip install --upgrade pip
|
||||
# CUDA 12 installation
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# CUDA 11 installation
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
|
||||
If JAX detects the wrong version of the CUDA libraries, there are several things
|
||||
@ -113,14 +109,19 @@ able to use the
|
||||
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
|
||||
that NVIDIA provides for this purpose.
|
||||
|
||||
JAX currently ships two CUDA wheel variants:
|
||||
* CUDA 12.3, cuDNN 8.9, NCCL 2.16
|
||||
* CUDA 11.8, cuDNN 8.6, NCCL 2.16
|
||||
JAX currently ships one CUDA wheel variant:
|
||||
|
||||
| Built with | Compatible with |
|
||||
|------------|-----------------|
|
||||
| CUDA 12.3 | CUDA 12.1+ |
|
||||
| cuDNN 8.9 | cuDNN 8.9+ |
|
||||
| NCCL 2.19 | NCCL 2.18+ |
|
||||
|
||||
You may use a JAX wheel provided the major version of your CUDA, cuDNN, and NCCL
|
||||
installations match, and the minor versions are the same or newer.
|
||||
JAX checks the versions of your libraries, and will report an error if they are
|
||||
not sufficiently new.
|
||||
Setting the `JAX_SKIP_CUDA_CONSTRAINTS_CHECK` environment variable will disable
|
||||
the check, but using older versions of CUDA may lead to errors, or incorrect
|
||||
results.
|
||||
|
||||
NCCL is an optional dependency, required only if you are performing multi-GPU
|
||||
computations.
|
||||
@ -134,9 +135,6 @@ pip install --upgrade pip
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
|
||||
**These `pip` installations do not work with Windows, and may fail silently; see
|
||||
@ -188,11 +186,6 @@ pip install -U --pre libtpu-nightly -f https://storage.googleapis.com/jax-releas
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
|
||||
```
|
||||
|
||||
* Jaxlib GPU (Cuda 11):
|
||||
```bash
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
|
||||
```
|
||||
|
||||
## Google TPU
|
||||
|
||||
### pip installation: Google Cloud TPU
|
||||
|
@ -72,7 +72,7 @@ NVIDIA has dropped support for Kepler GPUs in its software.
|
||||
|
||||
You must first install the NVIDIA driver. You're
|
||||
recommended to install the newest driver available from NVIDIA, but the driver
|
||||
version must be >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
|
||||
version must be >= 525.60.13 for CUDA 12 on Linux.
|
||||
|
||||
If you need to use a newer CUDA toolkit with an older driver, for example
|
||||
on a cluster where you cannot update the NVIDIA driver easily, you may be
|
||||
@ -99,10 +99,6 @@ pip install --upgrade pip
|
||||
# NVIDIA CUDA 12 installation
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# NVIDIA CUDA 11 installation
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
|
||||
If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things
|
||||
@ -131,15 +127,19 @@ able to use the
|
||||
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
|
||||
that NVIDIA provides for this purpose.
|
||||
|
||||
JAX currently ships two NVIDIA CUDA wheel variants:
|
||||
JAX currently ships one CUDA wheel variant:
|
||||
|
||||
- CUDA 12.2, cuDNN 8.9, NCCL 2.16
|
||||
- CUDA 11.8, cuDNN 8.6, NCCL 2.16
|
||||
| Built with | Compatible with |
|
||||
|------------|-----------------|
|
||||
| CUDA 12.3 | CUDA 12.1+ |
|
||||
| cuDNN 8.9 | cuDNN 8.9+ |
|
||||
| NCCL 2.19 | NCCL 2.18+ |
|
||||
|
||||
You may use a JAX wheel provided the major version of your CUDA, cuDNN, and NCCL
|
||||
installations match, and the minor versions are the same or newer.
|
||||
JAX checks the versions of your libraries, and will report an error if they are
|
||||
not sufficiently new.
|
||||
Setting the `JAX_SKIP_CUDA_CONSTRAINTS_CHECK` environment variable will disable
|
||||
the check, but using older versions of CUDA may lead to errors, or incorrect
|
||||
results.
|
||||
|
||||
NCCL is an optional dependency, required only if you are performing multi-GPU
|
||||
computations.
|
||||
@ -152,10 +152,6 @@ pip install --upgrade pip
|
||||
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 8.9 or newer.
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Installs the wheel compatible with NVIDIA CUDA 11 and cuDNN 8.6 or newer.
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
|
||||
**These `pip` installations do not work with Windows, and may fail silently; refer to the table
|
||||
@ -212,12 +208,6 @@ pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/lib
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
|
||||
```
|
||||
|
||||
- `jaxlib` NVIDIA GPU (CUDA 11):
|
||||
|
||||
```bash
|
||||
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
|
||||
```
|
||||
|
||||
(install-google-tpu)=
|
||||
## Google Cloud TPU
|
||||
|
||||
@ -318,4 +308,4 @@ pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_re
|
||||
For specific older GPU wheels, be sure to use the `jax_cuda_releases.html` URL; for example
|
||||
```bash
|
||||
pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
```
|
||||
|
@ -31,6 +31,7 @@ import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import platform as py_platform
|
||||
import traceback
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Callable, Union
|
||||
@ -267,33 +268,101 @@ def _check_cuda_compute_capability(devices_to_check):
|
||||
RuntimeWarning
|
||||
)
|
||||
|
||||
def _check_cuda_versions():
|
||||
assert cuda_versions is not None
|
||||
|
||||
def _version_check(name, get_version, get_build_version,
|
||||
scale_for_comparison=1):
|
||||
def _check_cuda_versions(raise_on_first_error: bool = False,
|
||||
debug: bool = False):
|
||||
assert cuda_versions is not None
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
def _make_msg(name: str,
|
||||
runtime_version: int,
|
||||
build_version: int,
|
||||
min_supported: int,
|
||||
debug_msg: bool = False):
|
||||
if debug_msg:
|
||||
return (f"Package: {name}\n"
|
||||
f"Version JAX was built against: {build_version}\n"
|
||||
f"Minimum supported: {min_supported}\n"
|
||||
f"Installed version: {runtime_version}")
|
||||
if min_supported:
|
||||
req_str = (f"The local installation version must be no lower than "
|
||||
f"{min_supported}.")
|
||||
else:
|
||||
req_str = ("The local installation must be the same version as "
|
||||
"the version against which JAX was built.")
|
||||
msg = (f"Outdated {name} installation found.\n"
|
||||
f"Version JAX was built against: {build_version}\n"
|
||||
f"Minimum supported: {min_supported}\n"
|
||||
f"Installed version: {runtime_version}\n"
|
||||
f"{req_str}")
|
||||
return msg
|
||||
|
||||
def _version_check(name: str,
|
||||
get_version,
|
||||
get_build_version,
|
||||
scale_for_comparison: int = 1,
|
||||
min_supported_version: int = 0):
|
||||
"""Checks the runtime CUDA component version against the JAX one.
|
||||
|
||||
Args:
|
||||
name: Of the CUDA component.
|
||||
get_version: A function to get the local runtime version of the component.
|
||||
get_build_version: A function to get the build version of the component.
|
||||
scale_for_comparison: For rounding down a version to ignore patch/minor.
|
||||
min_supported_version: An absolute minimum version required. Must be
|
||||
passed without rounding down.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the component is not found, or is of unsupported version,
|
||||
and if raising the error is not deferred till later.
|
||||
"""
|
||||
|
||||
build_version = get_build_version()
|
||||
try:
|
||||
version = get_version()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unable to load {name}. Is it installed?") from e
|
||||
if build_version // scale_for_comparison > version // scale_for_comparison:
|
||||
raise RuntimeError(
|
||||
f"Found {name} version {version}, but JAX was built against version "
|
||||
f"{build_version}, which is newer. The copy of {name} that is "
|
||||
"installed must be at least as new as the version against which JAX "
|
||||
"was built."
|
||||
)
|
||||
err_msg = f"Unable to load {name}. Is it installed?"
|
||||
if raise_on_first_error:
|
||||
raise RuntimeError(err_msg) from e
|
||||
err_msg += f"\n{traceback.format_exc()}"
|
||||
results.append({"name": name, "installed": False, "msg": err_msg})
|
||||
return
|
||||
|
||||
if not min_supported_version:
|
||||
min_supported_version = build_version // scale_for_comparison
|
||||
passed = min_supported_version <= version
|
||||
|
||||
if not passed or debug:
|
||||
msg = _make_msg(name=name,
|
||||
runtime_version=version,
|
||||
build_version=build_version,
|
||||
min_supported=min_supported_version,
|
||||
debug_msg=passed)
|
||||
if not passed and raise_on_first_error:
|
||||
raise RuntimeError(msg)
|
||||
else:
|
||||
record = {"name": name,
|
||||
"installed": True,
|
||||
"msg": msg,
|
||||
"passed": passed,
|
||||
"build_version": build_version,
|
||||
"version": version,
|
||||
"minimum_supported": min_supported_version}
|
||||
results.append(record)
|
||||
|
||||
_version_check("CUDA", cuda_versions.cuda_runtime_get_version,
|
||||
cuda_versions.cuda_runtime_build_version)
|
||||
cuda_versions.cuda_runtime_build_version,
|
||||
scale_for_comparison=10,
|
||||
min_supported_version=12010)
|
||||
_version_check(
|
||||
"cuDNN",
|
||||
cuda_versions.cudnn_get_version,
|
||||
cuda_versions.cudnn_build_version,
|
||||
# NVIDIA promise both backwards and forwards compatibility for cuDNN patch
|
||||
# versions: https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
|
||||
# versions:
|
||||
# https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
|
||||
scale_for_comparison=100,
|
||||
min_supported_version=8900
|
||||
)
|
||||
_version_check("cuFFT", cuda_versions.cufft_get_version,
|
||||
cuda_versions.cufft_build_version,
|
||||
@ -302,20 +371,42 @@ def _check_cuda_versions():
|
||||
_version_check("cuSOLVER", cuda_versions.cusolver_get_version,
|
||||
cuda_versions.cusolver_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100)
|
||||
scale_for_comparison=100,
|
||||
min_supported_version=11400)
|
||||
_version_check("cuPTI", cuda_versions.cupti_get_version,
|
||||
cuda_versions.cupti_build_version)
|
||||
cuda_versions.cupti_build_version,
|
||||
min_supported_version=18)
|
||||
# TODO(jakevdp) remove these checks when minimum jaxlib is v0.4.21
|
||||
if hasattr(cuda_versions, "cublas_get_version"):
|
||||
_version_check("cuBLAS", cuda_versions.cublas_get_version,
|
||||
cuda_versions.cublas_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100)
|
||||
scale_for_comparison=100,
|
||||
min_supported_version=120100)
|
||||
if hasattr(cuda_versions, "cusparse_get_version"):
|
||||
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
|
||||
cuda_versions.cusparse_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100)
|
||||
scale_for_comparison=100,
|
||||
min_supported_version=12100)
|
||||
|
||||
errors = []
|
||||
debug_results = []
|
||||
for result in results:
|
||||
message: str = result['msg']
|
||||
if not result['installed'] or not result['passed']:
|
||||
errors.append(message)
|
||||
else:
|
||||
debug_results.append(message)
|
||||
|
||||
join_str = f'\n{"-" * 50}\n'
|
||||
if debug_results:
|
||||
print(f'CUDA components status (debug):\n'
|
||||
f'{join_str.join(debug_results)}')
|
||||
if errors:
|
||||
raise RuntimeError(f'Unable to use CUDA because of the '
|
||||
f'following issues with CUDA components:\n'
|
||||
f'{join_str.join(errors)}')
|
||||
|
||||
|
||||
def make_gpu_client(
|
||||
@ -335,6 +426,10 @@ def make_gpu_client(
|
||||
if platform_name == "cuda":
|
||||
if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"):
|
||||
_check_cuda_versions()
|
||||
else:
|
||||
print('Skipped CUDA versions constraints check due to the '
|
||||
'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.')
|
||||
|
||||
# TODO(micky774): remove this check when minimum jaxlib is v0.4.26
|
||||
if jaxlib.version.__version_info__ >= (0, 4, 26):
|
||||
devices_to_check = (allowed_devices if allowed_devices else
|
||||
|
@ -25,7 +25,7 @@ import jax._src.xla_bridge as xb
|
||||
|
||||
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
|
||||
# preinstalled jax cuda plugin packages.
|
||||
for pkg_name in ['jax_cuda12_plugin', 'jax_cuda11_plugin', 'jaxlib']:
|
||||
for pkg_name in ['jax_cuda12_plugin', 'jaxlib']:
|
||||
try:
|
||||
cuda_plugin_extension = importlib.import_module(
|
||||
f'{pkg_name}.cuda_plugin_extension'
|
||||
|
@ -24,7 +24,7 @@ from .gpu_common_utils import GpuLibNotLinkedError
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cuda_linalg = importlib.import_module(
|
||||
f"{cuda_module_name}._linalg", package="jaxlib"
|
||||
|
@ -27,7 +27,7 @@ from jaxlib import xla_client
|
||||
from .hlo_helpers import custom_call
|
||||
from .gpu_common_utils import GpuLibNotLinkedError
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cuda_prng = importlib.import_module(
|
||||
f"{cuda_module_name}._prng", package="jaxlib"
|
||||
|
@ -22,7 +22,7 @@ import numpy as np
|
||||
from jaxlib import xla_client
|
||||
from .gpu_common_utils import GpuLibNotLinkedError
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_rnn = importlib.import_module(f"{cuda_module_name}._rnn", package="jaxlib")
|
||||
except ImportError:
|
||||
|
@ -33,7 +33,7 @@ from .hlo_helpers import (
|
||||
try:
|
||||
from .cuda import _blas as _cublas # pytype: disable=import-error
|
||||
except ImportError:
|
||||
for cuda_module_name in ["jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
for cuda_module_name in ["jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cublas = importlib.import_module(f"{cuda_module_name}._blas")
|
||||
except ImportError:
|
||||
@ -45,7 +45,7 @@ if _cublas:
|
||||
for _name, _value in _cublas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cusolver = importlib.import_module(
|
||||
f"{cuda_module_name}._solver", package="jaxlib"
|
||||
|
@ -27,7 +27,7 @@ from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import custom_call, mk_result_types_and_shapes
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cusparse = importlib.import_module(
|
||||
f"{cuda_module_name}._sparse", package="jaxlib"
|
||||
|
@ -15,7 +15,7 @@ import importlib
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
try:
|
||||
_cuda_triton = importlib.import_module(
|
||||
f"{cuda_module_name}._triton", package="jaxlib"
|
||||
|
@ -67,25 +67,17 @@ setup(
|
||||
'ml_dtypes>=0.2.0',
|
||||
],
|
||||
extras_require={
|
||||
'cuda11_pip': [
|
||||
"nvidia-cublas-cu11>=11.11",
|
||||
"nvidia-cuda-cupti-cu11>=11.8",
|
||||
"nvidia-cuda-nvcc-cu11>=11.8",
|
||||
"nvidia-cuda-runtime-cu11>=11.8",
|
||||
"nvidia-cudnn-cu11>=8.8",
|
||||
"nvidia-cufft-cu11>=10.9",
|
||||
"nvidia-cusolver-cu11>=11.4",
|
||||
"nvidia-cusparse-cu11>=11.7",
|
||||
],
|
||||
'cuda12_pip': [
|
||||
"nvidia-cublas-cu12",
|
||||
"nvidia-cuda-cupti-cu12",
|
||||
"nvidia-cuda-nvcc-cu12",
|
||||
"nvidia-cuda-runtime-cu12",
|
||||
"nvidia-cudnn-cu12>=8.9",
|
||||
"nvidia-cufft-cu12",
|
||||
"nvidia-cusolver-cu12",
|
||||
"nvidia-cusparse-cu12",
|
||||
"nvidia-cublas-cu12>=12.1.3.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.1.105",
|
||||
"nvidia-cuda-nvcc-cu12>=12.1.105",
|
||||
"nvidia-cuda-runtime-cu12>=12.1.105",
|
||||
"nvidia-cudnn-cu12>=8.9.2.26",
|
||||
"nvidia-cufft-cu12>=11.0.2.54",
|
||||
"nvidia-cusolver-cu12>=11.4.5.107",
|
||||
"nvidia-cusparse-cu12>=12.1.0.106",
|
||||
"nvidia-nccl-cu12>=2.18.1",
|
||||
"nvidia-nvjitlink-cu12>=12.1.105",
|
||||
],
|
||||
},
|
||||
url='https://github.com/google/jax',
|
||||
|
89
setup.py
89
setup.py
@ -25,9 +25,8 @@ project_name = 'jax'
|
||||
_current_jaxlib_version = '0.4.25'
|
||||
# The following should be updated with each new jaxlib release.
|
||||
_latest_jaxlib_version_on_pypi = '0.4.25'
|
||||
_available_cuda11_cudnn_versions = ['86']
|
||||
_default_cuda11_cudnn_version = '86'
|
||||
_default_cuda12_cudnn_version = '89'
|
||||
_available_cuda12_cudnn_versions = [_default_cuda12_cudnn_version]
|
||||
_libtpu_version = '0.1.dev20240224'
|
||||
|
||||
def load_version_module(pkg_path):
|
||||
@ -110,78 +109,62 @@ setup(
|
||||
# CUDA installations require adding the JAX CUDA releases URL, e.g.,
|
||||
# Cuda installation defaulting to a CUDA and Cudnn version defined above.
|
||||
# $ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{_default_cuda11_cudnn_version}"],
|
||||
'cuda': [f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}"],
|
||||
|
||||
'cuda11_pip': [
|
||||
f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{_default_cuda11_cudnn_version}",
|
||||
"nvidia-cublas-cu11>=11.11",
|
||||
"nvidia-cuda-cupti-cu11>=11.8",
|
||||
"nvidia-cuda-nvcc-cu11>=11.8",
|
||||
"nvidia-cuda-runtime-cu11>=11.8",
|
||||
"nvidia-cudnn-cu11>=8.8",
|
||||
"nvidia-cufft-cu11>=10.9",
|
||||
"nvidia-cusolver-cu11>=11.4",
|
||||
"nvidia-cusparse-cu11>=11.7",
|
||||
"nvidia-nccl-cu11>=2.18.3",
|
||||
],
|
||||
|
||||
'cuda12_pip': [
|
||||
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
|
||||
"nvidia-cublas-cu12>=12.3.4.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.3.101",
|
||||
"nvidia-cuda-nvcc-cu12>=12.3.107",
|
||||
"nvidia-cuda-runtime-cu12>=12.3.101",
|
||||
"nvidia-cudnn-cu12>=8.9.7.29",
|
||||
"nvidia-cufft-cu12>=11.0.12.1",
|
||||
"nvidia-cusolver-cu12>=11.5.4.101",
|
||||
"nvidia-cusparse-cu12>=12.2.0.103",
|
||||
"nvidia-nccl-cu12>=2.19.3",
|
||||
# nvjitlink is not a direct dependency of JAX, but it is a transitive
|
||||
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
|
||||
# do not have a version constraint on their dependencies, so the
|
||||
# package doesn't get upgraded even though not doing that can cause
|
||||
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
|
||||
# Until NVIDIA add version constraints, add an version constraint
|
||||
# here.
|
||||
"nvidia-nvjitlink-cu12>=12.3.101",
|
||||
],
|
||||
'cuda12_pip': [
|
||||
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
|
||||
"nvidia-cublas-cu12>=12.1.3.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.1.105",
|
||||
"nvidia-cuda-nvcc-cu12>=12.1.105",
|
||||
"nvidia-cuda-runtime-cu12>=12.1.105",
|
||||
"nvidia-cudnn-cu12>=8.9.2.26",
|
||||
"nvidia-cufft-cu12>=11.0.2.54",
|
||||
"nvidia-cusolver-cu12>=11.4.5.107",
|
||||
"nvidia-cusparse-cu12>=12.1.0.106",
|
||||
"nvidia-nccl-cu12>=2.18.1",
|
||||
# nvjitlink is not a direct dependency of JAX, but it is a transitive
|
||||
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
|
||||
# do not have a version constraint on their dependencies, so the
|
||||
# package doesn't get upgraded even though not doing that can cause
|
||||
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
|
||||
# Until NVIDIA add version constraints, add a version constraint
|
||||
# here.
|
||||
"nvidia-nvjitlink-cu12>=12.1.105",
|
||||
],
|
||||
|
||||
'cuda12': [
|
||||
f"jaxlib=={_current_jaxlib_version}",
|
||||
f"jax-cuda12-plugin=={_current_jaxlib_version}",
|
||||
"nvidia-cublas-cu12>=12.3.4.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.3.101",
|
||||
"nvidia-cuda-nvcc-cu12>=12.3.107",
|
||||
"nvidia-cuda-runtime-cu12>=12.3.101",
|
||||
"nvidia-cudnn-cu12>=8.9.7.29",
|
||||
"nvidia-cufft-cu12>=11.0.12.1",
|
||||
"nvidia-cusolver-cu12>=11.5.4.101",
|
||||
"nvidia-cusparse-cu12>=12.2.0.103",
|
||||
"nvidia-nccl-cu12>=2.19.3",
|
||||
"nvidia-cublas-cu12>=12.1.3.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.1.105",
|
||||
"nvidia-cuda-nvcc-cu12>=12.1.105",
|
||||
"nvidia-cuda-runtime-cu12>=12.1.105",
|
||||
"nvidia-cudnn-cu12>=8.9.2.26",
|
||||
"nvidia-cufft-cu12>=11.0.2.54",
|
||||
"nvidia-cusolver-cu12>=11.4.5.107",
|
||||
"nvidia-cusparse-cu12>=12.1.0.106",
|
||||
"nvidia-nccl-cu12>=2.18.1",
|
||||
# nvjitlink is not a direct dependency of JAX, but it is a transitive
|
||||
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
|
||||
# do not have a version constraint on their dependencies, so the
|
||||
# package doesn't get upgraded even though not doing that can cause
|
||||
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
|
||||
# Until NVIDIA add version constraints, add an version constraint
|
||||
# Until NVIDIA add version constraints, add a version constraint
|
||||
# here.
|
||||
"nvidia-nvjitlink-cu12>=12.3.101",
|
||||
"nvidia-nvjitlink-cu12>=12.1.105",
|
||||
],
|
||||
|
||||
# Target that does not depend on the CUDA pip wheels, for those who want
|
||||
# to use a preinstalled CUDA.
|
||||
'cuda11_local': [
|
||||
f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{_default_cuda11_cudnn_version}",
|
||||
],
|
||||
'cuda12_local': [
|
||||
f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{_default_cuda12_cudnn_version}",
|
||||
],
|
||||
|
||||
# CUDA installations require adding jax releases URL; e.g.
|
||||
# $ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
# $ pip install jax[cuda11_cudnn86] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
**{f'cuda11_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda11.cudnn{cudnn_version}"
|
||||
for cudnn_version in _available_cuda11_cudnn_versions}
|
||||
# $ pip install jax[cuda12_cudnn89] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
**{f'cuda12_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda12.cudnn{cudnn_version}"
|
||||
for cudnn_version in _available_cuda12_cudnn_versions}
|
||||
},
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
|
Loading…
x
Reference in New Issue
Block a user