[JAX Shardy] Unskip stream annotation test when shardy is enabled, since the underlying issue is now resolved.

PiperOrigin-RevId: 738802372
This commit is contained in:
Tom Natan 2025-03-20 08:01:08 -07:00 committed by jax authors
parent 7fa7db7a9f
commit c098b363fb

View File

@ -1664,8 +1664,6 @@ class StreamAnnotationTest(jtu.JaxTestCase):
def test_stream_annotation_inside_shmap(self):
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Stream annotation is only supported on GPU.")
if config.use_shardy_partitioner.value:
self.skipTest("Doesn't work with shardy")
mesh = jtu.create_mesh((2,), ('x',))
s = NamedSharding(mesh, P('x'))