Make JAX test suite pass (at least most of the time) with multiple threads enabled.

Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
This commit is contained in:
Peter Hawkins 2025-01-10 06:58:01 -08:00 committed by jax authors
parent 86643a1b3e
commit c61b2f6b81
19 changed files with 88 additions and 38 deletions

View File

@ -1013,16 +1013,27 @@ if hasattr(util, 'Mutex'):
_test_rwlock = util.Mutex()
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
_test_rwlock.reader_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.reader_unlock()
if getattr(test.__class__, "thread_hostile", False):
_test_rwlock.writer_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.writer_unlock()
else:
_test_rwlock.reader_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.reader_unlock()
@contextmanager
def thread_hostile_test():
"Decorator for tests that are not thread-safe."
def thread_unsafe_test():
"""Decorator for tests that are not thread-safe.
Note: this decorator (naturally) only applies to what it wraps, not to, say,
code in separate setUp() or tearDown() methods.
"""
if TEST_NUM_THREADS.value <= 0:
yield
return
@ -1048,9 +1059,19 @@ else:
@contextmanager
def thread_hostile_test():
def thread_unsafe_test():
yield # No reader-writer lock, so we get no parallelism.
def thread_unsafe_test_class():
"Decorator that marks a TestCase class as thread-hostile."
def f(klass):
assert issubclass(klass, unittest.TestCase), type(klass)
klass.thread_hostile = True
return klass
return f
class ThreadSafeTestResult:
"""
Wraps a TestResult to make it thread safe.
@ -1074,8 +1095,9 @@ class ThreadSafeTestResult:
def stopTest(self, test: unittest.TestCase):
stop_time = time.time()
with self.lock:
# We assume test_result is an ABSL _TextAndXMLTestResult, so we can
# override how it gets the time.
# If test_result is an ABSL _TextAndXMLTestResult we override how it gets
# the time. This affects the timing that shows up in the XML output
# consumed by CI.
time_getter = getattr(self.test_result, "time_getter", None)
try:
self.test_result.time_getter = lambda: self.start_time

View File

@ -632,7 +632,7 @@ class JitTest(jtu.BufferDonationTestCase):
python_should_be_executing = False
jit(f)(3)
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # GC effects aren't predictable with threads
def test_jit_cache_clear(self):
@jit
def f(x, y):
@ -1605,6 +1605,7 @@ class APITest(jtu.JaxTestCase):
assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
@jtu.thread_unsafe_test() # Concurrent cache eviction means we may retrace.
def test_grad_of_jit(self):
side = []
@ -1618,6 +1619,7 @@ class APITest(jtu.JaxTestCase):
assert grad(f)(2.0) == 4.0
assert len(side) == 1
@jtu.thread_unsafe_test() # Concurrent ache eviction means we may retrace.
def test_jit_of_grad(self):
side = []
@ -2589,7 +2591,7 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False)
self.assertEqual(pytree[3], 4)
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # Weakref destruction seems unpredictable with threads
def test_devicearray_weakref_friendly(self):
x = device_put(1.)
y = weakref.ref(x)
@ -2738,7 +2740,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(count(), 1)
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # jit cache misses aren't thread safe
def test_jit_infer_params_cache(self):
def f(x):
return x
@ -3329,7 +3331,7 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, ".*is not a valid JAX type"):
jax.grad(lambda x: x)(x)
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # logging isn't thread-safe
def test_jit_compilation_time_logging(self):
@api.jit
def f(x):
@ -3418,7 +3420,7 @@ class APITest(jtu.JaxTestCase):
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertEqual(z2, 1)
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # monkey-patching mlir.jaxpr_subcomp isn't thread-safe
def test_nested_jit_hoisting(self):
@api.jit
def f(x, y):
@ -3456,7 +3458,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul')
self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add')
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
def test_primitive_compilation_cache(self):
with jtu.count_primitive_compiles() as count:
lax.add(1, 2)
@ -4016,7 +4018,7 @@ class APITest(jtu.JaxTestCase):
a2 = jnp.array(((x, x), [x, x]))
self.assertAllClose(np.array(((1, 1), (1, 1))), a2)
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # count_jit_tracing_cache_miss() isn't thread-safe
def test_eval_shape_weak_type(self):
# https://github.com/jax-ml/jax/issues/23302
arr = jax.numpy.array(1)
@ -4145,7 +4147,7 @@ class APITest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
self.assertIn('Precision.HIGH', str(jaxpr))
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # Updating global configs is not thread-safe.
def test_dot_precision_forces_retrace(self):
num_traces = 0
@ -4318,7 +4320,7 @@ class APITest(jtu.JaxTestCase):
api.make_jaxpr(lambda: jnp.array(3))()
self.assertEqual(count(), 0)
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # Updating global configs is not thread-safe.
def test_rank_promotion_forces_retrace(self):
num_traces = 0
@ -4459,7 +4461,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(jfoo.__qualname__, f"make_jaxpr({foo.__qualname__})")
self.assertEqual(jfoo.__module__, "jax")
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # Concurrent cache eviction means we may retrace
def test_inner_jit_function_retracing(self):
# https://github.com/jax-ml/jax/issues/7155
inner_count = outer_count = 0
@ -4507,6 +4509,7 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, r".*Received invalid value.*"):
jax.device_put(jnp.arange(8), 'cpu')
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_clear_cache(self):
@jax.jit
def add(x):
@ -4525,6 +4528,7 @@ class APITest(jtu.JaxTestCase):
tracing_add_count += 1
self.assertEqual(tracing_add_count, 2)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations(self):
@jax.jit
def f(x, y):
@ -4584,6 +4588,7 @@ class APITest(jtu.JaxTestCase):
msg = cm.output[0]
self.assertIn("tracing context doesn't match", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_new_function_in_loop(self):
@jax.jit
def f(x, y):
@ -4605,6 +4610,7 @@ class APITest(jtu.JaxTestCase):
_, msg = cm.output
self.assertIn('another function defined on the same line', msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_unpacks_transforms(self):
# Tests that the explain_tracing_cache_miss() function does not throw an
# error when unpacking `transforms` with a length greater than 3.
@ -4701,7 +4707,7 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "ndim of its first argument"):
jax.sharding.Mesh(jax.devices(), ("x", "y"))
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # weakref gc doesn't seem predictable
def test_jit_boundmethod_reference_cycle(self):
class A:
def __init__(self):
@ -4840,7 +4846,7 @@ class RematTest(jtu.JaxTestCase):
('_policy', partial(jax.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # monkey patches sin_p and cos_p
def test_remat_basic(self, remat):
@remat
def g(x):
@ -5178,7 +5184,7 @@ class RematTest(jtu.JaxTestCase):
('_policy', partial(jax.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # monkey patches sin_p
def test_remat_no_redundant_flops(self, remat):
# see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584
@ -6422,7 +6428,7 @@ class RematTest(jtu.JaxTestCase):
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # logging isn't thread-safe
def test_remat_residual_logging(self):
def f(x):
x = jnp.sin(x)
@ -11126,7 +11132,7 @@ class AutodidaxTest(jtu.JaxTestCase):
class GarbageCollectionTest(jtu.JaxTestCase):
@jtu.thread_hostile_test()
@jtu.thread_unsafe_test() # GC isn't predictable
def test_xla_gc_callback(self):
# https://github.com/jax-ml/jax/issues/14882
x_np = np.arange(10, dtype='int32')

View File

@ -890,6 +890,7 @@ class ShardingTest(jtu.JaxTestCase):
r"factors: \[4, 2\] should evenly divide the shape\)"):
mps.shard_shape((8, 3))
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_pmap_sharding_hash_eq(self):
if jax.device_count() < 2:
self.skipTest('Test needs >= 2 devices.')

View File

@ -1326,6 +1326,7 @@ def list_insert(lst: list[a], idx: int, val: a) -> list[a]:
return lst
@jtu.thread_unsafe_test_class() # temporary registration isn't thread-safe
class VmappableTest(jtu.JaxTestCase):
def test_basic(self):
with temporarily_register_named_array_vmappable():

View File

@ -95,6 +95,7 @@ def clear_cache() -> None:
cc._cache.clear()
@jtu.thread_unsafe_test_class() # mocking isn't thread-safe
class CompilationCacheTestCase(jtu.JaxTestCase):
def setUp(self):

View File

@ -293,6 +293,7 @@ class CoreTest(jtu.JaxTestCase):
assert d2_sin(0.0) == 0.0
assert d3_sin(0.0) == -1.0
@jtu.thread_unsafe_test() # gc isn't predictable when threaded
def test_reference_cycles(self):
gc.collect()
@ -310,6 +311,7 @@ class CoreTest(jtu.JaxTestCase):
finally:
gc.set_debug(debug)
@jtu.thread_unsafe_test() # gc isn't predictable when threaded
def test_reference_cycles_jit(self):
gc.collect()

View File

@ -59,6 +59,7 @@ class DebugCallbackTest(jtu.JaxTestCase):
jax.debug.callback("this is not debug.print!")
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
class DebugPrintTest(jtu.JaxTestCase):
def tearDown(self):
@ -236,6 +237,7 @@ class DebugPrintTest(jtu.JaxTestCase):
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
class DebugPrintTransformationTest(jtu.JaxTestCase):
def test_debug_print_batching(self):
@ -507,6 +509,7 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
jax.effects_barrier()
self.assertEqual(output(), "hello bwd: 2.0 3.0\n")
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
class DebugPrintControlFlowTest(jtu.JaxTestCase):
def _assertLinesEqual(self, text1, text2):
@ -722,6 +725,7 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
b3: 2
"""))
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
class DebugPrintParallelTest(jtu.JaxTestCase):
def _assertLinesEqual(self, text1, text2):

View File

@ -34,6 +34,7 @@ def _create_array_cycle():
return weakref.ref(n1)
@jtu.thread_unsafe_test_class() # GC isn't predictable when threaded.
class GarbageCollectionGuardTest(jtu.JaxTestCase):
def test_gced_array_is_not_logged_by_default(self):

View File

@ -28,6 +28,7 @@ import numpy as np
jax.config.parse_flags_with_absl()
@jtu.thread_unsafe_test_class() # infeed isn't thread-safe
class InfeedTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -269,6 +269,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
@jtu.thread_unsafe_test_class() # because of mlir.register_lowering calls
class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -2138,6 +2138,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = jnp.array([])
self.assertAllClose(ans, expected)
@jtu.thread_unsafe_test() # Cache eviction means we might retrace
def testCaching(self):
def cond(x):
assert python_should_be_executing
@ -3033,6 +3034,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(y, x)
self.assertIsInstance(y, jax.Array)
@jtu.thread_unsafe_test() # live_arrays count isn't thread-safe
def test_cond_memory_leak(self):
# https://github.com/jax-ml/jax/issues/12719

View File

@ -3886,6 +3886,9 @@ def bake_vmap(batched_args, batch_dims):
return ys, bdim_out
# All tests in this test class are thread-hostile because they add and remove
# primitives from global maps.
@jtu.thread_unsafe_test_class() # registration isn't thread-safe
class CustomElementTypesTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -197,6 +197,7 @@ def _callable_generators(dtype):
jax_debug_nans=True,
jax_numpy_rank_promotion='raise',
jax_traceback_filtering='off')
@jtu.thread_unsafe_test_class() # matplotlib isn't thread-safe
class LobpcgTest(jtu.JaxTestCase):
def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype):

View File

@ -123,6 +123,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
val = jax.random.normal(rng, ())
self.assert_committed_to_device(val, device)
@jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
def test_primitive_compilation_cache(self):
devices = self.get_devices()

View File

@ -2144,6 +2144,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
compiled = f.lower(core.ShapedArray(input_shape, jnp.float32)).compile()
compiled(a1) # no error
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_pjit_single_device_sharding_add(self):
a = np.array([1, 2, 3], dtype=jnp.float32)
b = np.array([4, 5, 6], dtype=jnp.float32)
@ -2399,6 +2400,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(out1_sharding_id, out3_sharding_id)
self.assertEqual(out2_sharding_id, out3_sharding_id)
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_out_sharding_indices_id_cache_hit(self):
shape = (8, 2)
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
@ -2863,6 +2865,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
x_is_tracer, y_is_tracer = False, True
assert f_mixed(x=2, y=3) == 1
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_pjit_kwargs(self):
a = jnp.arange(8.)
b = jnp.arange(4.)
@ -3507,6 +3510,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
self.assertEqual(cache_info4.misses, cache_info3.misses)
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_cache_hit_pjit_lower_with_cpp_cache_miss(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))
@ -3600,6 +3604,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out3.sharding, NamedSharding)
self.assertEqual(count(), 1)
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_jit_mul_sum_sharding_preserved(self):
if config.use_shardy_partitioner.value:
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")

View File

@ -15,6 +15,7 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
import contextlib
from functools import partial
import itertools as it
import gc
@ -3195,19 +3196,11 @@ class EagerPmapMixin:
def setUp(self):
super().setUp()
self.eager_pmap_enabled = config.eager_pmap.value
self.jit_disabled = config.disable_jit.value
config.update('jax_disable_jit', True)
config.update('jax_eager_pmap', True)
self.warning_ctx = jtu.ignore_warning(
message="Some donated buffers were not usable", category=UserWarning)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
config.update('jax_eager_pmap', self.eager_pmap_enabled)
config.update('jax_disable_jit', self.jit_disabled)
super().tearDown()
stack = contextlib.ExitStack()
stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True))
stack.enter_context(jtu.ignore_warning(
message="Some donated buffers were not usable", category=UserWarning))
self.addCleanup(stack.close)
@jtu.pytest_mark_if_available('multiaccelerator')
class PythonPmapEagerTest(EagerPmapMixin, PythonPmapTest):

View File

@ -53,6 +53,7 @@ except ImportError:
jax.config.parse_flags_with_absl()
@jtu.thread_unsafe_test_class() # profiler isn't thread-safe
class ProfilerTest(unittest.TestCase):
# These tests simply test that the profiler API does not crash; they do not
# check functional correctness.

View File

@ -548,6 +548,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
@with_pure_and_io_callbacks
@jtu.thread_unsafe_test() # logging isn't thread-safe
def test_exception_in_callback(self, *, callback):
def fail(x):
raise RuntimeError("Ooops")
@ -570,6 +571,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
self.assertIn("Traceback (most recent call last)", output)
@with_pure_and_io_callbacks
@jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
def test_compilation_caching(self, *, callback):
def f_outside(x):
return 2 * x

View File

@ -955,6 +955,7 @@ if CAN_USE_HYPOTHESIS:
assert next(idx_, None) is None
return idx
@jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe
class StateHypothesisTest(jtu.JaxTestCase):
@hp.given(get_vmap_params())
@ -1711,6 +1712,7 @@ if CAN_USE_HYPOTHESIS:
min_dim=max(f1.min_dim, f2.min_dim),
max_dim=min(f1.max_dim, f2.max_dim))
@jtu.thread_unsafe_test_class() # because of hypothesis
class RunStateHypothesisTest(jtu.JaxTestCase):
@jax.legacy_prng_key('allow')