diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 019d68cff..531161eff 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -25,6 +25,7 @@ from jax._src import ad_checkpoint from jax._src import debugging from jax._src import dispatch from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version import jax.numpy as jnp import numpy as np @@ -58,6 +59,47 @@ class DebugCallbackTest(jtu.JaxTestCase): with self.assertRaisesRegex(TypeError, "callable"): jax.debug.callback("this is not debug.print!") + @jtu.skip_on_flag("jax_skip_slow_tests", True) + @jtu.run_on_devices("cpu") + def test_async_deadlock(self): + if xla_extension_version < 306: + self.skipTest("deadlock expected") + + # See https://github.com/jax-ml/jax/issues/25861 + def print_it(i, maxiter): + self.assertIsInstance(i, jax.Array) + self.assertIsInstance(maxiter, jax.Array) + return i == maxiter # Using JAX here causes deadlock with async dispatch + + def run(pos): + maxiter = 1000 + + def cond(v): + return v[0] < maxiter + + def step(v): + i, pos = v + jax.debug.callback(print_it, i + 1, maxiter) + return i + 1, pos + 1 + + val = jnp.array(0), pos + val = jax.lax.while_loop(cond, step, val) + return val[1] + + n_samples = 30 + inputs = 10 * jax.random.normal( + jax.random.key(42), shape=(n_samples, 128, 128) + ) + + def mean(forest): + norm = 1.0 / len(forest) + add = lambda a, b: a + b + m = norm * functools.reduce(add, forest) + return m + + post_mean = mean(tuple(run(x) for x in inputs)) + jax.block_until_ready(post_mean) # This shouldn't deadlock. + @jtu.thread_unsafe_test_class() # printing isn't thread-safe class DebugPrintTest(jtu.JaxTestCase):