From efab6945ca285743df5265838307958857f17df5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 17 Jan 2025 14:15:36 -0500 Subject: [PATCH] Remove code that supported jaxlib < 0.5. The new xla_extension_version is 303 and the new mlir_api_version is 57. --- jax/_src/compiler.py | 9 ++---- jax/_src/lax/linalg.py | 13 ++------- jax/_src/pallas/mosaic/lowering.py | 5 ---- jax/_src/xla_bridge.py | 10 +------ tests/cache_key_test.py | 4 +-- tests/compilation_cache_test.py | 4 --- tests/export_back_compat_test.py | 44 +++++++++++------------------- tests/lax_numpy_test.py | 3 -- tests/linalg_test.py | 15 +--------- tests/magma_linalg_test.py | 6 ---- tests/pallas/tpu_ops_test.py | 2 -- tests/pallas/tpu_pallas_test.py | 2 -- tests/pgle_test.py | 10 +------ tests/shard_map_test.py | 3 -- 14 files changed, 25 insertions(+), 105 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 21bc80384..9250e0b61 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -32,8 +32,6 @@ from jax._src import path as pathlib from jax._src import profiler from jax._src import traceback_util from jax._src.interpreters import mlir -from jax._src.lib import version as jaxlib_version -from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir import numpy as np @@ -199,10 +197,7 @@ def get_compile_options( logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.") if env_options_overrides is None: env_options_overrides = {} - if xla_extension_version > 302: - env_options_overrides['xla_gpu_enable_command_buffer'] = '' - else: - env_options_overrides['xla_gpu_graph_min_graph_size'] = '100000' + env_options_overrides['xla_gpu_enable_command_buffer'] = '' if env_options_overrides is not None: # Some overrides are passed directly on build_options. @@ -258,7 +253,7 @@ def get_compile_options( debug_options.xla_detailed_logging = detailed_logging # If persistent cache is enabled, also enable additional XLA caching features. - if compilation_cache.is_persistent_cache_enabled() and jaxlib_version > (0, 4, 35): + if compilation_cache.is_persistent_cache_enabled(): # compilation_cache_dir can't be None here, but the type checker is a bit # strict. path = pathlib.Path(config.compilation_cache_dir.value or "") diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 440ee424f..85f334816 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -47,7 +47,6 @@ from jax._src.lax.lax import ( from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -1364,10 +1363,8 @@ def _triangular_solve_cpu_lower( if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types: alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)) b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape) - # TODO(b/344892332): Remove the conditional after the compatibility period. - ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else () return lapack.trsm_hlo( - *ctx_args, a_aval.dtype, alpha, + ctx, a_aval.dtype, alpha, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal, b_shape_vals=b_shape_vals) else: @@ -2540,9 +2537,6 @@ def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b): def _tridiagonal_solve_cpu_lowering(ctx, dl, d, du, b, **kwargs): - if jaxlib_version <= (0, 4, 38): - rule = mlir.lower_fun(_tridiagonal_solve_jax, multiple_results=False) - return rule(ctx, dl, d, du, b, **kwargs) b_aval = ctx.avals_in[-1] batch_dims = b_aval.shape[:-2] target_name = lapack.prepare_lapack_call("gtsv_ffi", b_aval.dtype) @@ -2755,13 +2749,12 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals, a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) # TODO(b/344892332): Remove the conditional after the compatibility period. - ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else () - gees_result = lapack.gees_hlo(*ctx_args, operand_aval.dtype, operand, + gees_result = lapack.gees_hlo(ctx, operand_aval.dtype, operand, jobvs=compute_schur_vectors, sort=sort_eig_vals, select=select_callable, a_shape_vals=a_shape_vals) - if jaxlib_version >= (0, 4, 37) and not ctx.is_forward_compat(): + if not ctx.is_forward_compat(): schur_form, schur_vectors, _eig_vals, _selected_eig_vals, info = gees_result else: # Number of return values depends on value of sort_eig_vals. diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 58ed2d22f..20256331c 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -42,7 +42,6 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func @@ -1550,10 +1549,6 @@ def _masked_swap_lowering_rule( if mask is not None: raise NotImplementedError("masked swap with strided store") tpu.StridedStoreOp(val, ref, starts, strides) - elif jaxlib_version <= (0, 4, 35): - if mask is not None: - raise NotImplementedError("masked swap with vector store") - vector.StoreOp(val, ref, starts) else: tpu.VectorStoreOp(val, ref, starts, [], mask=mask) return result diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index bbe663175..4d32f787e 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -277,21 +277,13 @@ def make_cpu_client( f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.") num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None - if xla_client._version < 303 and num_devices is not None: - xla_flags = os.getenv("XLA_FLAGS") or "" - os.environ["XLA_FLAGS"] = ( - f"{xla_flags} --xla_force_host_platform_device_count={num_devices}" - ) - num_devices = None - # TODO(phawkins): pass num_devices directly when version 303 is the minimum. - kwargs = {} if num_devices is None else {"num_devices": num_devices} return xla_client.make_cpu_client( asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value, distributed_client=distributed.global_state.client, node_id=distributed.global_state.process_id, num_nodes=distributed.global_state.num_processes, collectives=collectives, - **kwargs, + num_devices=num_devices, ) diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 74f76c75b..2faa4dbaf 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -31,7 +31,6 @@ from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.lib import xla_client -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.mesh import Mesh from jax._src.partition_spec import PartitionSpec as P @@ -69,8 +68,7 @@ class CacheKeyTest(jtu.JaxTestCase): debug_options.xla_dump_hlo_as_long_text = True debug_options.xla_dump_disable_metadata = True debug_options.xla_dump_hlo_pipeline_re = "xyzzy" - if jaxlib_version > (0, 4, 35): - debug_options.xla_gpu_experimental_autotune_cache_mode = 2 + debug_options.xla_gpu_experimental_autotune_cache_mode = 2 hash2 = self.get_hashed_value( cache_key._hash_serialized_compile_options, compile_options ) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index ef245bc8d..3952843bd 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -531,8 +531,6 @@ class CompilationCacheTest(CompilationCacheTestCase): executable.fingerprint, deserialized_executable.fingerprint) def test_persistent_cache_enable_xla_caches(self): - if jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("Test requires AutotuneCacheMode bindings") s = os.sep with config.compilation_cache_dir("jax-cache"): with config.persistent_cache_enable_xla_caches("none"): @@ -603,8 +601,6 @@ class CompilationCacheDisabledTest(CompilationCacheTestCase): self.assertEqual(count_after_second_use, count_after_first_use) def test_persistent_cache_enable_xla_caches_disabled(self): - if jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("Test requires AutotuneCacheMode bindings") with config.enable_compilation_cache(False): compile_options = compiler.get_compile_options( num_replicas=1, num_partitions=1 diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index b453056e7..346490f67 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -71,7 +71,6 @@ from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import test_util as jtu from jax._src.lib import cuda_versions -from jax._src.lib import version as jaxlib_version config.parse_flags_with_absl() @@ -650,14 +649,11 @@ class CompatTest(bctu.CompatTestBase): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_schur_results) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 37) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata(cpu_schur_lapack_gees.data_2024_11_29[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_schur_results) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_schur_lapack_gees.data_2024_11_29[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=check_schur_results) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) @@ -746,14 +742,11 @@ class CompatTest(bctu.CompatTestBase): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_triangular_solve_results) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 37) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_triangular_solve_results) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=check_triangular_solve_results) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) @@ -808,23 +801,18 @@ class CompatTest(bctu.CompatTestBase): cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03[dtype_name] ) self.run_one_test(func, data, rtol=rtol, atol=atol) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 37) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata( - cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol, atol=atol) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @jax.default_matmul_precision("float32") def test_cpu_tridiagonal_solve_lapack_gtsv(self, dtype_name): - if jtu.jaxlib_version() <= (0, 4, 38): - self.skipTest("Test requires a newer jaxlib version") if not config.enable_x64.value and dtype_name in ["f64", "c128"]: self.skipTest("Test disabled for x32 mode") diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index b1328016b..62b0fc994 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -50,7 +50,6 @@ from jax._src import core from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.lib import version as jaxlib_version from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace config.parse_flags_with_absl() @@ -1566,8 +1565,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]): self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU and GPU backends.") - if rank == 2 and jaxlib_version <= (0, 4, 35) and jtu.test_device_matches(["gpu"]): - self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } if jtu.test_device_matches(["tpu"]): diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 8327e8da4..3027b36ff 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -268,8 +268,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): @jtu.run_on_devices("cpu", "gpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): - if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) n = shape[-1] args_maker = lambda: [rng(shape, dtype)] @@ -312,8 +310,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" - if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("eig on GPU requires jaxlib version > 0.4.35") a = jnp.full(shape, jnp.nan, dtype) results = lax.linalg.eig( a, compute_left_eigenvectors=compute_left_eigenvectors, @@ -331,8 +327,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): # haven't checked, that might be because of perturbations causing the # ordering of eigenvalues to change, which will trip up check_grads. So we # just test on small-ish matrices. - if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -346,8 +340,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): ) @jtu.run_on_devices("cpu", "gpu") def testEigvals(self, shape, dtype): - if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -358,8 +350,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): @jtu.run_on_devices("cpu", "gpu") def testEigvalsInf(self): # https://github.com/jax-ml/jax/issues/2661 - if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("eig on GPU requires jaxlib version > 0.4.35") x = jnp.array([[jnp.inf]]) self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x)))) @@ -369,8 +359,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): ) @jtu.run_on_devices("cpu", "gpu") def testEigBatching(self, shape, dtype): - if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) shape = (10,) + shape args = rng(shape, dtype) @@ -1697,8 +1685,7 @@ class ScipyLinalgTest(jtu.JaxTestCase): ) def testScipyQrModes(self, shape, dtype, mode, pivoting): is_not_cpu_test_device = not jtu.test_device_matches(["cpu"]) - is_not_valid_jaxlib_version = jtu.jaxlib_version() <= (0, 4, 38) - if pivoting and (is_not_cpu_test_device or is_not_valid_jaxlib_version): + if pivoting and is_not_cpu_test_device: self.skipTest("Pivoting is only supported on CPU with jaxlib > 0.4.38") rng = jtu.rand_default(self.rng()) jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting) diff --git a/tests/magma_linalg_test.py b/tests/magma_linalg_test.py index bf9c0fb6b..27ab58aae 100644 --- a/tests/magma_linalg_test.py +++ b/tests/magma_linalg_test.py @@ -42,8 +42,6 @@ class MagmaLinalgTest(jtu.JaxTestCase): @jtu.run_on_devices("gpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): - if jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("eig on GPU requires jaxlib version > 0.4.35") if not gpu_solver.has_magma(): self.skipTest("MAGMA is not installed or can't be loaded.") # TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for @@ -93,8 +91,6 @@ class MagmaLinalgTest(jtu.JaxTestCase): def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" - if jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("eig on GPU requires jaxlib version > 0.4.35") if not gpu_solver.has_magma(): self.skipTest("MAGMA is not installed or can't be loaded.") # TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for @@ -110,8 +106,6 @@ class MagmaLinalgTest(jtu.JaxTestCase): self.assertTrue(np.all(np.isnan(result))) def testEigMagmaConfig(self): - if jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("eig on GPU requires jaxlib version > 0.4.35") if not gpu_solver.has_magma(): self.skipTest("MAGMA is not installed or can't be loaded.") rng = jtu.rand_default(self.rng()) diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 7e1da537f..be1118f36 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -272,8 +272,6 @@ class OpsTest(PallasBaseTest): @parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int8]) def test_cast_vector_to_mask(self, dtype): - if jtu.jaxlib_version() <= (0, 4, 39): - self.skipTest("Test requires non-32-bit selection support") shape = (128, 128) bitwidth = pallas_utils.dtype_bitwidth(dtype) if ( diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 746485d68..45d6cf2b5 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1724,8 +1724,6 @@ class PallasCallTest(PallasBaseTest): np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1)) def test_masked_store(self): - if jtu.jaxlib_version() <= (0, 4, 35): - self.skipTest("Test requires masked store support") shape = (16, 256) mask_shape = (10, 130) mask_start = (4, 5) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 46add86e2..6f6fb96aa 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -37,7 +37,6 @@ from jax.experimental.serialize_executable import ( import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec import numpy as np -from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member jax.config.parse_flags_with_absl() @@ -94,10 +93,7 @@ class PgleTest(jtu.JaxTestCase): 'xla_gpu_enable_latency_hiding_scheduler': 'True', } # TODO(b/37664749): Remove this flag once the bug is fixed. - if xla_extension_version > 302: - compiler_options['xla_gpu_enable_command_buffer'] = '' - else: - compiler_options['xla_gpu_graph_min_graph_size'] = '100000' + compiler_options['xla_gpu_enable_command_buffer'] = '' @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), @@ -138,8 +134,6 @@ class PgleTest(jtu.JaxTestCase): 'xla_gpu_experimental_dump_fdo_profiles': 'True', } # TODO(b/376647494): Remove this flag once the bug is fixed. - if xla_extension_version <= 302: - compile_options['xla_gpu_graph_min_graph_size'] = '100000' @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), @@ -224,8 +218,6 @@ class PgleTest(jtu.JaxTestCase): 'xla_gpu_experimental_dump_fdo_profiles': 'True', } # TODO(b/376647494): Remove this flag once the bug is fixed. - if xla_extension_version <= 302: - compiler_options['xla_gpu_graph_min_graph_size'] = '100000' @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index b23c17022..53c637959 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -48,7 +48,6 @@ import jax.numpy as jnp from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.shard_map import shard_map -from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member config.parse_flags_with_absl() jtu.request_cpu_devices(8) @@ -2170,8 +2169,6 @@ class ShardMapTest(jtu.JaxTestCase): self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1)) def test_partial_auto_ppermute(self): - if xla_extension_version < 302: - self.skipTest('minimum xla extension version 302') if config.use_shardy_partitioner.value: self.skipTest('Shardy does not support full-to-shard.')