mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Relax cuDNN version compatibility test to ignore patch versions.
PiperOrigin-RevId: 569020492
This commit is contained in:
parent
59360794c1
commit
951298df64
@ -174,13 +174,14 @@ def _check_cuda_versions():
|
||||
if cuda_versions is None:
|
||||
return
|
||||
|
||||
def _version_check(name, get_version, get_build_version):
|
||||
def _version_check(name, get_version, get_build_version,
|
||||
scale_for_comparison=1):
|
||||
build_version = get_build_version()
|
||||
try:
|
||||
version = get_version()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unable to load {name}.") from e
|
||||
if build_version > version:
|
||||
if build_version // scale_for_comparison > version // scale_for_comparison:
|
||||
raise RuntimeError(
|
||||
f"Found {name} version {version}, but JAX was built against version "
|
||||
f"{build_version}, which is newer. The copy of {name} that is "
|
||||
@ -190,8 +191,14 @@ def _check_cuda_versions():
|
||||
|
||||
_version_check("CUDA", cuda_versions.cuda_runtime_get_version,
|
||||
cuda_versions.cuda_runtime_build_version)
|
||||
_version_check("cuDNN", cuda_versions.cudnn_get_version,
|
||||
cuda_versions.cudnn_build_version)
|
||||
_version_check(
|
||||
"cuDNN",
|
||||
cuda_versions.cudnn_get_version,
|
||||
cuda_versions.cudnn_build_version,
|
||||
# NVIDIA promise both backwards and forwards compatibility for cuDNN patch
|
||||
# versions: https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
|
||||
scale_for_comparison=100,
|
||||
)
|
||||
_version_check("cuFFT", cuda_versions.cufft_get_version,
|
||||
cuda_versions.cufft_build_version)
|
||||
_version_check("cuSOLVER", cuda_versions.cusolver_get_version,
|
||||
|
Loading…
x
Reference in New Issue
Block a user