diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 83d5aa49d..3fcc0ab47 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -456,8 +456,6 @@ class CompilationCacheTest(CompilationCacheTestCase): self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING)) def test_persistent_cache_miss_logging_with_explain(self): - if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") with (config.explain_cache_misses(True), config.compilation_cache_dir("jax-cache")): @@ -502,8 +500,6 @@ class CompilationCacheTest(CompilationCacheTestCase): def test_persistent_cache_miss_logging_with_no_explain(self): # test that cache failure messages do not get logged in WARNING - if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") with (config.explain_cache_misses(False), config.compilation_cache_dir("jax-cache")): # omitting writing to cache because compilation is too fast diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index e0730b758..357dad6cd 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -27,7 +27,6 @@ from absl.testing import parameterized import jax from jax import lax from jax import random -from jax._src import config from jax._src import dtypes from jax._src import linear_util as lu from jax._src import state @@ -1416,9 +1415,6 @@ class OpsTest(PallasBaseTest): if jtu.test_device_matches(["tpu"]): self.skipTest("Test for TPU is covered in tpu_pallas_test.py") - if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") - # TODO: this test flakes on gpu if jtu.test_device_matches(["gpu"]): self.skipTest("This test flakes on gpu") @@ -2254,8 +2250,6 @@ class OpsInterpretTest(OpsTest): INTERPRET = True def test_debug_print(self): - if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8e0efe21c..6d7ce9cd6 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -398,8 +398,6 @@ class PJitTest(jtu.BufferDonationTestCase): @jtu.run_on_devices('tpu') def testBufferDonationWithOutputShardingInferenceAndTokens(self): - if config.use_shardy_partitioner.value: - self.skipTest('b/355263220: Shardy does not support callbacks yet.') mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) @@ -4312,7 +4310,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def test_empty_io_callback_under_shard_map(self): if config.use_shardy_partitioner.value: - self.skipTest("Shardy errors out on empty callbacks.") + self.skipTest("TODO(b/384938613): Failing under shardy.") mesh = jtu.create_mesh((4,), 'i') def empty_callback(x): @@ -4330,7 +4328,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def test_empty_io_callback_under_shard_map_reshard_to_singledev(self): if config.use_shardy_partitioner.value: - self.skipTest("Shardy errors out on empty callbacks.") + self.skipTest("TODO(b/384938613): Failing under shardy.") mesh = jtu.create_mesh((4,), 'i') def empty_callback(x):