From 9bc04393b2a7ccdce458e2b9c4e47eeafb53cef4 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 23 Oct 2023 12:21:52 -0700 Subject: [PATCH] Disable flaky python callback test. PiperOrigin-RevId: 575893965 --- tests/jaxpr_effects_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 890a63d16..161665c15 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -605,6 +605,8 @@ class EffectOrderingTest(jtu.JaxTestCase): jax.effects_barrier() self.assertListEqual(log, [2., 3.]) + # TODO(b/307211483): Investigate failure + @jtu.skip_on_devices("tpu") def test_ordered_effect_remains_ordered_across_multiple_devices(self): if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") @@ -632,8 +634,8 @@ class EffectOrderingTest(jtu.JaxTestCase): f(jnp.ones((500, 500))) g(3.) jax.effects_barrier() - x_, y_ = float(jnp.log(1.25e8)), 3. - expected_log = [x_, y_, x_, y_, x_, y_] + f_, g_ = float(jnp.log(1.25e8)), 3. + expected_log = [f_, g_, f_, g_, f_, g_] self.assertListEqual(log, expected_log) def test_different_threads_get_different_tokens(self):