Relax cuDNN version compatibility test to ignore patch versions.

PiperOrigin-RevId: 569020492
This commit is contained in:
Peter Hawkins 2023-09-27 18:33:14 -07:00 committed by jax authors
parent 59360794c1
commit 951298df64

View File

@ -174,13 +174,14 @@ def _check_cuda_versions():
if cuda_versions is None:
return
def _version_check(name, get_version, get_build_version):
def _version_check(name, get_version, get_build_version,
scale_for_comparison=1):
build_version = get_build_version()
try:
version = get_version()
except Exception as e:
raise RuntimeError(f"Unable to load {name}.") from e
if build_version > version:
if build_version // scale_for_comparison > version // scale_for_comparison:
raise RuntimeError(
f"Found {name} version {version}, but JAX was built against version "
f"{build_version}, which is newer. The copy of {name} that is "
@ -190,8 +191,14 @@ def _check_cuda_versions():
_version_check("CUDA", cuda_versions.cuda_runtime_get_version,
cuda_versions.cuda_runtime_build_version)
_version_check("cuDNN", cuda_versions.cudnn_get_version,
cuda_versions.cudnn_build_version)
_version_check(
"cuDNN",
cuda_versions.cudnn_get_version,
cuda_versions.cudnn_build_version,
# NVIDIA promise both backwards and forwards compatibility for cuDNN patch
# versions: https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
scale_for_comparison=100,
)
_version_check("cuFFT", cuda_versions.cufft_get_version,
cuda_versions.cufft_build_version)
_version_check("cuSOLVER", cuda_versions.cusolver_get_version,