Remove code present to support jaxlib < 0.5.1.

The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
This commit is contained in:
Peter Hawkins 2025-02-24 17:45:19 -05:00
parent 99a12ef9ea
commit 66293d8897
22 changed files with 68 additions and 206 deletions

View File

@ -40,7 +40,6 @@ from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe from jax._src.lib import xla_extension as xe
from jax._src.lib import xla_extension_version
from jax._src.sharding import Sharding from jax._src.sharding import Sharding
from jax._src.sharding_impls import ( from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, PmapSharding, SingleDeviceSharding,
@ -212,90 +211,51 @@ class ArrayImpl(basearray.Array):
arrays = self._check_and_rearrange(arrays, self._sharding, self.aval) arrays = self._check_and_rearrange(arrays, self._sharding, self.aval)
self._arrays = arrays self._arrays = arrays
if xla_extension_version >= 310: def _check_and_rearrange(self, arrays, sharding, aval):
def _check_and_rearrange(self, arrays, sharding, aval): device_id_to_buffer = {_get_device(db).id: db for db in arrays}
device_id_to_buffer = {_get_device(db).id: db for db in arrays}
addressable_dev = sharding.addressable_devices addressable_dev = sharding.addressable_devices
if len(arrays) != len(addressable_dev): if len(arrays) != len(addressable_dev):
raise ValueError( raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays " f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but " "(this is how many devices are addressable by the sharding), but "
f"got {len(arrays)}") f"got {len(arrays)}")
array_device_ids = set(device_id_to_buffer.keys()) array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev} addressable_device_ids = {d.id for d in addressable_dev}
if len(array_device_ids) != len(arrays): if len(array_device_ids) != len(arrays):
buffer_device_ids = [_get_device(db).id for db in arrays] buffer_device_ids = [_get_device(db).id for db in arrays]
raise ValueError( raise ValueError(
"When making an array from single-device arrays, the input arrays" "When making an array from single-device arrays, the input arrays"
" must be from distinct devices, but got device IDs" " must be from distinct devices, but got device IDs"
f" {buffer_device_ids}") f" {buffer_device_ids}")
# Calculate a symmetric difference because the device ids between sharding # Calculate a symmetric difference because the device ids between sharding
# and _arrays should match. # and _arrays should match.
diff = array_device_ids ^ addressable_device_ids diff = array_device_ids ^ addressable_device_ids
if diff: if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = ( err_msg = (
"Addressable devices and per-device arrays devices do not match.") "Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays: if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} " err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.") "that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding: if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} " err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.") "that are not present in the sharding.")
raise ValueError(err_msg) raise ValueError(err_msg)
_validate_shape_and_dtype_for_per_device_arrays( _validate_shape_and_dtype_for_per_device_arrays(
arrays, arrays,
sharding=sharding, sharding=sharding,
aval=aval, aval=aval,
expected_shape=sharding.shard_shape(aval.shape), expected_shape=sharding.shard_shape(aval.shape),
) )
# Rearrange arrays based on the device assignment. # Rearrange arrays based on the device assignment.
addressable_da = sharding._addressable_device_assignment addressable_da = sharding._addressable_device_assignment
return [device_id_to_buffer[device.id] for device in addressable_da] return [device_id_to_buffer[device.id] for device in addressable_da]
else:
def _check_and_rearrange(self): # type: ignore
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev):
raise ValueError(
f"Expected {len(addressable_dev)} per-device arrays "
"(this is how many devices are addressable by the sharding), but "
f"got {len(self._arrays)}")
array_device_ids = set(device_id_to_buffer.keys())
addressable_device_ids = {d.id for d in addressable_dev}
# Calculate a symmetric difference because the device ids between sharding
# and _arrays should match.
diff = array_device_ids ^ addressable_device_ids
if diff:
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
err_msg = (
"Addressable devices and per-device arrays devices do not match.")
if dev_in_sharding_not_in_arrays:
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
"that are not present in per-device arrays.")
if dev_in_arrays_not_in_sharding:
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
"that are not present in the sharding.")
raise ValueError(err_msg)
_validate_shape_and_dtype_for_per_device_arrays(
self._arrays,
sharding=self.sharding,
aval=self.aval,
expected_shape=self.sharding.shard_shape(self.shape),
)
# Rearrange arrays based on the device assignment.
addressable_da = self.sharding._addressable_device_assignment
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
@property @property
def shape(self) -> Shape: def shape(self) -> Shape:
@ -652,18 +612,9 @@ class ArrayImpl(basearray.Array):
db.block_until_ready() db.block_until_ready()
return self return self
if xla_extension_version >= 314: @use_cpp_method()
@use_cpp_method() def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: # type: ignore
def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: # type: ignore ... # pytype: disable=bad-return-type
... # pytype: disable=bad-return-type
else:
@use_cpp_method()
def _single_device_array_to_np_array(self):
return np.asarray(self._arrays[0])
def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]:
return cast(np.ndarray, self._single_device_array_to_np_array()), True
@use_cpp_method() @use_cpp_method()
def _copy_single_device_array_to_host_async(self): def _copy_single_device_array_to_host_async(self):

View File

@ -33,7 +33,6 @@ from jax._src import profiler
from jax._src import traceback_util from jax._src import traceback_util
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
import numpy as np import numpy as np
@ -191,13 +190,12 @@ def get_compile_options(
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value build_options.memory_fitting_effort = config.memory_fitting_effort.value
if xla_extension_version >= 316: build_options.optimization_level = config.EffortLevel(
build_options.optimization_level = config.EffortLevel( config.optimization_level.value
config.optimization_level.value ).value
).value build_options.memory_fitting_level = config.EffortLevel(
build_options.memory_fitting_level = config.EffortLevel( config.memory_fitting_level.value
config.memory_fitting_level.value ).value
).value
# This is a temporary workaround to simplify the AutoPGLE usage. # This is a temporary workaround to simplify the AutoPGLE usage.
# TODO(b/376647494): Remove once the bug is fixed. # TODO(b/376647494): Remove once the bug is fixed.
@ -212,10 +210,9 @@ def get_compile_options(
# Some overrides are passed directly on build_options. # Some overrides are passed directly on build_options.
overrides_on_build_options = [ overrides_on_build_options = [
"exec_time_optimization_effort", "memory_fitting_effort"] "exec_time_optimization_effort", "memory_fitting_effort"]
if xla_extension_version >= 316: overrides_on_build_options.extend(
overrides_on_build_options.extend( ["optimization_level", "memory_fitting_level"]
["optimization_level", "memory_fitting_level"] )
)
env_options_overrides = dict(env_options_overrides) env_options_overrides = dict(env_options_overrides)
for name in overrides_on_build_options: for name in overrides_on_build_options:

View File

@ -45,7 +45,6 @@ from jax._src.interpreters import pxla
from jax._src.interpreters import xla from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, Layout from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.mesh import AbstractMesh, Mesh from jax._src.mesh import AbstractMesh, Mesh
from jax._src.monitoring import record_event_duration_secs, record_event_time_span from jax._src.monitoring import record_event_duration_secs, record_event_time_span
from jax._src.partition_spec import PartitionSpec from jax._src.partition_spec import PartitionSpec
@ -419,12 +418,8 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics):
def _reorder_shards(x, new_s, copy_semantics: CopySemantics): def _reorder_shards(x, new_s, copy_semantics: CopySemantics):
"""Reorders array shards to match the order indicated by the new sharding.""" """Reorders array shards to match the order indicated by the new sharding."""
if xla_extension_version >= 304: xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0]
xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0] return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore
return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore
else:
assert copy_semantics == CopySemantics.ALIAS
return array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)

View File

@ -33,7 +33,6 @@ import ml_dtypes
import numpy as np import numpy as np
from jax._src import config from jax._src import config
from jax._src.lib import xla_extension_version
from jax._src.typing import Array, DType, DTypeLike from jax._src.typing import Array, DType, DTypeLike
from jax._src.util import set_module, StrictABC from jax._src.util import set_module, StrictABC
@ -518,7 +517,7 @@ _complex_types: list[JAXType] = [
# only meant for the `jnp.isdtype` and we want to be conservative and not allow # only meant for the `jnp.isdtype` and we want to be conservative and not allow
# StringDType to be used in there. # StringDType to be used in there.
_string_types: list[JAXType] = [] _string_types: list[JAXType] = []
if hasattr(np.dtypes, 'StringDType') and xla_extension_version >= 311: if hasattr(np.dtypes, 'StringDType'):
_string_types: list[JAXType] = [np.dtypes.StringDType()] # type: ignore _string_types: list[JAXType] = [np.dtypes.StringDType()] # type: ignore
_jax_dtype_set = { _jax_dtype_set = {

View File

@ -55,7 +55,6 @@ from jax._src.sharding_impls import (AUTO, NamedSharding,
SdyArraySharding, SdyArrayShardingList) SdyArraySharding, SdyArrayShardingList)
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import dialects, ir, passmanager from jax._src.lib.mlir import dialects, ir, passmanager
from jax._src.lib.mlir.dialects import func as func_dialect, hlo from jax._src.lib.mlir.dialects import func as func_dialect, hlo
from jax._src.lib.mlir import register_jax_dialects from jax._src.lib.mlir import register_jax_dialects
@ -479,20 +478,13 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
loc = ctx.traceback_caches.location_cache.get(code_lasti, None) loc = ctx.traceback_caches.location_cache.get(code_lasti, None)
if loc is None: if loc is None:
frame = source_info_util.raw_frame_to_frame(code, lasti) frame = source_info_util.raw_frame_to_frame(code, lasti)
if xla_extension_version >= 309: file_loc = ir.Location.file(
file_loc = ir.Location.file( get_canonical_source_file(frame.file_name, ctx.traceback_caches),
get_canonical_source_file(frame.file_name, ctx.traceback_caches), frame.start_line,
frame.start_line, frame.start_column,
frame.start_column, frame.end_line,
frame.end_line, frame.end_column,
frame.end_column, )
)
else:
file_loc = ir.Location.file(
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
frame.start_line,
frame.start_column,
)
loc = ir.Location.name(frame.function_name, childLoc=file_loc) loc = ir.Location.name(frame.function_name, childLoc=file_loc)
ctx.traceback_caches.location_cache[code_lasti] = loc ctx.traceback_caches.location_cache[code_lasti] = loc
frame_locs.append(loc) frame_locs.append(loc)

View File

@ -49,7 +49,6 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax.lax import (PrecisionLike,_array_copy, from jax._src.lax.lax import (PrecisionLike,_array_copy,
_sort_le_comparator, _sort_lt_comparator) _sort_le_comparator, _sort_lt_comparator)
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.numpy.array_creation import (empty, empty_like, full, from jax._src.numpy.array_creation import (empty, empty_like, full,
ones, ones_like, zeros, zeros_like) ones, ones_like, zeros, zeros_like)
from jax._src.numpy import indexing from jax._src.numpy import indexing
@ -5357,11 +5356,6 @@ def _make_string_array(
ndmin: int = 0, ndmin: int = 0,
device: xc.Device | Sharding | None = None, device: xc.Device | Sharding | None = None,
) -> Array: ) -> Array:
if xla_extension_version < 311:
raise TypeError(
"String arrays are not supported in JAX before XLA extension version"
" 311."
)
if not isinstance(object, np.ndarray): if not isinstance(object, np.ndarray):
raise TypeError( raise TypeError(
"Currently, string arrays can only be made from NumPy" "Currently, string arrays can only be made from NumPy"

View File

@ -25,7 +25,6 @@ import jax._src.core as jax_core
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.lib import triton from jax._src.lib import triton
from jax._src.lib import gpu_triton as triton_kernel_call_lib from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core from jax._src.pallas import core as pallas_core
from jax._src.pallas.triton import lowering from jax._src.pallas.triton import lowering
@ -98,7 +97,7 @@ def pallas_call_lowering(
module_op.write_bytecode(buf) module_op.write_bytecode(buf)
# TODO(b/394629193): Remove True once the bug is fixed. # TODO(b/394629193): Remove True once the bug is fixed.
if True or jaxlib_version < (0, 5, 1): if True:
# AOT Triton compilation is only available on jaxlib 0.5.1+. # AOT Triton compilation is only available on jaxlib 0.5.1+.
out_types = [ out_types = [
ir.RankedTensorType.get(bm.array_shape_dtype.shape, ir.RankedTensorType.get(bm.array_shape_dtype.shape,

View File

@ -591,12 +591,7 @@ def _call_tf_lowering(
call = func_dialect.CallOp(callee_result_types, call = func_dialect.CallOp(callee_result_types,
ir.FlatSymbolRefAttr.get(fn), ir.FlatSymbolRefAttr.get(fn),
tuple(args_op) + captured_ops) tuple(args_op) + captured_ops)
if result_shape.is_tuple() and xla_client.mlir_api_version < 58: flat_results = call.results
# In API version 58, the results are always flattened.
flat_results = [hlo.get_tuple_element(call, mlir.i32_attr(i))
for i in range(len(result_shapes))]
else:
flat_results = call.results
if ordered: if ordered:
raise NotImplementedError( raise NotImplementedError(

View File

@ -16,7 +16,6 @@
import jax import jax
from typing import Any, TYPE_CHECKING from typing import Any, TYPE_CHECKING
from jax._src.lib import xla_client as _xc from jax._src.lib import xla_client as _xc
from jax._src.lib import xla_extension_version
from jax._src.util import use_cpp_class, use_cpp_method from jax._src.util import use_cpp_class, use_cpp_method
class TransferConnection: class TransferConnection:
@ -42,7 +41,7 @@ class TransferConnection:
return tree.unflatten(self._pull_flat(uuid, backend, xs_flat)) return tree.unflatten(self._pull_flat(uuid, backend, xs_flat))
if not TYPE_CHECKING and xla_extension_version >= 305: if not TYPE_CHECKING:
TransferConnection = use_cpp_class(_xc._xla.TransferConnection)(TransferConnection) TransferConnection = use_cpp_class(_xc._xla.TransferConnection)(TransferConnection)
@ -67,8 +66,7 @@ class TransferServer:
self._await_pull_flat(uuid, jax.tree.flatten(arrays)[0]) self._await_pull_flat(uuid, jax.tree.flatten(arrays)[0])
if not TYPE_CHECKING and xla_extension_version >= 305: if not TYPE_CHECKING:
TransferServer = use_cpp_class(_xc._xla.TransferServer)(TransferServer) TransferServer = use_cpp_class(_xc._xla.TransferServer)(TransferServer)
if xla_extension_version >= 305: start_transfer_server = _xc._xla.start_transfer_server
start_transfer_server = _xc._xla.start_transfer_server

View File

@ -61,7 +61,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.compilation_cache import is_persistent_cache_enabled
from jax._src.lib import xla_extension from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching import jax.custom_batching
@ -1371,8 +1370,6 @@ class JitTest(jtu.BufferDonationTestCase):
def f(x): def f(x):
return jnp.sqrt(x**2) + 1.0 return jnp.sqrt(x**2) + 1.0
if xla_extension_version < 316:
self.skipTest("Requires XLA extension version >= 316")
f_jit = jit( f_jit = jit(
f, f,
compiler_options={ compiler_options={
@ -1386,8 +1383,6 @@ class JitTest(jtu.BufferDonationTestCase):
def f(x): def f(x):
return jnp.sqrt(x**2) + 1.0 return jnp.sqrt(x**2) + 1.0
if xla_extension_version < 316:
self.skipTest("Requires XLA extension version >= 316")
f_jit = jit( f_jit = jit(
f, f,
compiler_options={ compiler_options={

View File

@ -29,7 +29,6 @@ from jax._src import op_shardings
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src import xla_bridge as xb from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import dialects, ir from jax._src.lib.mlir import dialects, ir
from jax._src.util import safe_zip from jax._src.util import safe_zip
from jax._src.mesh import AxisTypes from jax._src.mesh import AxisTypes
@ -825,8 +824,6 @@ class JaxArrayTest(jtu.JaxTestCase):
@parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats) @parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
@jtu.run_on_devices("gpu") @jtu.run_on_devices("gpu")
def test_pinned_host_npy_value_doesnt_cache(self, dtype): def test_pinned_host_npy_value_doesnt_cache(self, dtype):
if xla_extension_version < 314:
self.skipTest("Requires XLA extension version >= 314")
# see https://github.com/jax-ml/jax/issues/26216 # see https://github.com/jax-ml/jax/issues/26216
d_tensor = jnp.array(0, dtype=dtype) d_tensor = jnp.array(0, dtype=dtype)
d_sharding = d_tensor.sharding d_sharding = d_tensor.sharding

View File

@ -25,7 +25,6 @@ from absl.testing import parameterized
import jax import jax
from jax._src import config from jax._src import config
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax.experimental import colocated_python from jax.experimental import colocated_python
from jax.experimental.colocated_python import serialization from jax.experimental.colocated_python import serialization
from jax.extend.ifrt_programs import ifrt_programs from jax.extend.ifrt_programs import ifrt_programs
@ -383,11 +382,6 @@ class ColocatedPythonTest(jtu.JaxTestCase):
del colocated_python._testing_global_state del colocated_python._testing_global_state
def testStringProcessing(self): def testStringProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
if np.lib.NumpyVersion(np.__version__) < "2.0.0": if np.lib.NumpyVersion(np.__version__) < "2.0.0":
self.skipTest("StringDType requires NumPy 2.0.0 or later") self.skipTest("StringDType requires NumPy 2.0.0 or later")
cpu_devices = _colocated_cpu_devices(jax.local_devices()) cpu_devices = _colocated_cpu_devices(jax.local_devices())
@ -431,11 +425,6 @@ class ColocatedPythonTest(jtu.JaxTestCase):
) )
def testBinaryDataProcessing(self): def testBinaryDataProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
if np.lib.NumpyVersion(np.__version__) < "2.0.0": if np.lib.NumpyVersion(np.__version__) < "2.0.0":
self.skipTest("StringDType requires NumPy 2.0.0 or later") self.skipTest("StringDType requires NumPy 2.0.0 or later")
cpu_devices = _colocated_cpu_devices(jax.local_devices()) cpu_devices = _colocated_cpu_devices(jax.local_devices())

View File

@ -25,7 +25,6 @@ from jax._src import ad_checkpoint
from jax._src import debugging from jax._src import debugging
from jax._src import dispatch from jax._src import dispatch
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
@ -62,9 +61,6 @@ class DebugCallbackTest(jtu.JaxTestCase):
@jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.run_on_devices("cpu") @jtu.run_on_devices("cpu")
def test_async_deadlock(self): def test_async_deadlock(self):
if xla_extension_version < 306:
self.skipTest("deadlock expected")
# See https://github.com/jax-ml/jax/issues/25861 # See https://github.com/jax-ml/jax/issues/25861
def print_it(i, maxiter): def print_it(i, maxiter):
self.assertIsInstance(i, jax.Array) self.assertIsInstance(i, jax.Array)

View File

@ -32,7 +32,7 @@ from jax._src import dispatch
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout from jax._src.layout import DeviceLocalLayout
from jax._src.lib import lapack, xla_extension_version from jax._src.lib import lapack
from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import hlo
from jax._src.lax import linalg as lax_linalg_internal from jax._src.lax import linalg as lax_linalg_internal
from jax.experimental.shard_map import shard_map from jax.experimental.shard_map import shard_map
@ -333,8 +333,6 @@ def ffi_call_geqrf(x, _use_extend=False, **kwargs):
class BatchPartitioningTest(jtu.JaxTestCase): class BatchPartitioningTest(jtu.JaxTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
if xla_extension_version < 313:
self.skipTest("Requires XLA extension version >= 313")
# Register callbacks before checking the number of devices to make sure # Register callbacks before checking the number of devices to make sure
# that we're testing the registration path, even if we can't run the tests. # that we're testing the registration path, even if we can't run the tests.
for target_name in ["lapack_sgeqrf_ffi", "cusolver_geqrf_ffi", for target_name in ["lapack_sgeqrf_ffi", "cusolver_geqrf_ffi",

View File

@ -23,7 +23,6 @@ from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding
from jax._src import config from jax._src import config
from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src.layout import Layout, DeviceLocalLayout as DLL
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip from jax._src.util import safe_zip
from jax.experimental.compute_on import compute_on from jax.experimental.compute_on import compute_on
@ -33,12 +32,6 @@ jtu.request_cpu_devices(8)
class LayoutTest(jtu.JaxTestCase): class LayoutTest(jtu.JaxTestCase):
# Remove this setUp once the released xla_extension_version is >= 308.
def setUp(self):
if xla_extension_version < 308 and not jtu.test_device_matches(['tpu', 'gpu']):
self.skipTest("Layouts do not work on CPU backend yet.")
super().setUp()
def test_auto_layout(self): def test_auto_layout(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y')) mesh = jtu.create_mesh((2, 2), ('x', 'y'))
shape1 = (128, 128) shape1 = (128, 128)

View File

@ -1707,8 +1707,6 @@ class ScipyLinalgTest(jtu.JaxTestCase):
if pivoting: if pivoting:
if not jtu.test_device_matches(["cpu", "gpu"]): if not jtu.test_device_matches(["cpu", "gpu"]):
self.skipTest("Pivoting is only supported on CPU and GPU.") self.skipTest("Pivoting is only supported on CPU and GPU.")
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 5, 0):
self.skipTest("Pivoting is only supported on GPU for jaxlib > 0.5.0")
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting) jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting)
sp_func = partial(scipy.linalg.qr, mode=mode, pivoting=pivoting) sp_func = partial(scipy.linalg.qr, mode=mode, pivoting=pivoting)

View File

@ -120,8 +120,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
) )
@jtu.run_on_devices("gpu") @jtu.run_on_devices("gpu")
def testPivotedQrFactorization(self, shape, dtype): def testPivotedQrFactorization(self, shape, dtype):
if jtu.jaxlib_version() <= (0, 5, 0):
self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
if not gpu_solver.has_magma(): if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.") self.skipTest("MAGMA is not installed or can't be loaded.")
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
@ -133,8 +131,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(lax_func, args_maker) self._CompileAndCheck(lax_func, args_maker)
def testPivotedQrFactorizationMagmaConfig(self): def testPivotedQrFactorizationMagmaConfig(self):
if jtu.jaxlib_version() <= (0, 5, 0):
self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
if not gpu_solver.has_magma(): if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.") self.skipTest("MAGMA is not installed or can't be loaded.")
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())

View File

@ -26,7 +26,6 @@ from jax import lax
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src import xla_bridge as xb from jax._src import xla_bridge as xb
from jax._src.layout import DeviceLocalLayout as DLL, Layout from jax._src.layout import DeviceLocalLayout as DLL, Layout
from jax._src.lib import xla_extension_version
from jax._src import config from jax._src import config
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.numpy as jnp import jax.numpy as jnp
@ -1664,8 +1663,6 @@ class ComputeOffload(jtu.BufferDonationTestCase):
class StreamAnnotationTest(jtu.JaxTestCase): class StreamAnnotationTest(jtu.JaxTestCase):
def test_stream_annotation_inside_shmap(self): def test_stream_annotation_inside_shmap(self):
if xla_extension_version < 313:
self.skipTest("Requires xla_extension_version >= 313")
if not jtu.test_device_matches(["gpu"]): if not jtu.test_device_matches(["gpu"]):
self.skipTest("Stream annotation is only supported on GPU.") self.skipTest("Stream annotation is only supported on GPU.")
mesh = jtu.create_mesh((2, 2), ('x', 'y')) mesh = jtu.create_mesh((2, 2), ('x', 'y'))

View File

@ -29,7 +29,6 @@ import jax
from jax._src import config from jax._src import config
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.lib.mlir import passmanager from jax._src.lib.mlir import passmanager
from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import arith
@ -2701,9 +2700,6 @@ class UtilsTest(TestCase):
class SerializationTest(absltest.TestCase): class SerializationTest(absltest.TestCase):
def test_pass_is_registered(self): def test_pass_is_registered(self):
if jaxlib_version < (0, 5, 1):
self.skipTest("Test requires jaxlib 0.5.1 or later")
ctx = mlir.make_ir_context() ctx = mlir.make_ir_context()
ctx.allow_unregistered_dialects = True ctx.allow_unregistered_dialects = True
with ir.Location.unknown(ctx): with ir.Location.unknown(ctx):

View File

@ -62,7 +62,6 @@ from jax._src.lib.mlir import dialects
from jax._src import xla_bridge from jax._src import xla_bridge
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.util import curry, unzip2 from jax._src.util import curry, unzip2
config.parse_flags_with_absl() config.parse_flags_with_absl()
@ -7147,9 +7146,6 @@ class PJitErrorTest(jtu.JaxTestCase):
f(arr, 2., 3.) # doesn't crash f(arr, 2., 3.) # doesn't crash
def test_named_sharding_of_none(self): def test_named_sharding_of_none(self):
if xla_extension_version < 309:
raise unittest.SkipTest("NamedSharding does't reject None.")
mesh = jtu.create_mesh((2,), ('x',)) mesh = jtu.create_mesh((2,), ('x',))
with self.assertRaisesRegex(TypeError, 'Unexpected None'): with self.assertRaisesRegex(TypeError, 'Unexpected None'):
jax.NamedSharding(mesh, None) jax.NamedSharding(mesh, None)

View File

@ -37,8 +37,6 @@ class RaggedCollectiveTest(jtu.JaxTestCase):
super().setUp() super().setUp()
if jtu.test_device_matches(['cpu']): if jtu.test_device_matches(['cpu']):
self.skipTest('ragged-all-to-all is not supported on CPU') self.skipTest('ragged-all-to-all is not supported on CPU')
if jtu.jaxlib_version() < (0, 5, 1):
self.skipTest('ragged-all-to-all is not supported on jaxlib version < 0.5.1')
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(

View File

@ -18,7 +18,6 @@ import jax
from jax import numpy as jnp from jax import numpy as jnp
from jax._src import config from jax._src import config
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
import numpy as np import numpy as np
config.parse_flags_with_absl() config.parse_flags_with_absl()
@ -29,12 +28,6 @@ class StringArrayTest(jtu.JaxTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
if xla_extension_version < 311:
self.skipTest(
"Skipping this test because the current XLA extension version:"
f" {xla_extension_version} is older than 309, the oldest version with"
" string array support."
)
if not hasattr(np.dtypes, "StringDType"): if not hasattr(np.dtypes, "StringDType"):
self.skipTest( self.skipTest(
"Skipping this test because the numpy.dtype.StringDType is not" "Skipping this test because the numpy.dtype.StringDType is not"