Disable flaky python callback test.

PiperOrigin-RevId: 575893965
This commit is contained in:
George Necula 2023-10-23 12:21:52 -07:00 committed by jax authors
parent 9b1a656c1e
commit 9bc04393b2

View File

@ -605,6 +605,8 @@ class EffectOrderingTest(jtu.JaxTestCase):
jax.effects_barrier()
self.assertListEqual(log, [2., 3.])
# TODO(b/307211483): Investigate failure
@jtu.skip_on_devices("tpu")
def test_ordered_effect_remains_ordered_across_multiple_devices(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
@ -632,8 +634,8 @@ class EffectOrderingTest(jtu.JaxTestCase):
f(jnp.ones((500, 500)))
g(3.)
jax.effects_barrier()
x_, y_ = float(jnp.log(1.25e8)), 3.
expected_log = [x_, y_, x_, y_, x_, y_]
f_, g_ = float(jnp.log(1.25e8)), 3.
expected_log = [f_, g_, f_, g_, f_, g_]
self.assertListEqual(log, expected_log)
def test_different_threads_get_different_tokens(self):