From 83457c115ae7a115579b82a11931bc48d2b4bc37 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 28 Jan 2025 18:22:55 -0800 Subject: [PATCH] Always dispatch CPU executables synchronously when they include callbacks. As discussed in https://github.com/jax-ml/jax/issues/25861 and https://github.com/jax-ml/jax/issues/24255, using host callbacks within an asynchronously-dispatched CPU executable can deadlock when the body of the callback itself asynchronously dispatches JAX CPU code. My rough understanding of the problem is that the XLA intra op thread pool gets filled up with callbacks waiting for their body to execute, but there aren't enough resources to schedule the inner computations. There's probably a better way to fix this within XLA:CPU, but the temporary fix that I've come up with is to disable asynchronous dispatch on CPU when either: 1. Executing a program that includes any host callbacks, or 2. when running within the body of a callback. It seems like both of these conditions are needed in general because I was able to find test cases that failed with just one or the other implemented. This PR includes just the first change, and the second will be implemented in a follow-up. PiperOrigin-RevId: 720777713 --- tests/debugging_primitives_test.py | 42 ++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 019d68cff..531161eff 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -25,6 +25,7 @@ from jax._src import ad_checkpoint from jax._src import debugging from jax._src import dispatch from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version import jax.numpy as jnp import numpy as np @@ -58,6 +59,47 @@ class DebugCallbackTest(jtu.JaxTestCase): with self.assertRaisesRegex(TypeError, "callable"): jax.debug.callback("this is not debug.print!") + @jtu.skip_on_flag("jax_skip_slow_tests", True) + @jtu.run_on_devices("cpu") + def test_async_deadlock(self): + if xla_extension_version < 306: + self.skipTest("deadlock expected") + + # See https://github.com/jax-ml/jax/issues/25861 + def print_it(i, maxiter): + self.assertIsInstance(i, jax.Array) + self.assertIsInstance(maxiter, jax.Array) + return i == maxiter # Using JAX here causes deadlock with async dispatch + + def run(pos): + maxiter = 1000 + + def cond(v): + return v[0] < maxiter + + def step(v): + i, pos = v + jax.debug.callback(print_it, i + 1, maxiter) + return i + 1, pos + 1 + + val = jnp.array(0), pos + val = jax.lax.while_loop(cond, step, val) + return val[1] + + n_samples = 30 + inputs = 10 * jax.random.normal( + jax.random.key(42), shape=(n_samples, 128, 128) + ) + + def mean(forest): + norm = 1.0 / len(forest) + add = lambda a, b: a + b + m = norm * functools.reduce(add, forest) + return m + + post_mean = mean(tuple(run(x) for x in inputs)) + jax.block_until_ready(post_mean) # This shouldn't deadlock. + @jtu.thread_unsafe_test_class() # printing isn't thread-safe class DebugPrintTest(jtu.JaxTestCase):