Remove code that existed to support jaxlib < 0.4.32.

New minimum versions:
* jaxlib 0.4.32
* xla_extension_version 283
* mlir_api_version 57

PiperOrigin-RevId: 675291231
This commit is contained in:
Peter Hawkins 2024-09-16 14:29:21 -07:00 committed by jax authors
parent df385b6ad3
commit 940860625e
9 changed files with 45 additions and 112 deletions

View File

@ -33,7 +33,6 @@ from jax._src import profiler
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
import numpy as np
@ -157,8 +156,7 @@ def get_compile_options(
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
if xla_extension_version >= 280:
build_options.use_shardy_partitioner = use_shardy_partitioner
build_options.use_shardy_partitioner = use_shardy_partitioner
if fdo_profile is not None:
build_options.fdo_profile = fdo_profile
if use_auto_spmd_partitioning:

View File

@ -61,7 +61,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
@ -3022,12 +3021,8 @@ class MeshExecutable(stages.XlaExecutable):
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, cc_shard_arg)
if xla_extension_version < 282:
def cc_shard_arg(x, sharding):
return shard_args([sharding], [None], [x])[0]
else:
def cc_shard_arg(x, sharding, layout): # type: ignore
return shard_args([sharding], [layout], [x])[0]
def cc_shard_arg(x, sharding, layout):
return shard_args([sharding], [layout], [x])[0]
def check_arg_avals_for_call(ref_avals, arg_avals,

View File

@ -514,11 +514,7 @@ def _cholesky_cpu_lowering(ctx, operand):
out_aval, = ctx.avals_out
batch_dims = operand_aval.shape[:-2]
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
# TODO(b/344892332): Remove the check after the compatibility period.
if jaxlib_version < (0, 4, 31):
ctx_arg = ()
else:
ctx_arg = (ctx,)
ctx_arg = (ctx,)
result, info = lapack.potrf_hlo(*ctx_arg, operand_aval.dtype, operand,
lower=True, a_shape_vals=op_shape_vals)
@ -556,7 +552,7 @@ def _cholesky_update_abstract_eval(r_matrix, w_vector):
def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, w_vector):
# TODO(b/360781533): Remove guard after 3 week forward compatibility period.
if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32):
if ctx.is_forward_compat():
r_matrix_aval, _ = ctx.avals_in
try:
[platform] = ctx.module_context.platforms
@ -726,8 +722,7 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
ctx_args = (ctx,)
w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand,
input_shape_vals=op_shape_vals,
jobvl=compute_left_eigenvectors,
@ -937,8 +932,7 @@ def _eigh_cpu_gpu_lowering(
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
cpu_args = []
if platform == "cpu":
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
ctx_args = (ctx,)
cpu_args.extend(ctx_args)
v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand,
a_shape_vals=op_shape_vals, lower=lower)
@ -1511,9 +1505,9 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, platform: str,
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
m = operand_aval.shape[-2]
# TODO(b/357034884): Remove version gate once jaxlib 0.4.32 is the minimum
# version and the forward compat flag after the 3 week compatibility window.
if jaxlib_version < (0, 4, 32) or ctx.is_forward_compat():
# TODO(b/357034884): Remove version gate on the forward compat flag after the
# 3 week compatibility window.
if ctx.is_forward_compat():
if not is_constant_shape(operand_aval.shape[-2:]):
raise NotImplementedError(
"Shape polymorphism for native lowering for lu on CPU and GPU is "
@ -1757,9 +1751,8 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *,
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period
ctx_args = (
(ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
(ctx,) if platform == "cpu" else ()
)
a_out, taus, *maybe_info_geqrf = geqrf_impl(
*ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals
@ -1881,9 +1874,8 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *,
f"on GPU is not implemented; b/261671778; {a_aval.shape}")
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
else:
# TODO(b/344892332): Remove the conditional after the compatibility period
ctx_args = (
(ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
(ctx,) if platform == "cpu" else ()
)
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape)
@ -2152,8 +2144,7 @@ def _svd_cpu_gpu_lowering(
compute_uv=compute_uv)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
ctx_args = (ctx,)
s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand,
full_matrices=full_matrices,
compute_uv=compute_uv,

View File

@ -66,7 +66,6 @@ from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax._src.lib import version as jaxlib_version
config.parse_flags_with_absl()
@ -190,14 +189,11 @@ class CompatTest(bctu.CompatTestBase):
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name])
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 31)
self.run_one_test(func, data, rtol=rtol, atol=atol)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol)
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
@ -258,14 +254,11 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_eig_results)
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_eig_results)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_eig_results)
@staticmethod
def eigh_input(shape, dtype):
@ -316,14 +309,11 @@ class CompatTest(bctu.CompatTestBase):
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
# FFI Kernel test
with config.export_ignore_forward_compatibility(True):
data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))
# FFI Kernel test
with config.export_ignore_forward_compatibility(True):
data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_eigh_results, operand))
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}_{variant}",
@ -385,8 +375,6 @@ class CompatTest(bctu.CompatTestBase):
def test_cuda_lu_lapack_getrf(self, dtype_name:str):
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
if jaxlib_version < (0, 4, 32):
self.skipTest("Not implemented in older versions of jaxlib")
dtype = dict(f32=np.float32, f64=np.float64,
c64=np.complex64, c128=np.complex128)[dtype_name]
shape = (3, 4)
@ -416,15 +404,12 @@ class CompatTest(bctu.CompatTestBase):
data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name])
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
self.run_one_test(func, data, rtol=rtol)
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
)
self.run_one_test(func, data, rtol=rtol)
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
)
self.run_one_test(func, data, rtol=rtol)
@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}_{batched}",
@ -502,14 +487,11 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_lu_results, operand,
dtype=dtype))
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_lu_results, operand,
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_lu_results, operand,
dtype=dtype))
def check_svd_results(self, input, res_run, res_exp,
@ -629,16 +611,13 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results,
input))
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
)
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results, input))
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(
cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
)
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=partial(self.check_svd_results, input))
@jtu.parameterized_filterable(
kwargs=[

View File

@ -15,7 +15,6 @@
import contextlib
import math
from functools import partial
import unittest
from absl.testing import absltest
import numpy as np
@ -511,8 +510,6 @@ class LayoutTest(jtu.JaxTestCase):
'Layout passed to jit does not match the layout on the respective arg'):
g(arr)
@unittest.skipIf(xla_extension_version < 282,
"Requires xla_extension_version >= 282")
def test_in_layouts_jit_jnp_input(self):
major_last_layout = DLL(major_to_minor=(1, 0))
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])

View File

@ -16,7 +16,6 @@
from functools import partial
import itertools
import unittest
import numpy as np
import scipy
@ -2194,9 +2193,6 @@ class LaxLinalgTest(jtu.JaxTestCase):
symmetrize_output=[True, False],
)
@jtu.skip_on_devices("tpu")
@unittest.skipIf(
jax._src.lib.version < (0, 4, 32), "requires jaxlib >= 0.4.32"
)
def testSymmetricProduct(self, shape, dtype, symmetrize_output):
rng = jtu.rand_default(self.rng())
batch_size = 10

View File

@ -56,7 +56,6 @@ from jax._src.interpreters import pxla
from jax._src.lib.mlir import dialects
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import xla_extension
from jax._src.util import curry, unzip2
@ -4433,8 +4432,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
"Compiled object called with input sharding.*does not match"):
compiled(cpu_arr)
@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_different_devices_wsc_abstract_mesh_cache_hit(self):
if jax.device_count() < 4:
self.skipTest('Requires >=4 devices')
@ -4463,8 +4460,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(lowering_count[0], 1)
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.
@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_abstract_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
@ -4484,8 +4479,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertArraysEqual(out_eager, np_inp * 2)
self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x')))
@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_sds_abstract_mesh(self):
mesh = jtu.create_mesh((2,), 'x')
s = NamedSharding(mesh, P())
@ -4499,8 +4492,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
sds = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s)
f.eval_shape(sds) # doesn't crash
@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_vmap_abstract_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
@ -4517,8 +4508,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
out2 = jax.jit(jax.vmap(f, spmd_axis_name='y'))(arr)
self.assertEqual(out2.sharding, NamedSharding(mesh, P('y', 'x')))
@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_wsc_abstract_mesh_errors(self):
mesh = jtu.create_mesh((2,), ('x',))
np_inp = np.arange(8)

View File

@ -2843,11 +2843,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
((2, 3, 8, 4), "b1, b2, ..."),
((2, 3, 4, 5), "b1, b2, m, n"),
]
# TODO(danfm): Remove once jaxlib v0.4.32 is the minimum version.
# jaxlib versions before 0.4.32 require a static shape for the non-batch
# dimensions because these are used for computing the "permuation_size"
# which is passed to lu_pivots_to_permutation.
if jaxlib_version >= (0, 4, 32) or not poly.endswith("m, n")
],
[
# The random primitive tests, with threefry (both partitionable and

View File

@ -45,7 +45,6 @@ from jax._src.interpreters import partial_eval as pe
from jax._src import linear_util as lu
from jax._src import tree_util
import jax.numpy as jnp
from jax._src.lib import xla_extension_version
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.shard_map import shard_map
@ -777,8 +776,6 @@ class ShardMapTest(jtu.JaxTestCase):
# is over an axis of size 2. This is a problem at the moment.
jax.make_jaxpr(mapped)(x, y).jaxpr
@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_shard_map_abstract_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
@ -803,8 +800,6 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertArraysEqual(out2, np_inp)
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_different_devices_shmap_abstract_mesh_cache_hit(self):
if jax.device_count() < 4:
self.skipTest('Requires >=4 devices')
@ -835,8 +830,6 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertEqual(lowering_count[0], 1)
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.
@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_shmap_abstract_mesh_errors(self):
mesh = jtu.create_mesh((2,), ('x',))
np_inp = np.arange(8)