Always dispatch CPU executables synchronously when they include callbacks.

As discussed in https://github.com/jax-ml/jax/issues/25861 and https://github.com/jax-ml/jax/issues/24255, using host callbacks within an asynchronously-dispatched CPU executable can deadlock when the body of the callback itself asynchronously dispatches JAX CPU code. My rough understanding of the problem is that the XLA intra op thread pool gets filled up with callbacks waiting for their body to execute, but there aren't enough resources to schedule the inner computations.

There's probably a better way to fix this within XLA:CPU, but the temporary fix that I've come up with is to disable asynchronous dispatch on CPU when either:

1. Executing a program that includes any host callbacks, or
2. when running within the body of a callback.

It seems like both of these conditions are needed in general because I was able to find test cases that failed with just one or the other implemented.

This PR includes just the first change, and the second will be implemented in a follow-up.

PiperOrigin-RevId: 720777713
This commit is contained in:
Dan Foreman-Mackey 2025-01-28 18:22:55 -08:00 committed by jax authors
parent bf22b53cf4
commit 83457c115a

View File

@ -25,6 +25,7 @@ from jax._src import ad_checkpoint
from jax._src import debugging
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
import jax.numpy as jnp
import numpy as np
@ -58,6 +59,47 @@ class DebugCallbackTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, "callable"):
jax.debug.callback("this is not debug.print!")
@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.run_on_devices("cpu")
def test_async_deadlock(self):
if xla_extension_version < 306:
self.skipTest("deadlock expected")
# See https://github.com/jax-ml/jax/issues/25861
def print_it(i, maxiter):
self.assertIsInstance(i, jax.Array)
self.assertIsInstance(maxiter, jax.Array)
return i == maxiter # Using JAX here causes deadlock with async dispatch
def run(pos):
maxiter = 1000
def cond(v):
return v[0] < maxiter
def step(v):
i, pos = v
jax.debug.callback(print_it, i + 1, maxiter)
return i + 1, pos + 1
val = jnp.array(0), pos
val = jax.lax.while_loop(cond, step, val)
return val[1]
n_samples = 30
inputs = 10 * jax.random.normal(
jax.random.key(42), shape=(n_samples, 128, 128)
)
def mean(forest):
norm = 1.0 / len(forest)
add = lambda a, b: a + b
m = norm * functools.reduce(add, forest)
return m
post_mean = mean(tuple(run(x) for x in inputs))
jax.block_until_ready(post_mean) # This shouldn't deadlock.
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
class DebugPrintTest(jtu.JaxTestCase):