Copybara import of the project:

--
ac2c522dfe1e90ce7999e29f0fdf9660782db73d by Meekail Zain <zainmeekail@gmail.com>:

[FIX] Added jaxlib version guard for CUDA compute capability check

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/20237 from Micky774:add_version_guard ac2c522dfe1e90ce7999e29f0fdf9660782db73d
PiperOrigin-RevId: 616925918
This commit is contained in:
Meekail Zain 2024-03-18 13:15:50 -07:00 committed by jax authors
parent ee2631e4da
commit bc363de8a5

View File

@ -46,6 +46,7 @@ from jax._src.lib import cuda_versions
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib import jaxlib
logger = logging.getLogger(__name__)
@ -333,7 +334,10 @@ def make_gpu_client(
)
if platform_name == "cuda":
_check_cuda_versions()
devices_to_check = allowed_devices if allowed_devices else range(cuda_versions.cuda_device_count())
# TODO(micky774): remove this check when minimum jaxlib is v0.4.26
if jaxlib.version.__version_info__ >= (0, 4, 26):
devices_to_check = (allowed_devices if allowed_devices else
range(cuda_versions.cuda_device_count()))
_check_cuda_compute_capability(devices_to_check)
return xla_client.make_gpu_client(