#sdy unskip JAX Shardy tests that are already passing

PiperOrigin-RevId: 718898708
This commit is contained in:
Bart Chrzaszcz 2025-01-23 09:26:04 -08:00 committed by jax authors
parent 4222c30cf0
commit db8c8fc37c
3 changed files with 2 additions and 14 deletions

View File

@ -456,8 +456,6 @@ class CompilationCacheTest(CompilationCacheTestCase):
self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING))
def test_persistent_cache_miss_logging_with_explain(self):
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
with (config.explain_cache_misses(True),
config.compilation_cache_dir("jax-cache")):
@ -502,8 +500,6 @@ class CompilationCacheTest(CompilationCacheTestCase):
def test_persistent_cache_miss_logging_with_no_explain(self):
# test that cache failure messages do not get logged in WARNING
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
with (config.explain_cache_misses(False),
config.compilation_cache_dir("jax-cache")):
# omitting writing to cache because compilation is too fast

View File

@ -27,7 +27,6 @@ from absl.testing import parameterized
import jax
from jax import lax
from jax import random
from jax._src import config
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import state
@ -1416,9 +1415,6 @@ class OpsTest(PallasBaseTest):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Test for TPU is covered in tpu_pallas_test.py")
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
# TODO: this test flakes on gpu
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test flakes on gpu")
@ -2254,8 +2250,6 @@ class OpsInterpretTest(OpsTest):
INTERPRET = True
def test_debug_print(self):
if config.use_shardy_partitioner.value:
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),

View File

@ -398,8 +398,6 @@ class PJitTest(jtu.BufferDonationTestCase):
@jtu.run_on_devices('tpu')
def testBufferDonationWithOutputShardingInferenceAndTokens(self):
if config.use_shardy_partitioner.value:
self.skipTest('b/355263220: Shardy does not support callbacks yet.')
mesh = jtu.create_mesh((2,), 'x')
s = NamedSharding(mesh, P('x'))
@ -4312,7 +4310,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_empty_io_callback_under_shard_map(self):
if config.use_shardy_partitioner.value:
self.skipTest("Shardy errors out on empty callbacks.")
self.skipTest("TODO(b/384938613): Failing under shardy.")
mesh = jtu.create_mesh((4,), 'i')
def empty_callback(x):
@ -4330,7 +4328,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_empty_io_callback_under_shard_map_reshard_to_singledev(self):
if config.use_shardy_partitioner.value:
self.skipTest("Shardy errors out on empty callbacks.")
self.skipTest("TODO(b/384938613): Failing under shardy.")
mesh = jtu.create_mesh((4,), 'i')
def empty_callback(x):