Removed unnecessary jaxlib version guards from xla_bridge

The minimum jaxlib version is 0.4.27.

PiperOrigin-RevId: 640513280
This commit is contained in:
Sergei Lebedev 2024-06-05 07:04:19 -07:00 committed by jax authors
parent 557cae65d1
commit e09cda8fa9

View File

@ -20,8 +20,6 @@ XLA. There are also a handful of related casting utilities.
"""
from __future__ import annotations
from __future__ import annotations
import atexit
from collections.abc import Mapping
import dataclasses
@ -32,23 +30,22 @@ import logging
import os
import pkgutil
import platform as py_platform
import traceback
import sys
import threading
import traceback
from typing import Any, Callable, Union
import warnings
from jax._src import config
from jax._src import distributed
from jax._src import hardware_utils
from jax._src import traceback_util
from jax._src import util
from jax._src import hardware_utils
from jax._src.cloud_tpu_init import maybe_import_libtpu
from jax._src.lib import cuda_versions
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib import jaxlib
logger = logging.getLogger(__name__)
@ -400,19 +397,16 @@ def _check_cuda_versions(raise_on_first_error: bool = False,
_version_check("cuPTI", cuda_versions.cupti_get_version,
cuda_versions.cupti_build_version,
min_supported_version=18)
# TODO(jakevdp) remove these checks when minimum jaxlib is v0.4.21
if hasattr(cuda_versions, "cublas_get_version"):
_version_check("cuBLAS", cuda_versions.cublas_get_version,
cuda_versions.cublas_build_version,
# Ignore patch versions.
scale_for_comparison=100,
min_supported_version=120100)
if hasattr(cuda_versions, "cusparse_get_version"):
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
cuda_versions.cusparse_build_version,
# Ignore patch versions.
scale_for_comparison=100,
min_supported_version=12100)
_version_check("cuBLAS", cuda_versions.cublas_get_version,
cuda_versions.cublas_build_version,
# Ignore patch versions.
scale_for_comparison=100,
min_supported_version=120100)
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
cuda_versions.cusparse_build_version,
# Ignore patch versions.
scale_for_comparison=100,
min_supported_version=12100)
errors = []
debug_results = []
@ -454,11 +448,12 @@ def make_gpu_client(
print('Skipped CUDA versions constraints check due to the '
'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.')
# TODO(micky774): remove this check when minimum jaxlib is v0.4.26
if jaxlib.version.__version_info__ >= (0, 4, 26):
devices_to_check = (allowed_devices if allowed_devices else
range(cuda_versions.cuda_device_count()))
_check_cuda_compute_capability(devices_to_check)
devices_to_check = (
allowed_devices
if allowed_devices
else range(cuda_versions.cuda_device_count())
)
_check_cuda_compute_capability(devices_to_check)
return xla_client.make_gpu_client(
distributed_client=distributed.global_state.client,