diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index cc61d28fa..1722821b2 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -46,6 +46,7 @@ from jax._src.lib import cuda_versions from jax._src.lib import xla_client from jax._src.lib import xla_extension from jax._src.lib import xla_extension_version +from jax._src.lib import jaxlib logger = logging.getLogger(__name__) @@ -333,8 +334,11 @@ def make_gpu_client( ) if platform_name == "cuda": _check_cuda_versions() - devices_to_check = allowed_devices if allowed_devices else range(cuda_versions.cuda_device_count()) - _check_cuda_compute_capability(devices_to_check) + # TODO(micky774): remove this check when minimum jaxlib is v0.4.26 + if jaxlib.version.__version_info__ >= (0, 4, 26): + devices_to_check = (allowed_devices if allowed_devices else + range(cuda_versions.cuda_device_count())) + _check_cuda_compute_capability(devices_to_check) return xla_client.make_gpu_client( distributed_client=distributed.global_state.client,