Remove code that supported jaxlib < 0.5.

The new xla_extension_version is 303 and the new mlir_api_version is 57.
This commit is contained in:
Peter Hawkins 2025-01-17 14:15:36 -05:00
parent 7d81547f91
commit efab6945ca
14 changed files with 25 additions and 105 deletions

View File

@ -32,8 +32,6 @@ from jax._src import path as pathlib
from jax._src import profiler
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
import numpy as np
@ -199,10 +197,7 @@ def get_compile_options(
logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.")
if env_options_overrides is None:
env_options_overrides = {}
if xla_extension_version > 302:
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
else:
env_options_overrides['xla_gpu_graph_min_graph_size'] = '100000'
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
@ -258,7 +253,7 @@ def get_compile_options(
debug_options.xla_detailed_logging = detailed_logging
# If persistent cache is enabled, also enable additional XLA caching features.
if compilation_cache.is_persistent_cache_enabled() and jaxlib_version > (0, 4, 35):
if compilation_cache.is_persistent_cache_enabled():
# compilation_cache_dir can't be None here, but the type checker is a bit
# strict.
path = pathlib.Path(config.compilation_cache_dir.value or "")

View File

@ -47,7 +47,6 @@ from jax._src.lax.lax import (
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_sparse
from jax._src.lib import lapack
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
@ -1364,10 +1363,8 @@ def _triangular_solve_cpu_lower(
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
return lapack.trsm_hlo(
*ctx_args, a_aval.dtype, alpha,
ctx, a_aval.dtype, alpha,
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
b_shape_vals=b_shape_vals)
else:
@ -2540,9 +2537,6 @@ def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b):
def _tridiagonal_solve_cpu_lowering(ctx, dl, d, du, b, **kwargs):
if jaxlib_version <= (0, 4, 38):
rule = mlir.lower_fun(_tridiagonal_solve_jax, multiple_results=False)
return rule(ctx, dl, d, du, b, **kwargs)
b_aval = ctx.avals_in[-1]
batch_dims = b_aval.shape[:-2]
target_name = lapack.prepare_lapack_call("gtsv_ffi", b_aval.dtype)
@ -2755,13 +2749,12 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
gees_result = lapack.gees_hlo(*ctx_args, operand_aval.dtype, operand,
gees_result = lapack.gees_hlo(ctx, operand_aval.dtype, operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable,
a_shape_vals=a_shape_vals)
if jaxlib_version >= (0, 4, 37) and not ctx.is_forward_compat():
if not ctx.is_forward_compat():
schur_form, schur_vectors, _eig_vals, _selected_eig_vals, info = gees_result
else:
# Number of return values depends on value of sort_eig_vals.

View File

@ -42,7 +42,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import for_loop
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
@ -1550,10 +1549,6 @@ def _masked_swap_lowering_rule(
if mask is not None:
raise NotImplementedError("masked swap with strided store")
tpu.StridedStoreOp(val, ref, starts, strides)
elif jaxlib_version <= (0, 4, 35):
if mask is not None:
raise NotImplementedError("masked swap with vector store")
vector.StoreOp(val, ref, starts)
else:
tpu.VectorStoreOp(val, ref, starts, [], mask=mask)
return result

View File

@ -277,21 +277,13 @@ def make_cpu_client(
f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.")
num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None
if xla_client._version < 303 and num_devices is not None:
xla_flags = os.getenv("XLA_FLAGS") or ""
os.environ["XLA_FLAGS"] = (
f"{xla_flags} --xla_force_host_platform_device_count={num_devices}"
)
num_devices = None
# TODO(phawkins): pass num_devices directly when version 303 is the minimum.
kwargs = {} if num_devices is None else {"num_devices": num_devices}
return xla_client.make_cpu_client(
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
num_nodes=distributed.global_state.num_processes,
collectives=collectives,
**kwargs,
num_devices=num_devices,
)

View File

@ -31,7 +31,6 @@ from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.mesh import Mesh
from jax._src.partition_spec import PartitionSpec as P
@ -69,8 +68,7 @@ class CacheKeyTest(jtu.JaxTestCase):
debug_options.xla_dump_hlo_as_long_text = True
debug_options.xla_dump_disable_metadata = True
debug_options.xla_dump_hlo_pipeline_re = "xyzzy"
if jaxlib_version > (0, 4, 35):
debug_options.xla_gpu_experimental_autotune_cache_mode = 2
debug_options.xla_gpu_experimental_autotune_cache_mode = 2
hash2 = self.get_hashed_value(
cache_key._hash_serialized_compile_options, compile_options
)

View File

@ -531,8 +531,6 @@ class CompilationCacheTest(CompilationCacheTestCase):
executable.fingerprint, deserialized_executable.fingerprint)
def test_persistent_cache_enable_xla_caches(self):
if jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("Test requires AutotuneCacheMode bindings")
s = os.sep
with config.compilation_cache_dir("jax-cache"):
with config.persistent_cache_enable_xla_caches("none"):
@ -603,8 +601,6 @@ class CompilationCacheDisabledTest(CompilationCacheTestCase):
self.assertEqual(count_after_second_use, count_after_first_use)
def test_persistent_cache_enable_xla_caches_disabled(self):
if jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("Test requires AutotuneCacheMode bindings")
with config.enable_compilation_cache(False):
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1

View File

@ -71,7 +71,6 @@ from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax._src.lib import version as jaxlib_version
config.parse_flags_with_absl()
@ -650,14 +649,11 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_schur_results)
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 37)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_schur_lapack_gees.data_2024_11_29[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_schur_results)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_schur_lapack_gees.data_2024_11_29[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_schur_results)
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
@ -746,14 +742,11 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_triangular_solve_results)
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 37)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_triangular_solve_results)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_triangular_solve_results)
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
@ -808,23 +801,18 @@ class CompatTest(bctu.CompatTestBase):
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03[dtype_name]
)
self.run_one_test(func, data, rtol=rtol, atol=atol)
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 37)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01[dtype_name]
)
self.run_one_test(func, data, rtol=rtol, atol=atol)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01[dtype_name]
)
self.run_one_test(func, data, rtol=rtol, atol=atol)
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
for dtype_name in ("f32", "f64", "c64", "c128"))
@jax.default_matmul_precision("float32")
def test_cpu_tridiagonal_solve_lapack_gtsv(self, dtype_name):
if jtu.jaxlib_version() <= (0, 4, 38):
self.skipTest("Test requires a newer jaxlib version")
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")

View File

@ -50,7 +50,6 @@ from jax._src import core
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.lib import version as jaxlib_version
from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace
config.parse_flags_with_absl()
@ -1566,8 +1565,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.")
elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]):
self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU and GPU backends.")
if rank == 2 and jaxlib_version <= (0, 4, 35) and jtu.test_device_matches(["gpu"]):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 }
if jtu.test_device_matches(["tpu"]):

View File

@ -268,8 +268,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.run_on_devices("cpu", "gpu")
def testEig(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
n = shape[-1]
args_maker = lambda: [rng(shape, dtype)]
@ -312,8 +310,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
"""Verifies that `eig` fails gracefully if given non-finite inputs."""
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
a = jnp.full(shape, jnp.nan, dtype)
results = lax.linalg.eig(
a, compute_left_eigenvectors=compute_left_eigenvectors,
@ -331,8 +327,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
# haven't checked, that might be because of perturbations causing the
# ordering of eigenvalues to change, which will trip up check_grads. So we
# just test on small-ish matrices.
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
@ -346,8 +340,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
)
@jtu.run_on_devices("cpu", "gpu")
def testEigvals(self, shape, dtype):
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
@ -358,8 +350,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.run_on_devices("cpu", "gpu")
def testEigvalsInf(self):
# https://github.com/jax-ml/jax/issues/2661
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
x = jnp.array([[jnp.inf]])
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
@ -369,8 +359,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
)
@jtu.run_on_devices("cpu", "gpu")
def testEigBatching(self, shape, dtype):
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
shape = (10,) + shape
args = rng(shape, dtype)
@ -1697,8 +1685,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
)
def testScipyQrModes(self, shape, dtype, mode, pivoting):
is_not_cpu_test_device = not jtu.test_device_matches(["cpu"])
is_not_valid_jaxlib_version = jtu.jaxlib_version() <= (0, 4, 38)
if pivoting and (is_not_cpu_test_device or is_not_valid_jaxlib_version):
if pivoting and is_not_cpu_test_device:
self.skipTest("Pivoting is only supported on CPU with jaxlib > 0.4.38")
rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting)

View File

@ -42,8 +42,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
@jtu.run_on_devices("gpu")
def testEig(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
if jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
# TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for
@ -93,8 +91,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
"""Verifies that `eig` fails gracefully if given non-finite inputs."""
if jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
# TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for
@ -110,8 +106,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
self.assertTrue(np.all(np.isnan(result)))
def testEigMagmaConfig(self):
if jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
rng = jtu.rand_default(self.rng())

View File

@ -272,8 +272,6 @@ class OpsTest(PallasBaseTest):
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int8])
def test_cast_vector_to_mask(self, dtype):
if jtu.jaxlib_version() <= (0, 4, 39):
self.skipTest("Test requires non-32-bit selection support")
shape = (128, 128)
bitwidth = pallas_utils.dtype_bitwidth(dtype)
if (

View File

@ -1724,8 +1724,6 @@ class PallasCallTest(PallasBaseTest):
np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1))
def test_masked_store(self):
if jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("Test requires masked store support")
shape = (16, 256)
mask_shape = (10, 130)
mask_start = (4, 5)

View File

@ -37,7 +37,6 @@ from jax.experimental.serialize_executable import (
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
import numpy as np
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
jax.config.parse_flags_with_absl()
@ -94,10 +93,7 @@ class PgleTest(jtu.JaxTestCase):
'xla_gpu_enable_latency_hiding_scheduler': 'True',
}
# TODO(b/37664749): Remove this flag once the bug is fixed.
if xla_extension_version > 302:
compiler_options['xla_gpu_enable_command_buffer'] = ''
else:
compiler_options['xla_gpu_graph_min_graph_size'] = '100000'
compiler_options['xla_gpu_enable_command_buffer'] = ''
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
@ -138,8 +134,6 @@ class PgleTest(jtu.JaxTestCase):
'xla_gpu_experimental_dump_fdo_profiles': 'True',
}
# TODO(b/376647494): Remove this flag once the bug is fixed.
if xla_extension_version <= 302:
compile_options['xla_gpu_graph_min_graph_size'] = '100000'
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
@ -224,8 +218,6 @@ class PgleTest(jtu.JaxTestCase):
'xla_gpu_experimental_dump_fdo_profiles': 'True',
}
# TODO(b/376647494): Remove this flag once the bug is fixed.
if xla_extension_version <= 302:
compiler_options['xla_gpu_graph_min_graph_size'] = '100000'
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),

View File

@ -48,7 +48,6 @@ import jax.numpy as jnp
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.shard_map import shard_map
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
@ -2170,8 +2169,6 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1))
def test_partial_auto_ppermute(self):
if xla_extension_version < 302:
self.skipTest('minimum xla extension version 302')
if config.use_shardy_partitioner.value:
self.skipTest('Shardy does not support full-to-shard.')