mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Copybara import of the project:
-- ac2c522dfe1e90ce7999e29f0fdf9660782db73d by Meekail Zain <zainmeekail@gmail.com>: [FIX] Added jaxlib version guard for CUDA compute capability check COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/20237 from Micky774:add_version_guard ac2c522dfe1e90ce7999e29f0fdf9660782db73d PiperOrigin-RevId: 616925918
This commit is contained in:
parent
ee2631e4da
commit
bc363de8a5
@ -46,6 +46,7 @@ from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import jaxlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -333,7 +334,10 @@ def make_gpu_client(
|
||||
)
|
||||
if platform_name == "cuda":
|
||||
_check_cuda_versions()
|
||||
devices_to_check = allowed_devices if allowed_devices else range(cuda_versions.cuda_device_count())
|
||||
# TODO(micky774): remove this check when minimum jaxlib is v0.4.26
|
||||
if jaxlib.version.__version_info__ >= (0, 4, 26):
|
||||
devices_to_check = (allowed_devices if allowed_devices else
|
||||
range(cuda_versions.cuda_device_count()))
|
||||
_check_cuda_compute_capability(devices_to_check)
|
||||
|
||||
return xla_client.make_gpu_client(
|
||||
|
Loading…
x
Reference in New Issue
Block a user