mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
#sdy unskip JAX Shardy tests that are already passing
PiperOrigin-RevId: 718898708
This commit is contained in:
parent
4222c30cf0
commit
db8c8fc37c
@ -456,8 +456,6 @@ class CompilationCacheTest(CompilationCacheTestCase):
|
||||
self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING))
|
||||
|
||||
def test_persistent_cache_miss_logging_with_explain(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
|
||||
with (config.explain_cache_misses(True),
|
||||
config.compilation_cache_dir("jax-cache")):
|
||||
|
||||
@ -502,8 +500,6 @@ class CompilationCacheTest(CompilationCacheTestCase):
|
||||
|
||||
def test_persistent_cache_miss_logging_with_no_explain(self):
|
||||
# test that cache failure messages do not get logged in WARNING
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
|
||||
with (config.explain_cache_misses(False),
|
||||
config.compilation_cache_dir("jax-cache")):
|
||||
# omitting writing to cache because compilation is too fast
|
||||
|
@ -27,7 +27,6 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import state
|
||||
@ -1416,9 +1415,6 @@ class OpsTest(PallasBaseTest):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Test for TPU is covered in tpu_pallas_test.py")
|
||||
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
|
||||
|
||||
# TODO: this test flakes on gpu
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("This test flakes on gpu")
|
||||
@ -2254,8 +2250,6 @@ class OpsInterpretTest(OpsTest):
|
||||
INTERPRET = True
|
||||
|
||||
def test_debug_print(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet")
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
|
||||
|
@ -398,8 +398,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
@jtu.run_on_devices('tpu')
|
||||
def testBufferDonationWithOutputShardingInferenceAndTokens(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest('b/355263220: Shardy does not support callbacks yet.')
|
||||
mesh = jtu.create_mesh((2,), 'x')
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
|
||||
@ -4312,7 +4310,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
def test_empty_io_callback_under_shard_map(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("Shardy errors out on empty callbacks.")
|
||||
self.skipTest("TODO(b/384938613): Failing under shardy.")
|
||||
mesh = jtu.create_mesh((4,), 'i')
|
||||
|
||||
def empty_callback(x):
|
||||
@ -4330,7 +4328,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
def test_empty_io_callback_under_shard_map_reshard_to_singledev(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("Shardy errors out on empty callbacks.")
|
||||
self.skipTest("TODO(b/384938613): Failing under shardy.")
|
||||
mesh = jtu.create_mesh((4,), 'i')
|
||||
|
||||
def empty_callback(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user