xla_bridge: add logic to avoid version skew

This commit is contained in:
Jake VanderPlas 2023-11-22 12:17:21 -08:00
parent 961ba3cd42
commit 530cf30bfc

View File

@ -253,14 +253,17 @@ def _check_cuda_versions():
scale_for_comparison=100)
_version_check("cuPTI", cuda_versions.cupti_get_version,
cuda_versions.cupti_build_version)
_version_check("cuBLAS", cuda_versions.cublas_get_version,
cuda_versions.cublas_build_version,
# Ignore patch versions.
scale_for_comparison=100)
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
cuda_versions.cusparse_build_version,
# Ignore patch versions.
scale_for_comparison=100)
# TODO(jakevdp) remove these checks when minimum jaxlib is v0.4.21
if hasattr(cuda_versions, "cublas_get_version"):
_version_check("cuBLAS", cuda_versions.cublas_get_version,
cuda_versions.cublas_build_version,
# Ignore patch versions.
scale_for_comparison=100)
if hasattr(cuda_versions, "cusparse_get_version"):
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
cuda_versions.cusparse_build_version,
# Ignore patch versions.
scale_for_comparison=100)
def make_gpu_client(