mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
This commit is contained in:
parent
99a12ef9ea
commit
66293d8897
@ -40,7 +40,6 @@ from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension as xe
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding,
|
||||
@ -212,90 +211,51 @@ class ArrayImpl(basearray.Array):
|
||||
arrays = self._check_and_rearrange(arrays, self._sharding, self.aval)
|
||||
self._arrays = arrays
|
||||
|
||||
if xla_extension_version >= 310:
|
||||
def _check_and_rearrange(self, arrays, sharding, aval):
|
||||
device_id_to_buffer = {_get_device(db).id: db for db in arrays}
|
||||
def _check_and_rearrange(self, arrays, sharding, aval):
|
||||
device_id_to_buffer = {_get_device(db).id: db for db in arrays}
|
||||
|
||||
addressable_dev = sharding.addressable_devices
|
||||
if len(arrays) != len(addressable_dev):
|
||||
raise ValueError(
|
||||
f"Expected {len(addressable_dev)} per-device arrays "
|
||||
"(this is how many devices are addressable by the sharding), but "
|
||||
f"got {len(arrays)}")
|
||||
addressable_dev = sharding.addressable_devices
|
||||
if len(arrays) != len(addressable_dev):
|
||||
raise ValueError(
|
||||
f"Expected {len(addressable_dev)} per-device arrays "
|
||||
"(this is how many devices are addressable by the sharding), but "
|
||||
f"got {len(arrays)}")
|
||||
|
||||
array_device_ids = set(device_id_to_buffer.keys())
|
||||
addressable_device_ids = {d.id for d in addressable_dev}
|
||||
if len(array_device_ids) != len(arrays):
|
||||
buffer_device_ids = [_get_device(db).id for db in arrays]
|
||||
raise ValueError(
|
||||
"When making an array from single-device arrays, the input arrays"
|
||||
" must be from distinct devices, but got device IDs"
|
||||
f" {buffer_device_ids}")
|
||||
array_device_ids = set(device_id_to_buffer.keys())
|
||||
addressable_device_ids = {d.id for d in addressable_dev}
|
||||
if len(array_device_ids) != len(arrays):
|
||||
buffer_device_ids = [_get_device(db).id for db in arrays]
|
||||
raise ValueError(
|
||||
"When making an array from single-device arrays, the input arrays"
|
||||
" must be from distinct devices, but got device IDs"
|
||||
f" {buffer_device_ids}")
|
||||
|
||||
# Calculate a symmetric difference because the device ids between sharding
|
||||
# and _arrays should match.
|
||||
diff = array_device_ids ^ addressable_device_ids
|
||||
if diff:
|
||||
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
|
||||
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
|
||||
err_msg = (
|
||||
"Addressable devices and per-device arrays devices do not match.")
|
||||
if dev_in_sharding_not_in_arrays:
|
||||
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
|
||||
"that are not present in per-device arrays.")
|
||||
if dev_in_arrays_not_in_sharding:
|
||||
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
|
||||
"that are not present in the sharding.")
|
||||
raise ValueError(err_msg)
|
||||
# Calculate a symmetric difference because the device ids between sharding
|
||||
# and _arrays should match.
|
||||
diff = array_device_ids ^ addressable_device_ids
|
||||
if diff:
|
||||
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
|
||||
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
|
||||
err_msg = (
|
||||
"Addressable devices and per-device arrays devices do not match.")
|
||||
if dev_in_sharding_not_in_arrays:
|
||||
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
|
||||
"that are not present in per-device arrays.")
|
||||
if dev_in_arrays_not_in_sharding:
|
||||
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
|
||||
"that are not present in the sharding.")
|
||||
raise ValueError(err_msg)
|
||||
|
||||
_validate_shape_and_dtype_for_per_device_arrays(
|
||||
arrays,
|
||||
sharding=sharding,
|
||||
aval=aval,
|
||||
expected_shape=sharding.shard_shape(aval.shape),
|
||||
)
|
||||
_validate_shape_and_dtype_for_per_device_arrays(
|
||||
arrays,
|
||||
sharding=sharding,
|
||||
aval=aval,
|
||||
expected_shape=sharding.shard_shape(aval.shape),
|
||||
)
|
||||
|
||||
# Rearrange arrays based on the device assignment.
|
||||
addressable_da = sharding._addressable_device_assignment
|
||||
return [device_id_to_buffer[device.id] for device in addressable_da]
|
||||
else:
|
||||
def _check_and_rearrange(self): # type: ignore
|
||||
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
|
||||
|
||||
addressable_dev = self.sharding.addressable_devices
|
||||
if len(self._arrays) != len(addressable_dev):
|
||||
raise ValueError(
|
||||
f"Expected {len(addressable_dev)} per-device arrays "
|
||||
"(this is how many devices are addressable by the sharding), but "
|
||||
f"got {len(self._arrays)}")
|
||||
|
||||
array_device_ids = set(device_id_to_buffer.keys())
|
||||
addressable_device_ids = {d.id for d in addressable_dev}
|
||||
# Calculate a symmetric difference because the device ids between sharding
|
||||
# and _arrays should match.
|
||||
diff = array_device_ids ^ addressable_device_ids
|
||||
if diff:
|
||||
dev_in_sharding_not_in_arrays = addressable_device_ids - array_device_ids
|
||||
dev_in_arrays_not_in_sharding = array_device_ids - addressable_device_ids
|
||||
err_msg = (
|
||||
"Addressable devices and per-device arrays devices do not match.")
|
||||
if dev_in_sharding_not_in_arrays:
|
||||
err_msg += (f" Sharding contains devices {dev_in_sharding_not_in_arrays} "
|
||||
"that are not present in per-device arrays.")
|
||||
if dev_in_arrays_not_in_sharding:
|
||||
err_msg += (f" Per-device arrays contain devices {dev_in_arrays_not_in_sharding} "
|
||||
"that are not present in the sharding.")
|
||||
raise ValueError(err_msg)
|
||||
|
||||
_validate_shape_and_dtype_for_per_device_arrays(
|
||||
self._arrays,
|
||||
sharding=self.sharding,
|
||||
aval=self.aval,
|
||||
expected_shape=self.sharding.shard_shape(self.shape),
|
||||
)
|
||||
# Rearrange arrays based on the device assignment.
|
||||
addressable_da = self.sharding._addressable_device_assignment
|
||||
self._arrays = [device_id_to_buffer[device.id] for device in addressable_da]
|
||||
# Rearrange arrays based on the device assignment.
|
||||
addressable_da = sharding._addressable_device_assignment
|
||||
return [device_id_to_buffer[device.id] for device in addressable_da]
|
||||
|
||||
@property
|
||||
def shape(self) -> Shape:
|
||||
@ -652,18 +612,9 @@ class ArrayImpl(basearray.Array):
|
||||
db.block_until_ready()
|
||||
return self
|
||||
|
||||
if xla_extension_version >= 314:
|
||||
@use_cpp_method()
|
||||
def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: # type: ignore
|
||||
... # pytype: disable=bad-return-type
|
||||
|
||||
else:
|
||||
@use_cpp_method()
|
||||
def _single_device_array_to_np_array(self):
|
||||
return np.asarray(self._arrays[0])
|
||||
|
||||
def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]:
|
||||
return cast(np.ndarray, self._single_device_array_to_np_array()), True
|
||||
@use_cpp_method()
|
||||
def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: # type: ignore
|
||||
... # pytype: disable=bad-return-type
|
||||
|
||||
@use_cpp_method()
|
||||
def _copy_single_device_array_to_host_async(self):
|
||||
|
@ -33,7 +33,6 @@ from jax._src import profiler
|
||||
from jax._src import traceback_util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
import numpy as np
|
||||
|
||||
@ -191,13 +190,12 @@ def get_compile_options(
|
||||
|
||||
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
|
||||
build_options.memory_fitting_effort = config.memory_fitting_effort.value
|
||||
if xla_extension_version >= 316:
|
||||
build_options.optimization_level = config.EffortLevel(
|
||||
config.optimization_level.value
|
||||
).value
|
||||
build_options.memory_fitting_level = config.EffortLevel(
|
||||
config.memory_fitting_level.value
|
||||
).value
|
||||
build_options.optimization_level = config.EffortLevel(
|
||||
config.optimization_level.value
|
||||
).value
|
||||
build_options.memory_fitting_level = config.EffortLevel(
|
||||
config.memory_fitting_level.value
|
||||
).value
|
||||
|
||||
# This is a temporary workaround to simplify the AutoPGLE usage.
|
||||
# TODO(b/376647494): Remove once the bug is fixed.
|
||||
@ -212,10 +210,9 @@ def get_compile_options(
|
||||
# Some overrides are passed directly on build_options.
|
||||
overrides_on_build_options = [
|
||||
"exec_time_optimization_effort", "memory_fitting_effort"]
|
||||
if xla_extension_version >= 316:
|
||||
overrides_on_build_options.extend(
|
||||
["optimization_level", "memory_fitting_level"]
|
||||
)
|
||||
overrides_on_build_options.extend(
|
||||
["optimization_level", "memory_fitting_level"]
|
||||
)
|
||||
|
||||
env_options_overrides = dict(env_options_overrides)
|
||||
for name in overrides_on_build_options:
|
||||
|
@ -45,7 +45,6 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import DeviceLocalLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.mesh import AbstractMesh, Mesh
|
||||
from jax._src.monitoring import record_event_duration_secs, record_event_time_span
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -419,12 +418,8 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics):
|
||||
|
||||
def _reorder_shards(x, new_s, copy_semantics: CopySemantics):
|
||||
"""Reorders array shards to match the order indicated by the new sharding."""
|
||||
if xla_extension_version >= 304:
|
||||
xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0]
|
||||
return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore
|
||||
else:
|
||||
assert copy_semantics == CopySemantics.ALIAS
|
||||
return array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays)
|
||||
xc_copy_semantics = pxla.to_xc_copy_semantics([copy_semantics])[0]
|
||||
return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
@ -33,7 +33,6 @@ import ml_dtypes
|
||||
import numpy as np
|
||||
|
||||
from jax._src import config
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.typing import Array, DType, DTypeLike
|
||||
from jax._src.util import set_module, StrictABC
|
||||
|
||||
@ -518,7 +517,7 @@ _complex_types: list[JAXType] = [
|
||||
# only meant for the `jnp.isdtype` and we want to be conservative and not allow
|
||||
# StringDType to be used in there.
|
||||
_string_types: list[JAXType] = []
|
||||
if hasattr(np.dtypes, 'StringDType') and xla_extension_version >= 311:
|
||||
if hasattr(np.dtypes, 'StringDType'):
|
||||
_string_types: list[JAXType] = [np.dtypes.StringDType()] # type: ignore
|
||||
|
||||
_jax_dtype_set = {
|
||||
|
@ -55,7 +55,6 @@ from jax._src.sharding_impls import (AUTO, NamedSharding,
|
||||
SdyArraySharding, SdyArrayShardingList)
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import dialects, ir, passmanager
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect, hlo
|
||||
from jax._src.lib.mlir import register_jax_dialects
|
||||
@ -479,20 +478,13 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
|
||||
loc = ctx.traceback_caches.location_cache.get(code_lasti, None)
|
||||
if loc is None:
|
||||
frame = source_info_util.raw_frame_to_frame(code, lasti)
|
||||
if xla_extension_version >= 309:
|
||||
file_loc = ir.Location.file(
|
||||
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
|
||||
frame.start_line,
|
||||
frame.start_column,
|
||||
frame.end_line,
|
||||
frame.end_column,
|
||||
)
|
||||
else:
|
||||
file_loc = ir.Location.file(
|
||||
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
|
||||
frame.start_line,
|
||||
frame.start_column,
|
||||
)
|
||||
file_loc = ir.Location.file(
|
||||
get_canonical_source_file(frame.file_name, ctx.traceback_caches),
|
||||
frame.start_line,
|
||||
frame.start_column,
|
||||
frame.end_line,
|
||||
frame.end_column,
|
||||
)
|
||||
loc = ir.Location.name(frame.function_name, childLoc=file_loc)
|
||||
ctx.traceback_caches.location_cache[code_lasti] = loc
|
||||
frame_locs.append(loc)
|
||||
|
@ -49,7 +49,6 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax.lax import (PrecisionLike,_array_copy,
|
||||
_sort_le_comparator, _sort_lt_comparator)
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.numpy.array_creation import (empty, empty_like, full,
|
||||
ones, ones_like, zeros, zeros_like)
|
||||
from jax._src.numpy import indexing
|
||||
@ -5357,11 +5356,6 @@ def _make_string_array(
|
||||
ndmin: int = 0,
|
||||
device: xc.Device | Sharding | None = None,
|
||||
) -> Array:
|
||||
if xla_extension_version < 311:
|
||||
raise TypeError(
|
||||
"String arrays are not supported in JAX before XLA extension version"
|
||||
" 311."
|
||||
)
|
||||
if not isinstance(object, np.ndarray):
|
||||
raise TypeError(
|
||||
"Currently, string arrays can only be made from NumPy"
|
||||
|
@ -25,7 +25,6 @@ import jax._src.core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import triton
|
||||
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.triton import lowering
|
||||
@ -98,7 +97,7 @@ def pallas_call_lowering(
|
||||
module_op.write_bytecode(buf)
|
||||
|
||||
# TODO(b/394629193): Remove True once the bug is fixed.
|
||||
if True or jaxlib_version < (0, 5, 1):
|
||||
if True:
|
||||
# AOT Triton compilation is only available on jaxlib 0.5.1+.
|
||||
out_types = [
|
||||
ir.RankedTensorType.get(bm.array_shape_dtype.shape,
|
||||
|
@ -591,12 +591,7 @@ def _call_tf_lowering(
|
||||
call = func_dialect.CallOp(callee_result_types,
|
||||
ir.FlatSymbolRefAttr.get(fn),
|
||||
tuple(args_op) + captured_ops)
|
||||
if result_shape.is_tuple() and xla_client.mlir_api_version < 58:
|
||||
# In API version 58, the results are always flattened.
|
||||
flat_results = [hlo.get_tuple_element(call, mlir.i32_attr(i))
|
||||
for i in range(len(result_shapes))]
|
||||
else:
|
||||
flat_results = call.results
|
||||
flat_results = call.results
|
||||
|
||||
if ordered:
|
||||
raise NotImplementedError(
|
||||
|
@ -16,7 +16,6 @@
|
||||
import jax
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from jax._src.lib import xla_client as _xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import use_cpp_class, use_cpp_method
|
||||
|
||||
class TransferConnection:
|
||||
@ -42,7 +41,7 @@ class TransferConnection:
|
||||
return tree.unflatten(self._pull_flat(uuid, backend, xs_flat))
|
||||
|
||||
|
||||
if not TYPE_CHECKING and xla_extension_version >= 305:
|
||||
if not TYPE_CHECKING:
|
||||
TransferConnection = use_cpp_class(_xc._xla.TransferConnection)(TransferConnection)
|
||||
|
||||
|
||||
@ -67,8 +66,7 @@ class TransferServer:
|
||||
self._await_pull_flat(uuid, jax.tree.flatten(arrays)[0])
|
||||
|
||||
|
||||
if not TYPE_CHECKING and xla_extension_version >= 305:
|
||||
if not TYPE_CHECKING:
|
||||
TransferServer = use_cpp_class(_xc._xla.TransferServer)(TransferServer)
|
||||
|
||||
if xla_extension_version >= 305:
|
||||
start_transfer_server = _xc._xla.start_transfer_server
|
||||
start_transfer_server = _xc._xla.start_transfer_server
|
||||
|
@ -61,7 +61,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.compilation_cache import is_persistent_cache_enabled
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax._src.util as jax_util
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.custom_batching
|
||||
@ -1371,8 +1370,6 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
def f(x):
|
||||
return jnp.sqrt(x**2) + 1.0
|
||||
|
||||
if xla_extension_version < 316:
|
||||
self.skipTest("Requires XLA extension version >= 316")
|
||||
f_jit = jit(
|
||||
f,
|
||||
compiler_options={
|
||||
@ -1386,8 +1383,6 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
def f(x):
|
||||
return jnp.sqrt(x**2) + 1.0
|
||||
|
||||
if xla_extension_version < 316:
|
||||
self.skipTest("Requires XLA extension version >= 316")
|
||||
f_jit = jit(
|
||||
f,
|
||||
compiler_options={
|
||||
|
@ -29,7 +29,6 @@ from jax._src import op_shardings
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import dialects, ir
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.mesh import AxisTypes
|
||||
@ -825,8 +824,6 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
@parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
|
||||
@jtu.run_on_devices("gpu")
|
||||
def test_pinned_host_npy_value_doesnt_cache(self, dtype):
|
||||
if xla_extension_version < 314:
|
||||
self.skipTest("Requires XLA extension version >= 314")
|
||||
# see https://github.com/jax-ml/jax/issues/26216
|
||||
d_tensor = jnp.array(0, dtype=dtype)
|
||||
d_sharding = d_tensor.sharding
|
||||
|
@ -25,7 +25,6 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.experimental import colocated_python
|
||||
from jax.experimental.colocated_python import serialization
|
||||
from jax.extend.ifrt_programs import ifrt_programs
|
||||
@ -383,11 +382,6 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
del colocated_python._testing_global_state
|
||||
|
||||
def testStringProcessing(self):
|
||||
if xla_extension_version < 315:
|
||||
self.skipTest(
|
||||
"String support for colocated Python requires xla_extension_version"
|
||||
" >= 315"
|
||||
)
|
||||
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
|
||||
self.skipTest("StringDType requires NumPy 2.0.0 or later")
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
@ -431,11 +425,6 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def testBinaryDataProcessing(self):
|
||||
if xla_extension_version < 315:
|
||||
self.skipTest(
|
||||
"String support for colocated Python requires xla_extension_version"
|
||||
" >= 315"
|
||||
)
|
||||
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
|
||||
self.skipTest("StringDType requires NumPy 2.0.0 or later")
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())
|
||||
|
@ -25,7 +25,6 @@ from jax._src import ad_checkpoint
|
||||
from jax._src import debugging
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -62,9 +61,6 @@ class DebugCallbackTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
@jtu.run_on_devices("cpu")
|
||||
def test_async_deadlock(self):
|
||||
if xla_extension_version < 306:
|
||||
self.skipTest("deadlock expected")
|
||||
|
||||
# See https://github.com/jax-ml/jax/issues/25861
|
||||
def print_it(i, maxiter):
|
||||
self.assertIsInstance(i, jax.Array)
|
||||
|
@ -32,7 +32,7 @@ from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.layout import DeviceLocalLayout
|
||||
from jax._src.lib import lapack, xla_extension_version
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lax import linalg as lax_linalg_internal
|
||||
from jax.experimental.shard_map import shard_map
|
||||
@ -333,8 +333,6 @@ def ffi_call_geqrf(x, _use_extend=False, **kwargs):
|
||||
class BatchPartitioningTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_extension_version < 313:
|
||||
self.skipTest("Requires XLA extension version >= 313")
|
||||
# Register callbacks before checking the number of devices to make sure
|
||||
# that we're testing the registration path, even if we can't run the tests.
|
||||
for target_name in ["lapack_sgeqrf_ffi", "cusolver_geqrf_ffi",
|
||||
|
@ -23,7 +23,6 @@ from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding
|
||||
from jax._src import config
|
||||
from jax._src.layout import Layout, DeviceLocalLayout as DLL
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip
|
||||
from jax.experimental.compute_on import compute_on
|
||||
|
||||
@ -33,12 +32,6 @@ jtu.request_cpu_devices(8)
|
||||
|
||||
class LayoutTest(jtu.JaxTestCase):
|
||||
|
||||
# Remove this setUp once the released xla_extension_version is >= 308.
|
||||
def setUp(self):
|
||||
if xla_extension_version < 308 and not jtu.test_device_matches(['tpu', 'gpu']):
|
||||
self.skipTest("Layouts do not work on CPU backend yet.")
|
||||
super().setUp()
|
||||
|
||||
def test_auto_layout(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape1 = (128, 128)
|
||||
|
@ -1707,8 +1707,6 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
if pivoting:
|
||||
if not jtu.test_device_matches(["cpu", "gpu"]):
|
||||
self.skipTest("Pivoting is only supported on CPU and GPU.")
|
||||
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 5, 0):
|
||||
self.skipTest("Pivoting is only supported on GPU for jaxlib > 0.5.0")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting)
|
||||
sp_func = partial(scipy.linalg.qr, mode=mode, pivoting=pivoting)
|
||||
|
@ -120,8 +120,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
def testPivotedQrFactorization(self, shape, dtype):
|
||||
if jtu.jaxlib_version() <= (0, 5, 0):
|
||||
self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
|
||||
if not gpu_solver.has_magma():
|
||||
self.skipTest("MAGMA is not installed or can't be loaded.")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -133,8 +131,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(lax_func, args_maker)
|
||||
|
||||
def testPivotedQrFactorizationMagmaConfig(self):
|
||||
if jtu.jaxlib_version() <= (0, 5, 0):
|
||||
self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
|
||||
if not gpu_solver.has_magma():
|
||||
self.skipTest("MAGMA is not installed or can't be loaded.")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
@ -26,7 +26,6 @@ from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.layout import DeviceLocalLayout as DLL, Layout
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import config
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.numpy as jnp
|
||||
@ -1664,8 +1663,6 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
class StreamAnnotationTest(jtu.JaxTestCase):
|
||||
|
||||
def test_stream_annotation_inside_shmap(self):
|
||||
if xla_extension_version < 313:
|
||||
self.skipTest("Requires xla_extension_version >= 313")
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Stream annotation is only supported on GPU.")
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
@ -29,7 +29,6 @@ import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir import passmanager
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
@ -2701,9 +2700,6 @@ class UtilsTest(TestCase):
|
||||
class SerializationTest(absltest.TestCase):
|
||||
|
||||
def test_pass_is_registered(self):
|
||||
if jaxlib_version < (0, 5, 1):
|
||||
self.skipTest("Test requires jaxlib 0.5.1 or later")
|
||||
|
||||
ctx = mlir.make_ir_context()
|
||||
ctx.allow_unregistered_dialects = True
|
||||
with ir.Location.unknown(ctx):
|
||||
|
@ -62,7 +62,6 @@ from jax._src.lib.mlir import dialects
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import curry, unzip2
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -7147,9 +7146,6 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
f(arr, 2., 3.) # doesn't crash
|
||||
|
||||
def test_named_sharding_of_none(self):
|
||||
if xla_extension_version < 309:
|
||||
raise unittest.SkipTest("NamedSharding does't reject None.")
|
||||
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
with self.assertRaisesRegex(TypeError, 'Unexpected None'):
|
||||
jax.NamedSharding(mesh, None)
|
||||
|
@ -37,8 +37,6 @@ class RaggedCollectiveTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
if jtu.test_device_matches(['cpu']):
|
||||
self.skipTest('ragged-all-to-all is not supported on CPU')
|
||||
if jtu.jaxlib_version() < (0, 5, 1):
|
||||
self.skipTest('ragged-all-to-all is not supported on jaxlib version < 0.5.1')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
|
@ -18,7 +18,6 @@ import jax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -29,12 +28,6 @@ class StringArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_extension_version < 311:
|
||||
self.skipTest(
|
||||
"Skipping this test because the current XLA extension version:"
|
||||
f" {xla_extension_version} is older than 309, the oldest version with"
|
||||
" string array support."
|
||||
)
|
||||
if not hasattr(np.dtypes, "StringDType"):
|
||||
self.skipTest(
|
||||
"Skipping this test because the numpy.dtype.StringDType is not"
|
||||
|
Loading…
x
Reference in New Issue
Block a user