mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove code that existed to support jaxlib < 0.4.32.
New minimum versions: * jaxlib 0.4.32 * xla_extension_version 283 * mlir_api_version 57 PiperOrigin-RevId: 675291231
This commit is contained in:
parent
df385b6ad3
commit
940860625e
@ -33,7 +33,6 @@ from jax._src import profiler
|
||||
from jax._src import traceback_util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
import numpy as np
|
||||
|
||||
@ -157,8 +156,7 @@ def get_compile_options(
|
||||
build_options = compile_options.executable_build_options
|
||||
build_options.use_spmd_partitioning = use_spmd_partitioning
|
||||
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
|
||||
if xla_extension_version >= 280:
|
||||
build_options.use_shardy_partitioner = use_shardy_partitioner
|
||||
build_options.use_shardy_partitioner = use_shardy_partitioner
|
||||
if fdo_profile is not None:
|
||||
build_options.fdo_profile = fdo_profile
|
||||
if use_auto_spmd_partitioning:
|
||||
|
@ -61,7 +61,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -3022,12 +3021,8 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.dispatch_registry, cc_shard_arg)
|
||||
|
||||
if xla_extension_version < 282:
|
||||
def cc_shard_arg(x, sharding):
|
||||
return shard_args([sharding], [None], [x])[0]
|
||||
else:
|
||||
def cc_shard_arg(x, sharding, layout): # type: ignore
|
||||
return shard_args([sharding], [layout], [x])[0]
|
||||
def cc_shard_arg(x, sharding, layout):
|
||||
return shard_args([sharding], [layout], [x])[0]
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
|
@ -514,11 +514,7 @@ def _cholesky_cpu_lowering(ctx, operand):
|
||||
out_aval, = ctx.avals_out
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
|
||||
# TODO(b/344892332): Remove the check after the compatibility period.
|
||||
if jaxlib_version < (0, 4, 31):
|
||||
ctx_arg = ()
|
||||
else:
|
||||
ctx_arg = (ctx,)
|
||||
ctx_arg = (ctx,)
|
||||
result, info = lapack.potrf_hlo(*ctx_arg, operand_aval.dtype, operand,
|
||||
lower=True, a_shape_vals=op_shape_vals)
|
||||
|
||||
@ -556,7 +552,7 @@ def _cholesky_update_abstract_eval(r_matrix, w_vector):
|
||||
|
||||
def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, w_vector):
|
||||
# TODO(b/360781533): Remove guard after 3 week forward compatibility period.
|
||||
if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32):
|
||||
if ctx.is_forward_compat():
|
||||
r_matrix_aval, _ = ctx.avals_in
|
||||
try:
|
||||
[platform] = ctx.module_context.platforms
|
||||
@ -726,8 +722,7 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
|
||||
out_aval = ctx.avals_out[0]
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period.
|
||||
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
|
||||
ctx_args = (ctx,)
|
||||
w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand,
|
||||
input_shape_vals=op_shape_vals,
|
||||
jobvl=compute_left_eigenvectors,
|
||||
@ -937,8 +932,7 @@ def _eigh_cpu_gpu_lowering(
|
||||
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
|
||||
cpu_args = []
|
||||
if platform == "cpu":
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period.
|
||||
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
|
||||
ctx_args = (ctx,)
|
||||
cpu_args.extend(ctx_args)
|
||||
v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand,
|
||||
a_shape_vals=op_shape_vals, lower=lower)
|
||||
@ -1511,9 +1505,9 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, platform: str,
|
||||
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
|
||||
m = operand_aval.shape[-2]
|
||||
|
||||
# TODO(b/357034884): Remove version gate once jaxlib 0.4.32 is the minimum
|
||||
# version and the forward compat flag after the 3 week compatibility window.
|
||||
if jaxlib_version < (0, 4, 32) or ctx.is_forward_compat():
|
||||
# TODO(b/357034884): Remove version gate on the forward compat flag after the
|
||||
# 3 week compatibility window.
|
||||
if ctx.is_forward_compat():
|
||||
if not is_constant_shape(operand_aval.shape[-2:]):
|
||||
raise NotImplementedError(
|
||||
"Shape polymorphism for native lowering for lu on CPU and GPU is "
|
||||
@ -1757,9 +1751,8 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *,
|
||||
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
|
||||
else:
|
||||
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period
|
||||
ctx_args = (
|
||||
(ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
|
||||
(ctx,) if platform == "cpu" else ()
|
||||
)
|
||||
a_out, taus, *maybe_info_geqrf = geqrf_impl(
|
||||
*ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals
|
||||
@ -1881,9 +1874,8 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *,
|
||||
f"on GPU is not implemented; b/261671778; {a_aval.shape}")
|
||||
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
|
||||
else:
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period
|
||||
ctx_args = (
|
||||
(ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
|
||||
(ctx,) if platform == "cpu" else ()
|
||||
)
|
||||
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
|
||||
tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape)
|
||||
@ -2152,8 +2144,7 @@ def _svd_cpu_gpu_lowering(
|
||||
compute_uv=compute_uv)
|
||||
else:
|
||||
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period.
|
||||
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else ()
|
||||
ctx_args = (ctx,)
|
||||
s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
|
@ -66,7 +66,6 @@ from jax.sharding import PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -190,14 +189,11 @@ class CompatTest(bctu.CompatTestBase):
|
||||
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
||||
|
||||
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name])
|
||||
# TODO(b/344892332): Remove the check after the compatibility period.
|
||||
has_xla_ffi_support = jaxlib_version >= (0, 4, 31)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
||||
if has_xla_ffi_support:
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
||||
@ -258,14 +254,11 @@ class CompatTest(bctu.CompatTestBase):
|
||||
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=check_eig_results)
|
||||
# TODO(b/344892332): Remove the check after the compatibility period.
|
||||
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
|
||||
if has_xla_ffi_support:
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=check_eig_results)
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=check_eig_results)
|
||||
|
||||
@staticmethod
|
||||
def eigh_input(shape, dtype):
|
||||
@ -316,14 +309,11 @@ class CompatTest(bctu.CompatTestBase):
|
||||
atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name]
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_eigh_results, operand))
|
||||
# TODO(b/344892332): Remove the check after the compatibility period.
|
||||
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
|
||||
if has_xla_ffi_support:
|
||||
# FFI Kernel test
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_eigh_results, operand))
|
||||
# FFI Kernel test
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_eigh_results, operand))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dtype={dtype_name}_{variant}",
|
||||
@ -385,8 +375,6 @@ class CompatTest(bctu.CompatTestBase):
|
||||
def test_cuda_lu_lapack_getrf(self, dtype_name:str):
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
if jaxlib_version < (0, 4, 32):
|
||||
self.skipTest("Not implemented in older versions of jaxlib")
|
||||
dtype = dict(f32=np.float32, f64=np.float64,
|
||||
c64=np.complex64, c128=np.complex128)[dtype_name]
|
||||
shape = (3, 4)
|
||||
@ -416,15 +404,12 @@ class CompatTest(bctu.CompatTestBase):
|
||||
data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name])
|
||||
rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name]
|
||||
self.run_one_test(func, data, rtol=rtol)
|
||||
# TODO(b/344892332): Remove the check after the compatibility period.
|
||||
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
|
||||
if has_xla_ffi_support:
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(
|
||||
cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol)
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(
|
||||
cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dtype={dtype_name}_{batched}",
|
||||
@ -502,14 +487,11 @@ class CompatTest(bctu.CompatTestBase):
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_lu_results, operand,
|
||||
dtype=dtype))
|
||||
# TODO(b/344892332): Remove the check after the compatibility period.
|
||||
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
|
||||
if has_xla_ffi_support:
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_lu_results, operand,
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_lu_results, operand,
|
||||
dtype=dtype))
|
||||
|
||||
def check_svd_results(self, input, res_run, res_exp,
|
||||
@ -629,16 +611,13 @@ class CompatTest(bctu.CompatTestBase):
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_svd_results,
|
||||
input))
|
||||
# TODO(b/344892332): Remove the check after the compatibility period.
|
||||
has_xla_ffi_support = jaxlib_version >= (0, 4, 32)
|
||||
if has_xla_ffi_support:
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(
|
||||
cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_svd_results, input))
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(
|
||||
cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=partial(self.check_svd_results, input))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
|
@ -15,7 +15,6 @@
|
||||
import contextlib
|
||||
import math
|
||||
from functools import partial
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
@ -511,8 +510,6 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
'Layout passed to jit does not match the layout on the respective arg'):
|
||||
g(arr)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 282,
|
||||
"Requires xla_extension_version >= 282")
|
||||
def test_in_layouts_jit_jnp_input(self):
|
||||
major_last_layout = DLL(major_to_minor=(1, 0))
|
||||
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
@ -2194,9 +2193,6 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
||||
symmetrize_output=[True, False],
|
||||
)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@unittest.skipIf(
|
||||
jax._src.lib.version < (0, 4, 32), "requires jaxlib >= 0.4.32"
|
||||
)
|
||||
def testSymmetricProduct(self, shape, dtype, symmetrize_output):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
batch_size = 10
|
||||
|
@ -56,7 +56,6 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import dialects
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.util import curry, unzip2
|
||||
|
||||
@ -4433,8 +4432,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
"Compiled object called with input sharding.*does not match"):
|
||||
compiled(cpu_arr)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 281,
|
||||
'Requires xla_extension_version >= 281')
|
||||
def test_different_devices_wsc_abstract_mesh_cache_hit(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest('Requires >=4 devices')
|
||||
@ -4463,8 +4460,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(lowering_count[0], 1)
|
||||
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 281,
|
||||
'Requires xla_extension_version >= 281')
|
||||
def test_wsc_abstract_mesh(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
@ -4484,8 +4479,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out_eager, np_inp * 2)
|
||||
self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x')))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 281,
|
||||
'Requires xla_extension_version >= 281')
|
||||
def test_wsc_sds_abstract_mesh(self):
|
||||
mesh = jtu.create_mesh((2,), 'x')
|
||||
s = NamedSharding(mesh, P())
|
||||
@ -4499,8 +4492,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
sds = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s)
|
||||
f.eval_shape(sds) # doesn't crash
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 281,
|
||||
'Requires xla_extension_version >= 281')
|
||||
def test_wsc_vmap_abstract_mesh(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
@ -4517,8 +4508,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out2 = jax.jit(jax.vmap(f, spmd_axis_name='y'))(arr)
|
||||
self.assertEqual(out2.sharding, NamedSharding(mesh, P('y', 'x')))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 281,
|
||||
'Requires xla_extension_version >= 281')
|
||||
def test_wsc_abstract_mesh_errors(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
np_inp = np.arange(8)
|
||||
|
@ -2843,11 +2843,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
((2, 3, 8, 4), "b1, b2, ..."),
|
||||
((2, 3, 4, 5), "b1, b2, m, n"),
|
||||
]
|
||||
# TODO(danfm): Remove once jaxlib v0.4.32 is the minimum version.
|
||||
# jaxlib versions before 0.4.32 require a static shape for the non-batch
|
||||
# dimensions because these are used for computing the "permuation_size"
|
||||
# which is passed to lu_pivots_to_permutation.
|
||||
if jaxlib_version >= (0, 4, 32) or not poly.endswith("m, n")
|
||||
],
|
||||
[
|
||||
# The random primitive tests, with threefry (both partitionable and
|
||||
|
@ -45,7 +45,6 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax.experimental.shard_map import shard_map
|
||||
@ -777,8 +776,6 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
# is over an axis of size 2. This is a problem at the moment.
|
||||
jax.make_jaxpr(mapped)(x, y).jaxpr
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 281,
|
||||
'Requires xla_extension_version >= 281')
|
||||
def test_shard_map_abstract_mesh(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
@ -803,8 +800,6 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out2, np_inp)
|
||||
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 281,
|
||||
'Requires xla_extension_version >= 281')
|
||||
def test_different_devices_shmap_abstract_mesh_cache_hit(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest('Requires >=4 devices')
|
||||
@ -835,8 +830,6 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(lowering_count[0], 1)
|
||||
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 281,
|
||||
'Requires xla_extension_version >= 281')
|
||||
def test_shmap_abstract_mesh_errors(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
np_inp = np.arange(8)
|
||||
|
Loading…
x
Reference in New Issue
Block a user