mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Always dispatch CPU executables synchronously when they include callbacks.
As discussed in https://github.com/jax-ml/jax/issues/25861 and https://github.com/jax-ml/jax/issues/24255, using host callbacks within an asynchronously-dispatched CPU executable can deadlock when the body of the callback itself asynchronously dispatches JAX CPU code. My rough understanding of the problem is that the XLA intra op thread pool gets filled up with callbacks waiting for their body to execute, but there aren't enough resources to schedule the inner computations. There's probably a better way to fix this within XLA:CPU, but the temporary fix that I've come up with is to disable asynchronous dispatch on CPU when either: 1. Executing a program that includes any host callbacks, or 2. when running within the body of a callback. It seems like both of these conditions are needed in general because I was able to find test cases that failed with just one or the other implemented. This PR includes just the first change, and the second will be implemented in a follow-up. PiperOrigin-RevId: 720777713
This commit is contained in:
parent
bf22b53cf4
commit
83457c115a
@ -25,6 +25,7 @@ from jax._src import ad_checkpoint
|
||||
from jax._src import debugging
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -58,6 +59,47 @@ class DebugCallbackTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, "callable"):
|
||||
jax.debug.callback("this is not debug.print!")
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
@jtu.run_on_devices("cpu")
|
||||
def test_async_deadlock(self):
|
||||
if xla_extension_version < 306:
|
||||
self.skipTest("deadlock expected")
|
||||
|
||||
# See https://github.com/jax-ml/jax/issues/25861
|
||||
def print_it(i, maxiter):
|
||||
self.assertIsInstance(i, jax.Array)
|
||||
self.assertIsInstance(maxiter, jax.Array)
|
||||
return i == maxiter # Using JAX here causes deadlock with async dispatch
|
||||
|
||||
def run(pos):
|
||||
maxiter = 1000
|
||||
|
||||
def cond(v):
|
||||
return v[0] < maxiter
|
||||
|
||||
def step(v):
|
||||
i, pos = v
|
||||
jax.debug.callback(print_it, i + 1, maxiter)
|
||||
return i + 1, pos + 1
|
||||
|
||||
val = jnp.array(0), pos
|
||||
val = jax.lax.while_loop(cond, step, val)
|
||||
return val[1]
|
||||
|
||||
n_samples = 30
|
||||
inputs = 10 * jax.random.normal(
|
||||
jax.random.key(42), shape=(n_samples, 128, 128)
|
||||
)
|
||||
|
||||
def mean(forest):
|
||||
norm = 1.0 / len(forest)
|
||||
add = lambda a, b: a + b
|
||||
m = norm * functools.reduce(add, forest)
|
||||
return m
|
||||
|
||||
post_mean = mean(tuple(run(x) for x in inputs))
|
||||
jax.block_until_ready(post_mean) # This shouldn't deadlock.
|
||||
|
||||
|
||||
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
|
||||
class DebugPrintTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user