mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Removed unnecessary jaxlib version guards from xla_bridge
The minimum jaxlib version is 0.4.27. PiperOrigin-RevId: 640513280
This commit is contained in:
parent
557cae65d1
commit
e09cda8fa9
@ -20,8 +20,6 @@ XLA. There are also a handful of related casting utilities.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
from collections.abc import Mapping
|
||||
import dataclasses
|
||||
@ -32,23 +30,22 @@ import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import platform as py_platform
|
||||
import traceback
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Any, Callable, Union
|
||||
import warnings
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import distributed
|
||||
from jax._src import hardware_utils
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src import hardware_utils
|
||||
from jax._src.cloud_tpu_init import maybe_import_libtpu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import jaxlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -400,19 +397,16 @@ def _check_cuda_versions(raise_on_first_error: bool = False,
|
||||
_version_check("cuPTI", cuda_versions.cupti_get_version,
|
||||
cuda_versions.cupti_build_version,
|
||||
min_supported_version=18)
|
||||
# TODO(jakevdp) remove these checks when minimum jaxlib is v0.4.21
|
||||
if hasattr(cuda_versions, "cublas_get_version"):
|
||||
_version_check("cuBLAS", cuda_versions.cublas_get_version,
|
||||
cuda_versions.cublas_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100,
|
||||
min_supported_version=120100)
|
||||
if hasattr(cuda_versions, "cusparse_get_version"):
|
||||
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
|
||||
cuda_versions.cusparse_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100,
|
||||
min_supported_version=12100)
|
||||
_version_check("cuBLAS", cuda_versions.cublas_get_version,
|
||||
cuda_versions.cublas_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100,
|
||||
min_supported_version=120100)
|
||||
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
|
||||
cuda_versions.cusparse_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100,
|
||||
min_supported_version=12100)
|
||||
|
||||
errors = []
|
||||
debug_results = []
|
||||
@ -454,11 +448,12 @@ def make_gpu_client(
|
||||
print('Skipped CUDA versions constraints check due to the '
|
||||
'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.')
|
||||
|
||||
# TODO(micky774): remove this check when minimum jaxlib is v0.4.26
|
||||
if jaxlib.version.__version_info__ >= (0, 4, 26):
|
||||
devices_to_check = (allowed_devices if allowed_devices else
|
||||
range(cuda_versions.cuda_device_count()))
|
||||
_check_cuda_compute_capability(devices_to_check)
|
||||
devices_to_check = (
|
||||
allowed_devices
|
||||
if allowed_devices
|
||||
else range(cuda_versions.cuda_device_count())
|
||||
)
|
||||
_check_cuda_compute_capability(devices_to_check)
|
||||
|
||||
return xla_client.make_gpu_client(
|
||||
distributed_client=distributed.global_state.client,
|
||||
|
Loading…
x
Reference in New Issue
Block a user