mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
xla_bridge: add logic to avoid version skew
This commit is contained in:
parent
961ba3cd42
commit
530cf30bfc
@ -253,14 +253,17 @@ def _check_cuda_versions():
|
||||
scale_for_comparison=100)
|
||||
_version_check("cuPTI", cuda_versions.cupti_get_version,
|
||||
cuda_versions.cupti_build_version)
|
||||
_version_check("cuBLAS", cuda_versions.cublas_get_version,
|
||||
cuda_versions.cublas_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100)
|
||||
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
|
||||
cuda_versions.cusparse_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100)
|
||||
# TODO(jakevdp) remove these checks when minimum jaxlib is v0.4.21
|
||||
if hasattr(cuda_versions, "cublas_get_version"):
|
||||
_version_check("cuBLAS", cuda_versions.cublas_get_version,
|
||||
cuda_versions.cublas_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100)
|
||||
if hasattr(cuda_versions, "cusparse_get_version"):
|
||||
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
|
||||
cuda_versions.cusparse_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100)
|
||||
|
||||
|
||||
def make_gpu_client(
|
||||
|
Loading…
x
Reference in New Issue
Block a user