Remove code present to support jaxlib < 0.5.1.

The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
This commit is contained in:
Peter Hawkins 2025-02-24 17:45:19 -05:00
parent 99a12ef9ea
commit 66293d8897
22 changed files with 68 additions and 206 deletions

View File

@ -40,7 +40,6 @@ from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.lib import xla_extension_version
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding,
@ -212,90 +211,51 @@ class ArrayImpl(basearray.Array):
arrays = self._check_and_rearrange(arrays, self._sharding, self.aval)
self._arrays = arrays
if xla_extension_version >= 310:
def _check_and_rearrange(self, arrays, sharding, aval):
device_id_to_buffer = {_get_device(db).id: db for db in arrays}
def _check_and_rearrange(self, arrays, sharding, aval):
device_id_to_buffer = {_get_device(db).id: db for db in arrays}
addressable_dev = sharding.addressable_devices
if len(arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(arrays)}")
addressable_dev = sharding.addressable_devices
if len(arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(arrays)}")
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
if len(array_device_ids) != len(arrays):
buffer_device_ids = [_get_device(db).id for db in arrays]
raise ValueError(
"When making an array from single-device arrays, the input arrays"
" must be from distinct devices, but got device IDs"
f" {buffer_device_ids}")
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
if len(array_device_ids) != len(arrays):
buffer_device_ids = [_get_device(db).id for db in arrays]
raise ValueError(
"When making an array from single-device arrays, the input arrays"
" must be from distinct devices, but got device IDs"
f" {buffer_device_ids}")
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)
_validate_shape_and_dtype_for_per_device_arrays(
arrays,
sharding=sharding,
aval=aval,
expected_shape=sharding.shard_shape(aval.shape),
)
_validate_shape_and_dtype_for_per_device_arrays(
arrays,
sharding=sharding,
aval=aval,
expected_shape=sharding.shard_shape(aval.shape),
)
# Rearrange arrays based on the device assignment.
addressable_da = sharding._addressable_device_assignment
return [device_id_to_buffer[device.id] for device in addressable_da]
else:
def _check_and_rearrange(self): # type: ignore
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(self._arrays)}")
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)
_validate_shape_and_dtype_for_per_device_arrays(
self._arrays,
sharding=self.sharding,
aval=self.aval,
expected_shape=self.sharding.shard_shape(self.shape),
)
# Rearrange arrays based on the device assignment.
addressable_da = self.sharding._addressable_device_assignment
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
# Rearrange arrays based on the device assignment.
addressable_da = sharding._addressable_device_assignment
return [device_id_to_buffer[device.id] for device in addressable_da]
@property
def shape(self) -> Shape:
@ -652,18 +612,9 @@ class ArrayImpl(basearray.Array):
db.block_until_ready()
return self
if xla_extension_version >= 314:
@use_cpp_method()
def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: # type: ignore
... # pytype: disable=bad-return-type
else:
@use_cpp_method()
def _single_device_array_to_np_array(self):
return np.asarray(self._arrays[0])
def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]:
return cast(np.ndarray, self._single_device_array_to_np_array()), True
@use_cpp_method()
def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: # type: ignore
... # pytype: disable=bad-return-type
@use_cpp_method()
def _copy_single_device_array_to_host_async(self):

View File

@ -33,7 +33,6 @@ from jax._src import profiler
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
import numpy as np
@ -191,13 +190,12 @@ def get_compile_options(
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value
if xla_extension_version >= 316:
build_options.optimization_level = config.EffortLevel(
config.optimization_level.value
).value
build_options.memory_fitting_level = config.EffortLevel(
config.memory_fitting_level.value
).value
build_options.optimization_level = config.EffortLevel(
config.optimization_level.value
).value
build_options.memory_fitting_level = config.EffortLevel(
config.memory_fitting_level.value
).value
# This is a temporary workaround to simplify the AutoPGLE usage.
# TODO(b/376647494): Remove once the bug is fixed.
@ -212,10 +210,9 @@ def get_compile_options(
# Some overrides are passed directly on build_options.
overrides_on_build_options = [
"exec_time_optimization_effort", "memory_fitting_effort"]
if xla_extension_version >= 316:
overrides_on_build_options.extend(
["optimization_level", "memory_fitting_level"]
)
overrides_on_build_options.extend(
["optimization_level", "memory_fitting_level"]
)
env_options_overrides = dict(env_options_overrides)
for name in overrides_on_build_options:

View File

@ -45,7 +45,6 @@ from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.monitoring import record_event_duration_secs, record_event_time_span
from jax._src.partition_spec import PartitionSpec
@ -419,12 +418,8 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics):
def _reorder_shards(x, new_s, copy_semantics: CopySemantics):
"""Reorders array shards to match the order indicated by the new sharding."""
if xla_extension_version >= 304:
xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0]
return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore
else:
assert copy_semantics == CopySemantics.ALIAS
return array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays)
xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0]
return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore
@dataclasses.dataclass(frozen=True)

View File

@ -33,7 +33,6 @@ import ml_dtypes
import numpy as np
from jax._src import config
from jax._src.lib import xla_extension_version
from jax._src.typing import Array, DType, DTypeLike
from jax._src.util import set_module, StrictABC
@ -518,7 +517,7 @@ _complex_types: list[JAXType] = [
# only meant for the `jnp.isdtype` and we want to be conservative and not allow
# StringDType to be used in there.
_string_types: list[JAXType] = []
if hasattr(np.dtypes, 'StringDType') and xla_extension_version >= 311:
if hasattr(np.dtypes, 'StringDType'):
_string_types: list[JAXType] = [np.dtypes.StringDType()] # type: ignore
_jax_dtype_set = {

View File

@ -55,7 +55,6 @@ from jax._src.sharding_impls import (AUTO, NamedSharding,
SdyArraySharding, SdyArrayShardingList)
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import dialects, ir, passmanager
from jax._src.lib.mlir.dialects import func as func_dialect, hlo
from jax._src.lib.mlir import register_jax_dialects
@ -479,20 +478,13 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
loc = ctx.traceback_caches.location_cache.get(code_lasti, None)
if loc is None:
frame = source_info_util.raw_frame_to_frame(code, lasti)
if xla_extension_version >= 309:
file_loc = ir.Location.file(
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
frame.start_line,
frame.start_column,
frame.end_line,
frame.end_column,
)
else:
file_loc = ir.Location.file(
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
frame.start_line,
frame.start_column,
)
file_loc = ir.Location.file(
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
frame.start_line,
frame.start_column,
frame.end_line,
frame.end_column,
)
loc = ir.Location.name(frame.function_name, childLoc=file_loc)
ctx.traceback_caches.location_cache[code_lasti] = loc
frame_locs.append(loc)

View File

@ -49,7 +49,6 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax.lax import (PrecisionLike,_array_copy,
_sort_le_comparator, _sort_lt_comparator)
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.numpy.array_creation import (empty, empty_like, full,
ones, ones_like, zeros, zeros_like)
from jax._src.numpy import indexing
@ -5357,11 +5356,6 @@ def _make_string_array(
ndmin: int = 0,
device: xc.Device | Sharding | None = None,
) -> Array:
if xla_extension_version < 311:
raise TypeError(
"String arrays are not supported in JAX before XLA extension version"
" 311."
)
if not isinstance(object, np.ndarray):
raise TypeError(
"Currently, string arrays can only be made from NumPy"

View File

@ -25,7 +25,6 @@ import jax._src.core as jax_core
from jax._src.interpreters import mlir
from jax._src.lib import triton
from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core
from jax._src.pallas.triton import lowering
@ -98,7 +97,7 @@ def pallas_call_lowering(
module_op.write_bytecode(buf)
# TODO(b/394629193): Remove True once the bug is fixed.
if True or jaxlib_version < (0, 5, 1):
if True:
# AOT Triton compilation is only available on jaxlib 0.5.1+.
out_types = [
ir.RankedTensorType.get(bm.array_shape_dtype.shape,

View File

@ -591,12 +591,7 @@ def _call_tf_lowering(
call = func_dialect.CallOp(callee_result_types,
ir.FlatSymbolRefAttr.get(fn),
tuple(args_op) + captured_ops)
if result_shape.is_tuple() and xla_client.mlir_api_version < 58:
# In API version 58, the results are always flattened.
flat_results = [hlo.get_tuple_element(call, mlir.i32_attr(i))
for i in range(len(result_shapes))]
else:
flat_results = call.results
flat_results = call.results
if ordered:
raise NotImplementedError(

View File

@ -16,7 +16,6 @@
import jax
from typing import Any, TYPE_CHECKING
from jax._src.lib import xla_client as _xc
from jax._src.lib import xla_extension_version
from jax._src.util import use_cpp_class, use_cpp_method
class TransferConnection:
@ -42,7 +41,7 @@ class TransferConnection:
return tree.unflatten(self._pull_flat(uuid, backend, xs_flat))
if not TYPE_CHECKING and xla_extension_version >= 305:
if not TYPE_CHECKING:
TransferConnection = use_cpp_class(_xc._xla.TransferConnection)(TransferConnection)
@ -67,8 +66,7 @@ class TransferServer:
self._await_pull_flat(uuid, jax.tree.flatten(arrays)[0])
if not TYPE_CHECKING and xla_extension_version >= 305:
if not TYPE_CHECKING:
TransferServer = use_cpp_class(_xc._xla.TransferServer)(TransferServer)
if xla_extension_version >= 305:
start_transfer_server = _xc._xla.start_transfer_server
start_transfer_server = _xc._xla.start_transfer_server

View File

@ -61,7 +61,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.compilation_cache import is_persistent_cache_enabled
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
@ -1371,8 +1370,6 @@ class JitTest(jtu.BufferDonationTestCase):
def f(x):
return jnp.sqrt(x**2) + 1.0
if xla_extension_version < 316:
self.skipTest("Requires XLA extension version >= 316")
f_jit = jit(
f,
compiler_options={
@ -1386,8 +1383,6 @@ class JitTest(jtu.BufferDonationTestCase):
def f(x):
return jnp.sqrt(x**2) + 1.0
if xla_extension_version < 316:
self.skipTest("Requires XLA extension version >= 316")
f_jit = jit(
f,
compiler_options={

View File

@ -29,7 +29,6 @@ from jax._src import op_shardings
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import dialects, ir
from jax._src.util import safe_zip
from jax._src.mesh import AxisTypes
@ -825,8 +824,6 @@ class JaxArrayTest(jtu.JaxTestCase):
@parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
@jtu.run_on_devices("gpu")
def test_pinned_host_npy_value_doesnt_cache(self, dtype):
if xla_extension_version < 314:
self.skipTest("Requires XLA extension version >= 314")
# see https://github.com/jax-ml/jax/issues/26216
d_tensor = jnp.array(0, dtype=dtype)
d_sharding = d_tensor.sharding

View File

@ -25,7 +25,6 @@ from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax.experimental import colocated_python
from jax.experimental.colocated_python import serialization
from jax.extend.ifrt_programs import ifrt_programs
@ -383,11 +382,6 @@ class ColocatedPythonTest(jtu.JaxTestCase):
del colocated_python._testing_global_state
def testStringProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
self.skipTest("StringDType requires NumPy 2.0.0 or later")
cpu_devices = _colocated_cpu_devices(jax.local_devices())
@ -431,11 +425,6 @@ class ColocatedPythonTest(jtu.JaxTestCase):
)
def testBinaryDataProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
self.skipTest("StringDType requires NumPy 2.0.0 or later")
cpu_devices = _colocated_cpu_devices(jax.local_devices())

View File

@ -25,7 +25,6 @@ from jax._src import ad_checkpoint
from jax._src import debugging
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
import jax.numpy as jnp
import numpy as np
@ -62,9 +61,6 @@ class DebugCallbackTest(jtu.JaxTestCase):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.run_on_devices("cpu")
def test_async_deadlock(self):
if xla_extension_version < 306:
self.skipTest("deadlock expected")
# See https://github.com/jax-ml/jax/issues/25861
def print_it(i, maxiter):
self.assertIsInstance(i, jax.Array)

View File

@ -32,7 +32,7 @@ from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib import lapack, xla_extension_version
from jax._src.lib import lapack
from jax._src.lib.mlir.dialects import hlo
from jax._src.lax import linalg as lax_linalg_internal
from jax.experimental.shard_map import shard_map
@ -333,8 +333,6 @@ def ffi_call_geqrf(x, _use_extend=False, **kwargs):
class BatchPartitioningTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_extension_version < 313:
self.skipTest("Requires XLA extension version >= 313")
# Register callbacks before checking the number of devices to make sure
# that we're testing the registration path, even if we can't run the tests.
for target_name in ["lapack_sgeqrf_ffi", "cusolver_geqrf_ffi",

View File

@ -23,7 +23,6 @@ from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding
from jax._src import config
from jax._src.layout import Layout, DeviceLocalLayout as DLL
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip
from jax.experimental.compute_on import compute_on
@ -33,12 +32,6 @@ jtu.request_cpu_devices(8)
class LayoutTest(jtu.JaxTestCase):
# Remove this setUp once the released xla_extension_version is >= 308.
def setUp(self):
if xla_extension_version < 308 and not jtu.test_device_matches(['tpu', 'gpu']):
self.skipTest("Layouts do not work on CPU backend yet.")
super().setUp()
def test_auto_layout(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
shape1 = (128, 128)

View File

@ -1707,8 +1707,6 @@ class ScipyLinalgTest(jtu.JaxTestCase):
if pivoting:
if not jtu.test_device_matches(["cpu", "gpu"]):
self.skipTest("Pivoting is only supported on CPU and GPU.")
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 5, 0):
self.skipTest("Pivoting is only supported on GPU for jaxlib > 0.5.0")
rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting)
sp_func = partial(scipy.linalg.qr, mode=mode, pivoting=pivoting)

View File

@ -120,8 +120,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
)
@jtu.run_on_devices("gpu")
def testPivotedQrFactorization(self, shape, dtype):
if jtu.jaxlib_version() <= (0, 5, 0):
self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
rng = jtu.rand_default(self.rng())
@ -133,8 +131,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(lax_func, args_maker)
def testPivotedQrFactorizationMagmaConfig(self):
if jtu.jaxlib_version() <= (0, 5, 0):
self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
rng = jtu.rand_default(self.rng())

View File

@ -26,7 +26,6 @@ from jax import lax
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.layout import DeviceLocalLayout as DLL, Layout
from jax._src.lib import xla_extension_version
from jax._src import config
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.numpy as jnp
@ -1664,8 +1663,6 @@ class ComputeOffload(jtu.BufferDonationTestCase):
class StreamAnnotationTest(jtu.JaxTestCase):
def test_stream_annotation_inside_shmap(self):
if xla_extension_version < 313:
self.skipTest("Requires xla_extension_version >= 313")
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Stream annotation is only supported on GPU.")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))

View File

@ -29,7 +29,6 @@ import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir import passmanager
from jax._src.lib.mlir.dialects import arith
@ -2701,9 +2700,6 @@ class UtilsTest(TestCase):
class SerializationTest(absltest.TestCase):
def test_pass_is_registered(self):
if jaxlib_version < (0, 5, 1):
self.skipTest("Test requires jaxlib 0.5.1 or later")
ctx = mlir.make_ir_context()
ctx.allow_unregistered_dialects = True
with ir.Location.unknown(ctx):

View File

@ -62,7 +62,6 @@ from jax._src.lib.mlir import dialects
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.util import curry, unzip2
config.parse_flags_with_absl()
@ -7147,9 +7146,6 @@ class PJitErrorTest(jtu.JaxTestCase):
f(arr, 2., 3.) # doesn't crash
def test_named_sharding_of_none(self):
if xla_extension_version < 309:
raise unittest.SkipTest("NamedSharding does't reject None.")
mesh = jtu.create_mesh((2,), ('x',))
with self.assertRaisesRegex(TypeError, 'Unexpected None'):
jax.NamedSharding(mesh, None)

View File

@ -37,8 +37,6 @@ class RaggedCollectiveTest(jtu.JaxTestCase):
super().setUp()
if jtu.test_device_matches(['cpu']):
self.skipTest('ragged-all-to-all is not supported on CPU')
if jtu.jaxlib_version() < (0, 5, 1):
self.skipTest('ragged-all-to-all is not supported on jaxlib version < 0.5.1')
@parameterized.named_parameters(
dict(

View File

@ -18,7 +18,6 @@ import jax
from jax import numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
import numpy as np
config.parse_flags_with_absl()
@ -29,12 +28,6 @@ class StringArrayTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_extension_version < 311:
self.skipTest(
"Skipping this test because the current XLA extension version:"
f" {xla_extension_version} is older than 309, the oldest version with"
" string array support."
)
if not hasattr(np.dtypes, "StringDType"):
self.skipTest(
"Skipping this test because the numpy.dtype.StringDType is not"