mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Disable flaky python callback test.
PiperOrigin-RevId: 575893965
This commit is contained in:
parent
9b1a656c1e
commit
9bc04393b2
@ -605,6 +605,8 @@ class EffectOrderingTest(jtu.JaxTestCase):
|
||||
jax.effects_barrier()
|
||||
self.assertListEqual(log, [2., 3.])
|
||||
|
||||
# TODO(b/307211483): Investigate failure
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_ordered_effect_remains_ordered_across_multiple_devices(self):
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
@ -632,8 +634,8 @@ class EffectOrderingTest(jtu.JaxTestCase):
|
||||
f(jnp.ones((500, 500)))
|
||||
g(3.)
|
||||
jax.effects_barrier()
|
||||
x_, y_ = float(jnp.log(1.25e8)), 3.
|
||||
expected_log = [x_, y_, x_, y_, x_, y_]
|
||||
f_, g_ = float(jnp.log(1.25e8)), 3.
|
||||
expected_log = [f_, g_, f_, g_, f_, g_]
|
||||
self.assertListEqual(log, expected_log)
|
||||
|
||||
def test_different_threads_get_different_tokens(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user