diff --git a/jax/_src/array.py b/jax/_src/array.py index 47bf85661..47a0a470b 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -40,7 +40,6 @@ from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension as xe -from jax._src.lib import xla_extension_version from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, @@ -212,90 +211,51 @@ class ArrayImpl(basearray.Array): arrays = self._check_and_rearrange(arrays, self._sharding, self.aval) self._arrays = arrays - if xla_extension_version >= 310: - def _check_and_rearrange(self, arrays, sharding, aval): - device_id_to_buffer = {_get_device(db).id: db for db in arrays} + def _check_and_rearrange(self, arrays, sharding, aval): + device_id_to_buffer = {_get_device(db).id: db for db in arrays} - addressable_dev = sharding.addressable_devices - if len(arrays) != len(addressable_dev): - raise ValueError( - f"Expected {len(addressable_dev)} per-device arrays " - "(this is how many devices are addressable by the sharding), but " - f"got {len(arrays)}") + addressable_dev = sharding.addressable_devices + if len(arrays) != len(addressable_dev): + raise ValueError( + f"Expected {len(addressable_dev)} per-device arrays " + "(this is how many devices are addressable by the sharding), but " + f"got {len(arrays)}") - array_device_ids = set(device_id_to_buffer.keys()) - addressable_device_ids = {d.id for d in addressable_dev} - if len(array_device_ids) != len(arrays): - buffer_device_ids = [_get_device(db).id for db in arrays] - raise ValueError( - "When making an array from single-device arrays, the input arrays" - " must be from distinct devices, but got device IDs" - f" {buffer_device_ids}") + array_device_ids = set(device_id_to_buffer.keys()) + addressable_device_ids = {d.id for d in addressable_dev} + if len(array_device_ids) != len(arrays): + buffer_device_ids = [_get_device(db).id for db in arrays] + raise ValueError( + "When making an array from single-device arrays, the input arrays" + " must be from distinct devices, but got device IDs" + f" {buffer_device_ids}") - # Calculate a symmetric difference because the device ids between sharding - # and _arrays should match. - diff = array_device_ids ^ addressable_device_ids - if diff: - dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids - dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids - err_msg = ( - "Addressable devices and per-device arrays devices do not match.") - if dev_in_sharding_not_in_arrays: - err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} " - "that are not present in per-device arrays.") - if dev_in_arrays_not_in_sharding: - err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} " - "that are not present in the sharding.") - raise ValueError(err_msg) + # Calculate a symmetric difference because the device ids between sharding + # and _arrays should match. + diff = array_device_ids ^ addressable_device_ids + if diff: + dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids + dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids + err_msg = ( + "Addressable devices and per-device arrays devices do not match.") + if dev_in_sharding_not_in_arrays: + err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} " + "that are not present in per-device arrays.") + if dev_in_arrays_not_in_sharding: + err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} " + "that are not present in the sharding.") + raise ValueError(err_msg) - _validate_shape_and_dtype_for_per_device_arrays( - arrays, - sharding=sharding, - aval=aval, - expected_shape=sharding.shard_shape(aval.shape), - ) + _validate_shape_and_dtype_for_per_device_arrays( + arrays, + sharding=sharding, + aval=aval, + expected_shape=sharding.shard_shape(aval.shape), + ) - # Rearrange arrays based on the device assignment. - addressable_da = sharding._addressable_device_assignment - return [device_id_to_buffer[device.id] for device in addressable_da] - else: - def _check_and_rearrange(self): # type: ignore - device_id_to_buffer = {_get_device(db).id: db for db in self._arrays} - - addressable_dev = self.sharding.addressable_devices - if len(self._arrays) != len(addressable_dev): - raise ValueError( - f"Expected {len(addressable_dev)} per-device arrays " - "(this is how many devices are addressable by the sharding), but " - f"got {len(self._arrays)}") - - array_device_ids = set(device_id_to_buffer.keys()) - addressable_device_ids = {d.id for d in addressable_dev} - # Calculate a symmetric difference because the device ids between sharding - # and _arrays should match. - diff = array_device_ids ^ addressable_device_ids - if diff: - dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids - dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids - err_msg = ( - "Addressable devices and per-device arrays devices do not match.") - if dev_in_sharding_not_in_arrays: - err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} " - "that are not present in per-device arrays.") - if dev_in_arrays_not_in_sharding: - err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} " - "that are not present in the sharding.") - raise ValueError(err_msg) - - _validate_shape_and_dtype_for_per_device_arrays( - self._arrays, - sharding=self.sharding, - aval=self.aval, - expected_shape=self.sharding.shard_shape(self.shape), - ) - # Rearrange arrays based on the device assignment. - addressable_da = self.sharding._addressable_device_assignment - self._arrays = [device_id_to_buffer[device.id] for device in addressable_da] + # Rearrange arrays based on the device assignment. + addressable_da = sharding._addressable_device_assignment + return [device_id_to_buffer[device.id] for device in addressable_da] @property def shape(self) -> Shape: @@ -652,18 +612,9 @@ class ArrayImpl(basearray.Array): db.block_until_ready() return self - if xla_extension_version >= 314: - @use_cpp_method() - def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: # type: ignore - ... # pytype: disable=bad-return-type - - else: - @use_cpp_method() - def _single_device_array_to_np_array(self): - return np.asarray(self._arrays[0]) - - def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: - return cast(np.ndarray, self._single_device_array_to_np_array()), True + @use_cpp_method() + def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: # type: ignore + ... # pytype: disable=bad-return-type @use_cpp_method() def _copy_single_device_array_to_host_async(self): diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index df545cb72..dea532d13 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -33,7 +33,6 @@ from jax._src import profiler from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir import numpy as np @@ -191,13 +190,12 @@ def get_compile_options( build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value build_options.memory_fitting_effort = config.memory_fitting_effort.value - if xla_extension_version >= 316: - build_options.optimization_level = config.EffortLevel( - config.optimization_level.value - ).value - build_options.memory_fitting_level = config.EffortLevel( - config.memory_fitting_level.value - ).value + build_options.optimization_level = config.EffortLevel( + config.optimization_level.value + ).value + build_options.memory_fitting_level = config.EffortLevel( + config.memory_fitting_level.value + ).value # This is a temporary workaround to simplify the AutoPGLE usage. # TODO(b/376647494): Remove once the bug is fixed. @@ -212,10 +210,9 @@ def get_compile_options( # Some overrides are passed directly on build_options. overrides_on_build_options = [ "exec_time_optimization_effort", "memory_fitting_effort"] - if xla_extension_version >= 316: - overrides_on_build_options.extend( - ["optimization_level", "memory_fitting_level"] - ) + overrides_on_build_options.extend( + ["optimization_level", "memory_fitting_level"] + ) env_options_overrides = dict(env_options_overrides) for name in overrides_on_build_options: diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 076f26c4a..050d6c394 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -45,7 +45,6 @@ from jax._src.interpreters import pxla from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.mesh import AbstractMesh, Mesh from jax._src.monitoring import record_event_duration_secs, record_event_time_span from jax._src.partition_spec import PartitionSpec @@ -419,12 +418,8 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): def _reorder_shards(x, new_s, copy_semantics: CopySemantics): """Reorders array shards to match the order indicated by the new sharding.""" - if xla_extension_version >= 304: - xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0] - return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore - else: - assert copy_semantics == CopySemantics.ALIAS - return array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays) + xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0] + return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore @dataclasses.dataclass(frozen=True) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index d8cdeecea..808d129ba 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -33,7 +33,6 @@ import ml_dtypes import numpy as np from jax._src import config -from jax._src.lib import xla_extension_version from jax._src.typing import Array, DType, DTypeLike from jax._src.util import set_module, StrictABC @@ -518,7 +517,7 @@ _complex_types: list[JAXType] = [ # only meant for the `jnp.isdtype` and we want to be conservative and not allow # StringDType to be used in there. _string_types: list[JAXType] = [] -if hasattr(np.dtypes, 'StringDType') and xla_extension_version >= 311: +if hasattr(np.dtypes, 'StringDType'): _string_types: list[JAXType] = [np.dtypes.StringDType()] # type: ignore _jax_dtype_set = { diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1370bb943..27e6ddac8 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -55,7 +55,6 @@ from jax._src.sharding_impls import (AUTO, NamedSharding, SdyArraySharding, SdyArrayShardingList) from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import dialects, ir, passmanager from jax._src.lib.mlir.dialects import func as func_dialect, hlo from jax._src.lib.mlir import register_jax_dialects @@ -479,20 +478,13 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location: loc = ctx.traceback_caches.location_cache.get(code_lasti, None) if loc is None: frame = source_info_util.raw_frame_to_frame(code, lasti) - if xla_extension_version >= 309: - file_loc = ir.Location.file( - get_canonical_source_file(frame.file_name, ctx.traceback_caches), - frame.start_line, - frame.start_column, - frame.end_line, - frame.end_column, - ) - else: - file_loc = ir.Location.file( - get_canonical_source_file(frame.file_name, ctx.traceback_caches), - frame.start_line, - frame.start_column, - ) + file_loc = ir.Location.file( + get_canonical_source_file(frame.file_name, ctx.traceback_caches), + frame.start_line, + frame.start_column, + frame.end_line, + frame.end_column, + ) loc = ir.Location.name(frame.function_name, childLoc=file_loc) ctx.traceback_caches.location_cache[code_lasti] = loc frame_locs.append(loc) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 47791a216..a50576720 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -49,7 +49,6 @@ from jax._src.lax import lax as lax_internal from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.numpy.array_creation import (empty, empty_like, full, ones, ones_like, zeros, zeros_like) from jax._src.numpy import indexing @@ -5357,11 +5356,6 @@ def _make_string_array( ndmin: int = 0, device: xc.Device | Sharding | None = None, ) -> Array: - if xla_extension_version < 311: - raise TypeError( - "String arrays are not supported in JAX before XLA extension version" - " 311." - ) if not isinstance(object, np.ndarray): raise TypeError( "Currently, string arrays can only be made from NumPy" diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index ebd77b4f9..4e3bd0697 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -25,7 +25,6 @@ import jax._src.core as jax_core from jax._src.interpreters import mlir from jax._src.lib import triton from jax._src.lib import gpu_triton as triton_kernel_call_lib -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core from jax._src.pallas.triton import lowering @@ -98,7 +97,7 @@ def pallas_call_lowering( module_op.write_bytecode(buf) # TODO(b/394629193): Remove True once the bug is fixed. - if True or jaxlib_version < (0, 5, 1): + if True: # AOT Triton compilation is only available on jaxlib 0.5.1+. out_types = [ ir.RankedTensorType.get(bm.array_shape_dtype.shape, diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 5e8823795..98c1c20cd 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -591,12 +591,7 @@ def _call_tf_lowering( call = func_dialect.CallOp(callee_result_types, ir.FlatSymbolRefAttr.get(fn), tuple(args_op) + captured_ops) - if result_shape.is_tuple() and xla_client.mlir_api_version < 58: - # In API version 58, the results are always flattened. - flat_results = [hlo.get_tuple_element(call, mlir.i32_attr(i)) - for i in range(len(result_shapes))] - else: - flat_results = call.results + flat_results = call.results if ordered: raise NotImplementedError( diff --git a/jax/experimental/transfer.py b/jax/experimental/transfer.py index 39d25a5bb..1522df2cf 100644 --- a/jax/experimental/transfer.py +++ b/jax/experimental/transfer.py @@ -16,7 +16,6 @@ import jax from typing import Any, TYPE_CHECKING from jax._src.lib import xla_client as _xc -from jax._src.lib import xla_extension_version from jax._src.util import use_cpp_class, use_cpp_method class TransferConnection: @@ -42,7 +41,7 @@ class TransferConnection: return tree.unflatten(self._pull_flat(uuid, backend, xs_flat)) -if not TYPE_CHECKING and xla_extension_version >= 305: +if not TYPE_CHECKING: TransferConnection = use_cpp_class(_xc._xla.TransferConnection)(TransferConnection) @@ -67,8 +66,7 @@ class TransferServer: self._await_pull_flat(uuid, jax.tree.flatten(arrays)[0]) -if not TYPE_CHECKING and xla_extension_version >= 305: +if not TYPE_CHECKING: TransferServer = use_cpp_class(_xc._xla.TransferServer)(TransferServer) -if xla_extension_version >= 305: - start_transfer_server = _xc._xla.start_transfer_server +start_transfer_server = _xc._xla.start_transfer_server diff --git a/tests/api_test.py b/tests/api_test.py index 536fbf372..f145157cd 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -61,7 +61,6 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.custom_batching @@ -1371,8 +1370,6 @@ class JitTest(jtu.BufferDonationTestCase): def f(x): return jnp.sqrt(x**2) + 1.0 - if xla_extension_version < 316: - self.skipTest("Requires XLA extension version >= 316") f_jit = jit( f, compiler_options={ @@ -1386,8 +1383,6 @@ class JitTest(jtu.BufferDonationTestCase): def f(x): return jnp.sqrt(x**2) + 1.0 - if xla_extension_version < 316: - self.skipTest("Requires XLA extension version >= 316") f_jit = jit( f, compiler_options={ diff --git a/tests/array_test.py b/tests/array_test.py index f961d2057..c9e888510 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -29,7 +29,6 @@ from jax._src import op_shardings from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import dialects, ir from jax._src.util import safe_zip from jax._src.mesh import AxisTypes @@ -825,8 +824,6 @@ class JaxArrayTest(jtu.JaxTestCase): @parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats) @jtu.run_on_devices("gpu") def test_pinned_host_npy_value_doesnt_cache(self, dtype): - if xla_extension_version < 314: - self.skipTest("Requires XLA extension version >= 314") # see https://github.com/jax-ml/jax/issues/26216 d_tensor = jnp.array(0, dtype=dtype) d_sharding = d_tensor.sharding diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 74592509f..d6abe8bec 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -25,7 +25,6 @@ from absl.testing import parameterized import jax from jax._src import config from jax._src import test_util as jtu -from jax._src.lib import xla_extension_version from jax.experimental import colocated_python from jax.experimental.colocated_python import serialization from jax.extend.ifrt_programs import ifrt_programs @@ -383,11 +382,6 @@ class ColocatedPythonTest(jtu.JaxTestCase): del colocated_python._testing_global_state def testStringProcessing(self): - if xla_extension_version < 315: - self.skipTest( - "String support for colocated Python requires xla_extension_version" - " >= 315" - ) if np.lib.NumpyVersion(np.__version__) < "2.0.0": self.skipTest("StringDType requires NumPy 2.0.0 or later") cpu_devices = _colocated_cpu_devices(jax.local_devices()) @@ -431,11 +425,6 @@ class ColocatedPythonTest(jtu.JaxTestCase): ) def testBinaryDataProcessing(self): - if xla_extension_version < 315: - self.skipTest( - "String support for colocated Python requires xla_extension_version" - " >= 315" - ) if np.lib.NumpyVersion(np.__version__) < "2.0.0": self.skipTest("StringDType requires NumPy 2.0.0 or later") cpu_devices = _colocated_cpu_devices(jax.local_devices()) diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 531161eff..a8d59bc39 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -25,7 +25,6 @@ from jax._src import ad_checkpoint from jax._src import debugging from jax._src import dispatch from jax._src import test_util as jtu -from jax._src.lib import xla_extension_version import jax.numpy as jnp import numpy as np @@ -62,9 +61,6 @@ class DebugCallbackTest(jtu.JaxTestCase): @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.run_on_devices("cpu") def test_async_deadlock(self): - if xla_extension_version < 306: - self.skipTest("deadlock expected") - # See https://github.com/jax-ml/jax/issues/25861 def print_it(i, maxiter): self.assertIsInstance(i, jax.Array) diff --git a/tests/ffi_test.py b/tests/ffi_test.py index f38789085..46aaefa8f 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -32,7 +32,7 @@ from jax._src import dispatch from jax._src import test_util as jtu from jax._src.interpreters import mlir from jax._src.layout import DeviceLocalLayout -from jax._src.lib import lapack, xla_extension_version +from jax._src.lib import lapack from jax._src.lib.mlir.dialects import hlo from jax._src.lax import linalg as lax_linalg_internal from jax.experimental.shard_map import shard_map @@ -333,8 +333,6 @@ def ffi_call_geqrf(x, _use_extend=False, **kwargs): class BatchPartitioningTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if xla_extension_version < 313: - self.skipTest("Requires XLA extension version >= 313") # Register callbacks before checking the number of devices to make sure # that we're testing the registration path, even if we can't run the tests. for target_name in ["lapack_sgeqrf_ffi", "cusolver_geqrf_ffi", diff --git a/tests/layout_test.py b/tests/layout_test.py index 0520cf972..d98121b53 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -23,7 +23,6 @@ from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding from jax._src import config from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu -from jax._src.lib import xla_extension_version from jax._src.util import safe_zip from jax.experimental.compute_on import compute_on @@ -33,12 +32,6 @@ jtu.request_cpu_devices(8) class LayoutTest(jtu.JaxTestCase): - # Remove this setUp once the released xla_extension_version is >= 308. - def setUp(self): - if xla_extension_version < 308 and not jtu.test_device_matches(['tpu', 'gpu']): - self.skipTest("Layouts do not work on CPU backend yet.") - super().setUp() - def test_auto_layout(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape1 = (128, 128) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index cfa6536c0..fec41f199 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1707,8 +1707,6 @@ class ScipyLinalgTest(jtu.JaxTestCase): if pivoting: if not jtu.test_device_matches(["cpu", "gpu"]): self.skipTest("Pivoting is only supported on CPU and GPU.") - if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 5, 0): - self.skipTest("Pivoting is only supported on GPU for jaxlib > 0.5.0") rng = jtu.rand_default(self.rng()) jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting) sp_func = partial(scipy.linalg.qr, mode=mode, pivoting=pivoting) diff --git a/tests/magma_linalg_test.py b/tests/magma_linalg_test.py index 2400672d1..37bc88339 100644 --- a/tests/magma_linalg_test.py +++ b/tests/magma_linalg_test.py @@ -120,8 +120,6 @@ class MagmaLinalgTest(jtu.JaxTestCase): ) @jtu.run_on_devices("gpu") def testPivotedQrFactorization(self, shape, dtype): - if jtu.jaxlib_version() <= (0, 5, 0): - self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0") if not gpu_solver.has_magma(): self.skipTest("MAGMA is not installed or can't be loaded.") rng = jtu.rand_default(self.rng()) @@ -133,8 +131,6 @@ class MagmaLinalgTest(jtu.JaxTestCase): self._CompileAndCheck(lax_func, args_maker) def testPivotedQrFactorizationMagmaConfig(self): - if jtu.jaxlib_version() <= (0, 5, 0): - self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0") if not gpu_solver.has_magma(): self.skipTest("MAGMA is not installed or can't be loaded.") rng = jtu.rand_default(self.rng()) diff --git a/tests/memories_test.py b/tests/memories_test.py index 212d51c70..70c98fd51 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -26,7 +26,6 @@ from jax import lax from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.layout import DeviceLocalLayout as DLL, Layout -from jax._src.lib import xla_extension_version from jax._src import config from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp @@ -1664,8 +1663,6 @@ class ComputeOffload(jtu.BufferDonationTestCase): class StreamAnnotationTest(jtu.JaxTestCase): def test_stream_annotation_inside_shmap(self): - if xla_extension_version < 313: - self.skipTest("Requires xla_extension_version >= 313") if not jtu.test_device_matches(["gpu"]): self.skipTest("Stream annotation is only supported on GPU.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 36b0218d6..d96630df8 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -29,7 +29,6 @@ import jax from jax._src import config from jax._src import test_util as jtu from jax._src.interpreters import mlir -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir import passmanager from jax._src.lib.mlir.dialects import arith @@ -2701,9 +2700,6 @@ class UtilsTest(TestCase): class SerializationTest(absltest.TestCase): def test_pass_is_registered(self): - if jaxlib_version < (0, 5, 1): - self.skipTest("Test requires jaxlib 0.5.1 or later") - ctx = mlir.make_ir_context() ctx.allow_unregistered_dialects = True with ir.Location.unknown(ctx): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index bfb9451bd..f30de44d0 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -62,7 +62,6 @@ from jax._src.lib.mlir import dialects from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -7147,9 +7146,6 @@ class PJitErrorTest(jtu.JaxTestCase): f(arr, 2., 3.) # doesn't crash def test_named_sharding_of_none(self): - if xla_extension_version < 309: - raise unittest.SkipTest("NamedSharding does't reject None.") - mesh = jtu.create_mesh((2,), ('x',)) with self.assertRaisesRegex(TypeError, 'Unexpected None'): jax.NamedSharding(mesh, None) diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index dccfd8c38..48f3d062b 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -37,8 +37,6 @@ class RaggedCollectiveTest(jtu.JaxTestCase): super().setUp() if jtu.test_device_matches(['cpu']): self.skipTest('ragged-all-to-all is not supported on CPU') - if jtu.jaxlib_version() < (0, 5, 1): - self.skipTest('ragged-all-to-all is not supported on jaxlib version < 0.5.1') @parameterized.named_parameters( dict( diff --git a/tests/string_array_test.py b/tests/string_array_test.py index 4b8632ee4..364c71759 100644 --- a/tests/string_array_test.py +++ b/tests/string_array_test.py @@ -18,7 +18,6 @@ import jax from jax import numpy as jnp from jax._src import config from jax._src import test_util as jtu -from jax._src.lib import xla_extension_version import numpy as np config.parse_flags_with_absl() @@ -29,12 +28,6 @@ class StringArrayTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if xla_extension_version < 311: - self.skipTest( - "Skipping this test because the current XLA extension version:" - f" {xla_extension_version} is older than 309, the oldest version with" - " string array support." - ) if not hasattr(np.dtypes, "StringDType"): self.skipTest( "Skipping this test because the numpy.dtype.StringDType is not"