Make JAX test suite pass (at least most of the time) with multiple threads enabled.

Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
This commit is contained in:
Peter Hawkins 2025-01-10 06:58:01 -08:00 committed by jax authors
parent 86643a1b3e
commit c61b2f6b81
19 changed files with 88 additions and 38 deletions

View File

@ -1013,16 +1013,27 @@ if hasattr(util, 'Mutex'):
_test_rwlock = util.Mutex() _test_rwlock = util.Mutex()
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
_test_rwlock.reader_lock() if getattr(test.__class__, "thread_hostile", False):
try: _test_rwlock.writer_lock()
test(result) # type: ignore try:
finally: test(result) # type: ignore
_test_rwlock.reader_unlock() finally:
_test_rwlock.writer_unlock()
else:
_test_rwlock.reader_lock()
try:
test(result) # type: ignore
finally:
_test_rwlock.reader_unlock()
@contextmanager @contextmanager
def thread_hostile_test(): def thread_unsafe_test():
"Decorator for tests that are not thread-safe." """Decorator for tests that are not thread-safe.
Note: this decorator (naturally) only applies to what it wraps, not to, say,
code in separate setUp() or tearDown() methods.
"""
if TEST_NUM_THREADS.value <= 0: if TEST_NUM_THREADS.value <= 0:
yield yield
return return
@ -1048,9 +1059,19 @@ else:
@contextmanager @contextmanager
def thread_hostile_test(): def thread_unsafe_test():
yield # No reader-writer lock, so we get no parallelism. yield # No reader-writer lock, so we get no parallelism.
def thread_unsafe_test_class():
"Decorator that marks a TestCase class as thread-hostile."
def f(klass):
assert issubclass(klass, unittest.TestCase), type(klass)
klass.thread_hostile = True
return klass
return f
class ThreadSafeTestResult: class ThreadSafeTestResult:
""" """
Wraps a TestResult to make it thread safe. Wraps a TestResult to make it thread safe.
@ -1074,8 +1095,9 @@ class ThreadSafeTestResult:
def stopTest(self, test: unittest.TestCase): def stopTest(self, test: unittest.TestCase):
stop_time = time.time() stop_time = time.time()
with self.lock: with self.lock:
# We assume test_result is an ABSL _TextAndXMLTestResult, so we can # If test_result is an ABSL _TextAndXMLTestResult we override how it gets
# override how it gets the time. # the time. This affects the timing that shows up in the XML output
# consumed by CI.
time_getter = getattr(self.test_result, "time_getter", None) time_getter = getattr(self.test_result, "time_getter", None)
try: try:
self.test_result.time_getter = lambda: self.start_time self.test_result.time_getter = lambda: self.start_time

View File

@ -632,7 +632,7 @@ class JitTest(jtu.BufferDonationTestCase):
python_should_be_executing = False python_should_be_executing = False
jit(f)(3) jit(f)(3)
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # GC effects aren't predictable with threads
def test_jit_cache_clear(self): def test_jit_cache_clear(self):
@jit @jit
def f(x, y): def f(x, y):
@ -1605,6 +1605,7 @@ class APITest(jtu.JaxTestCase):
assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0) assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0)) assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
@jtu.thread_unsafe_test() # Concurrent cache eviction means we may retrace.
def test_grad_of_jit(self): def test_grad_of_jit(self):
side = [] side = []
@ -1618,6 +1619,7 @@ class APITest(jtu.JaxTestCase):
assert grad(f)(2.0) == 4.0 assert grad(f)(2.0) == 4.0
assert len(side) == 1 assert len(side) == 1
@jtu.thread_unsafe_test() # Concurrent ache eviction means we may retrace.
def test_jit_of_grad(self): def test_jit_of_grad(self):
side = [] side = []
@ -2589,7 +2591,7 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False) self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False)
self.assertEqual(pytree[3], 4) self.assertEqual(pytree[3], 4)
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # Weakref destruction seems unpredictable with threads
def test_devicearray_weakref_friendly(self): def test_devicearray_weakref_friendly(self):
x = device_put(1.) x = device_put(1.)
y = weakref.ref(x) y = weakref.ref(x)
@ -2738,7 +2740,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(count(), 1) self.assertEqual(count(), 1)
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # jit cache misses aren't thread safe
def test_jit_infer_params_cache(self): def test_jit_infer_params_cache(self):
def f(x): def f(x):
return x return x
@ -3329,7 +3331,7 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, ".*is not a valid JAX type"): with self.assertRaisesRegex(TypeError, ".*is not a valid JAX type"):
jax.grad(lambda x: x)(x) jax.grad(lambda x: x)(x)
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # logging isn't thread-safe
def test_jit_compilation_time_logging(self): def test_jit_compilation_time_logging(self):
@api.jit @api.jit
def f(x): def f(x):
@ -3418,7 +3420,7 @@ class APITest(jtu.JaxTestCase):
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer()) self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertEqual(z2, 1) self.assertEqual(z2, 1)
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # monkey-patching mlir.jaxpr_subcomp isn't thread-safe
def test_nested_jit_hoisting(self): def test_nested_jit_hoisting(self):
@api.jit @api.jit
def f(x, y): def f(x, y):
@ -3456,7 +3458,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul') self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul')
self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add') self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add')
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
def test_primitive_compilation_cache(self): def test_primitive_compilation_cache(self):
with jtu.count_primitive_compiles() as count: with jtu.count_primitive_compiles() as count:
lax.add(1, 2) lax.add(1, 2)
@ -4016,7 +4018,7 @@ class APITest(jtu.JaxTestCase):
a2 = jnp.array(((x, x), [x, x])) a2 = jnp.array(((x, x), [x, x]))
self.assertAllClose(np.array(((1, 1), (1, 1))), a2) self.assertAllClose(np.array(((1, 1), (1, 1))), a2)
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # count_jit_tracing_cache_miss() isn't thread-safe
def test_eval_shape_weak_type(self): def test_eval_shape_weak_type(self):
# https://github.com/jax-ml/jax/issues/23302 # https://github.com/jax-ml/jax/issues/23302
arr = jax.numpy.array(1) arr = jax.numpy.array(1)
@ -4145,7 +4147,7 @@ class APITest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(jnp.dot)(x, x) jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
self.assertIn('Precision.HIGH', str(jaxpr)) self.assertIn('Precision.HIGH', str(jaxpr))
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # Updating global configs is not thread-safe.
def test_dot_precision_forces_retrace(self): def test_dot_precision_forces_retrace(self):
num_traces = 0 num_traces = 0
@ -4318,7 +4320,7 @@ class APITest(jtu.JaxTestCase):
api.make_jaxpr(lambda: jnp.array(3))() api.make_jaxpr(lambda: jnp.array(3))()
self.assertEqual(count(), 0) self.assertEqual(count(), 0)
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # Updating global configs is not thread-safe.
def test_rank_promotion_forces_retrace(self): def test_rank_promotion_forces_retrace(self):
num_traces = 0 num_traces = 0
@ -4459,7 +4461,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(jfoo.__qualname__, f"make_jaxpr({foo.__qualname__})") self.assertEqual(jfoo.__qualname__, f"make_jaxpr({foo.__qualname__})")
self.assertEqual(jfoo.__module__, "jax") self.assertEqual(jfoo.__module__, "jax")
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # Concurrent cache eviction means we may retrace
def test_inner_jit_function_retracing(self): def test_inner_jit_function_retracing(self):
# https://github.com/jax-ml/jax/issues/7155 # https://github.com/jax-ml/jax/issues/7155
inner_count = outer_count = 0 inner_count = outer_count = 0
@ -4507,6 +4509,7 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, r".*Received invalid value.*"): with self.assertRaisesRegex(ValueError, r".*Received invalid value.*"):
jax.device_put(jnp.arange(8), 'cpu') jax.device_put(jnp.arange(8), 'cpu')
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_clear_cache(self): def test_clear_cache(self):
@jax.jit @jax.jit
def add(x): def add(x):
@ -4525,6 +4528,7 @@ class APITest(jtu.JaxTestCase):
tracing_add_count += 1 tracing_add_count += 1
self.assertEqual(tracing_add_count, 2) self.assertEqual(tracing_add_count, 2)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations(self): def test_cache_miss_explanations(self):
@jax.jit @jax.jit
def f(x, y): def f(x, y):
@ -4584,6 +4588,7 @@ class APITest(jtu.JaxTestCase):
msg = cm.output[0] msg = cm.output[0]
self.assertIn("tracing context doesn't match", msg) self.assertIn("tracing context doesn't match", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_new_function_in_loop(self): def test_cache_miss_explanations_new_function_in_loop(self):
@jax.jit @jax.jit
def f(x, y): def f(x, y):
@ -4605,6 +4610,7 @@ class APITest(jtu.JaxTestCase):
_, msg = cm.output _, msg = cm.output
self.assertIn('another function defined on the same line', msg) self.assertIn('another function defined on the same line', msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_unpacks_transforms(self): def test_cache_miss_explanations_unpacks_transforms(self):
# Tests that the explain_tracing_cache_miss() function does not throw an # Tests that the explain_tracing_cache_miss() function does not throw an
# error when unpacking `transforms` with a length greater than 3. # error when unpacking `transforms` with a length greater than 3.
@ -4701,7 +4707,7 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "ndim of its first argument"): with self.assertRaisesRegex(ValueError, "ndim of its first argument"):
jax.sharding.Mesh(jax.devices(), ("x", "y")) jax.sharding.Mesh(jax.devices(), ("x", "y"))
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # weakref gc doesn't seem predictable
def test_jit_boundmethod_reference_cycle(self): def test_jit_boundmethod_reference_cycle(self):
class A: class A:
def __init__(self): def __init__(self):
@ -4840,7 +4846,7 @@ class RematTest(jtu.JaxTestCase):
('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
]) ])
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # monkey patches sin_p and cos_p
def test_remat_basic(self, remat): def test_remat_basic(self, remat):
@remat @remat
def g(x): def g(x):
@ -5178,7 +5184,7 @@ class RematTest(jtu.JaxTestCase):
('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_policy', partial(jax.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
]) ])
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # monkey patches sin_p
def test_remat_no_redundant_flops(self, remat): def test_remat_no_redundant_flops(self, remat):
# see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584 # see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584
@ -6422,7 +6428,7 @@ class RematTest(jtu.JaxTestCase):
self.assertIn(' sin ', str(jaxpr)) self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr)) self.assertIn(' cos ', str(jaxpr))
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # logging isn't thread-safe
def test_remat_residual_logging(self): def test_remat_residual_logging(self):
def f(x): def f(x):
x = jnp.sin(x) x = jnp.sin(x)
@ -11126,7 +11132,7 @@ class AutodidaxTest(jtu.JaxTestCase):
class GarbageCollectionTest(jtu.JaxTestCase): class GarbageCollectionTest(jtu.JaxTestCase):
@jtu.thread_hostile_test() @jtu.thread_unsafe_test() # GC isn't predictable
def test_xla_gc_callback(self): def test_xla_gc_callback(self):
# https://github.com/jax-ml/jax/issues/14882 # https://github.com/jax-ml/jax/issues/14882
x_np = np.arange(10, dtype='int32') x_np = np.arange(10, dtype='int32')

View File

@ -890,6 +890,7 @@ class ShardingTest(jtu.JaxTestCase):
r"factors: \[4, 2\] should evenly divide the shape\)"): r"factors: \[4, 2\] should evenly divide the shape\)"):
mps.shard_shape((8, 3)) mps.shard_shape((8, 3))
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_pmap_sharding_hash_eq(self): def test_pmap_sharding_hash_eq(self):
if jax.device_count() < 2: if jax.device_count() < 2:
self.skipTest('Test needs >= 2 devices.') self.skipTest('Test needs >= 2 devices.')

View File

@ -1326,6 +1326,7 @@ def list_insert(lst: list[a], idx: int, val: a) -> list[a]:
return lst return lst
@jtu.thread_unsafe_test_class() # temporary registration isn't thread-safe
class VmappableTest(jtu.JaxTestCase): class VmappableTest(jtu.JaxTestCase):
def test_basic(self): def test_basic(self):
with temporarily_register_named_array_vmappable(): with temporarily_register_named_array_vmappable():

View File

@ -95,6 +95,7 @@ def clear_cache() -> None:
cc._cache.clear() cc._cache.clear()
@jtu.thread_unsafe_test_class() # mocking isn't thread-safe
class CompilationCacheTestCase(jtu.JaxTestCase): class CompilationCacheTestCase(jtu.JaxTestCase):
def setUp(self): def setUp(self):

View File

@ -293,6 +293,7 @@ class CoreTest(jtu.JaxTestCase):
assert d2_sin(0.0) == 0.0 assert d2_sin(0.0) == 0.0
assert d3_sin(0.0) == -1.0 assert d3_sin(0.0) == -1.0
@jtu.thread_unsafe_test() # gc isn't predictable when threaded
def test_reference_cycles(self): def test_reference_cycles(self):
gc.collect() gc.collect()
@ -310,6 +311,7 @@ class CoreTest(jtu.JaxTestCase):
finally: finally:
gc.set_debug(debug) gc.set_debug(debug)
@jtu.thread_unsafe_test() # gc isn't predictable when threaded
def test_reference_cycles_jit(self): def test_reference_cycles_jit(self):
gc.collect() gc.collect()

View File

@ -59,6 +59,7 @@ class DebugCallbackTest(jtu.JaxTestCase):
jax.debug.callback("this is not debug.print!") jax.debug.callback("this is not debug.print!")
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
class DebugPrintTest(jtu.JaxTestCase): class DebugPrintTest(jtu.JaxTestCase):
def tearDown(self): def tearDown(self):
@ -236,6 +237,7 @@ class DebugPrintTest(jtu.JaxTestCase):
self.assertEqual(output(), "[1.23 2.35 0. ]\n") self.assertEqual(output(), "[1.23 2.35 0. ]\n")
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
class DebugPrintTransformationTest(jtu.JaxTestCase): class DebugPrintTransformationTest(jtu.JaxTestCase):
def test_debug_print_batching(self): def test_debug_print_batching(self):
@ -507,6 +509,7 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
jax.effects_barrier() jax.effects_barrier()
self.assertEqual(output(), "hello bwd: 2.0 3.0\n") self.assertEqual(output(), "hello bwd: 2.0 3.0\n")
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
class DebugPrintControlFlowTest(jtu.JaxTestCase): class DebugPrintControlFlowTest(jtu.JaxTestCase):
def _assertLinesEqual(self, text1, text2): def _assertLinesEqual(self, text1, text2):
@ -722,6 +725,7 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
b3: 2 b3: 2
""")) """))
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
class DebugPrintParallelTest(jtu.JaxTestCase): class DebugPrintParallelTest(jtu.JaxTestCase):
def _assertLinesEqual(self, text1, text2): def _assertLinesEqual(self, text1, text2):

View File

@ -34,6 +34,7 @@ def _create_array_cycle():
return weakref.ref(n1) return weakref.ref(n1)
@jtu.thread_unsafe_test_class() # GC isn't predictable when threaded.
class GarbageCollectionGuardTest(jtu.JaxTestCase): class GarbageCollectionGuardTest(jtu.JaxTestCase):
def test_gced_array_is_not_logged_by_default(self): def test_gced_array_is_not_logged_by_default(self):

View File

@ -28,6 +28,7 @@ import numpy as np
jax.config.parse_flags_with_absl() jax.config.parse_flags_with_absl()
@jtu.thread_unsafe_test_class() # infeed isn't thread-safe
class InfeedTest(jtu.JaxTestCase): class InfeedTest(jtu.JaxTestCase):
def setUp(self): def setUp(self):

View File

@ -269,6 +269,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect}) self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
@jtu.thread_unsafe_test_class() # because of mlir.register_lowering calls
class EffectfulJaxprLoweringTest(jtu.JaxTestCase): class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def setUp(self): def setUp(self):

View File

@ -2138,6 +2138,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = jnp.array([]) expected = jnp.array([])
self.assertAllClose(ans, expected) self.assertAllClose(ans, expected)
@jtu.thread_unsafe_test() # Cache eviction means we might retrace
def testCaching(self): def testCaching(self):
def cond(x): def cond(x):
assert python_should_be_executing assert python_should_be_executing
@ -3033,6 +3034,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(y, x) self.assertEqual(y, x)
self.assertIsInstance(y, jax.Array) self.assertIsInstance(y, jax.Array)
@jtu.thread_unsafe_test() # live_arrays count isn't thread-safe
def test_cond_memory_leak(self): def test_cond_memory_leak(self):
# https://github.com/jax-ml/jax/issues/12719 # https://github.com/jax-ml/jax/issues/12719

View File

@ -3886,6 +3886,9 @@ def bake_vmap(batched_args, batch_dims):
return ys, bdim_out return ys, bdim_out
# All tests in this test class are thread-hostile because they add and remove
# primitives from global maps.
@jtu.thread_unsafe_test_class() # registration isn't thread-safe
class CustomElementTypesTest(jtu.JaxTestCase): class CustomElementTypesTest(jtu.JaxTestCase):
def setUp(self): def setUp(self):

View File

@ -197,6 +197,7 @@ def _callable_generators(dtype):
jax_debug_nans=True, jax_debug_nans=True,
jax_numpy_rank_promotion='raise', jax_numpy_rank_promotion='raise',
jax_traceback_filtering='off') jax_traceback_filtering='off')
@jtu.thread_unsafe_test_class() # matplotlib isn't thread-safe
class LobpcgTest(jtu.JaxTestCase): class LobpcgTest(jtu.JaxTestCase):
def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype): def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype):

View File

@ -123,6 +123,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
val = jax.random.normal(rng, ()) val = jax.random.normal(rng, ())
self.assert_committed_to_device(val, device) self.assert_committed_to_device(val, device)
@jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
def test_primitive_compilation_cache(self): def test_primitive_compilation_cache(self):
devices = self.get_devices() devices = self.get_devices()

View File

@ -2144,6 +2144,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
compiled = f.lower(core.ShapedArray(input_shape, jnp.float32)).compile() compiled = f.lower(core.ShapedArray(input_shape, jnp.float32)).compile()
compiled(a1) # no error compiled(a1) # no error
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_pjit_single_device_sharding_add(self): def test_pjit_single_device_sharding_add(self):
a = np.array([1, 2, 3], dtype=jnp.float32) a = np.array([1, 2, 3], dtype=jnp.float32)
b = np.array([4, 5, 6], dtype=jnp.float32) b = np.array([4, 5, 6], dtype=jnp.float32)
@ -2399,6 +2400,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(out1_sharding_id, out3_sharding_id) self.assertEqual(out1_sharding_id, out3_sharding_id)
self.assertEqual(out2_sharding_id, out3_sharding_id) self.assertEqual(out2_sharding_id, out3_sharding_id)
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_out_sharding_indices_id_cache_hit(self): def test_out_sharding_indices_id_cache_hit(self):
shape = (8, 2) shape = (8, 2)
mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh = jtu.create_mesh((4, 2), ('x', 'y'))
@ -2863,6 +2865,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
x_is_tracer, y_is_tracer = False, True x_is_tracer, y_is_tracer = False, True
assert f_mixed(x=2, y=3) == 1 assert f_mixed(x=2, y=3) == 1
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_pjit_kwargs(self): def test_pjit_kwargs(self):
a = jnp.arange(8.) a = jnp.arange(8.)
b = jnp.arange(4.) b = jnp.arange(4.)
@ -3507,6 +3510,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(cache_info4.hits, cache_info3.hits + 1) self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
self.assertEqual(cache_info4.misses, cache_info3.misses) self.assertEqual(cache_info4.misses, cache_info3.misses)
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_cache_hit_pjit_lower_with_cpp_cache_miss(self): def test_cache_hit_pjit_lower_with_cpp_cache_miss(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y')) mesh = jtu.create_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x')) ns = NamedSharding(mesh, P('x'))
@ -3600,6 +3604,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out3.sharding, NamedSharding) self.assertIsInstance(out3.sharding, NamedSharding)
self.assertEqual(count(), 1) self.assertEqual(count(), 1)
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_jit_mul_sum_sharding_preserved(self): def test_jit_mul_sum_sharding_preserved(self):
if config.use_shardy_partitioner.value: if config.use_shardy_partitioner.value:
raise unittest.SkipTest("Shardy doesn't support PositionalSharding") raise unittest.SkipTest("Shardy doesn't support PositionalSharding")

View File

@ -15,6 +15,7 @@
from __future__ import annotations from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import contextlib
from functools import partial from functools import partial
import itertools as it import itertools as it
import gc import gc
@ -3195,19 +3196,11 @@ class EagerPmapMixin:
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.eager_pmap_enabled = config.eager_pmap.value stack = contextlib.ExitStack()
self.jit_disabled = config.disable_jit.value stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True))
config.update('jax_disable_jit', True) stack.enter_context(jtu.ignore_warning(
config.update('jax_eager_pmap', True) message="Some donated buffers were not usable", category=UserWarning))
self.warning_ctx = jtu.ignore_warning( self.addCleanup(stack.close)
message="Some donated buffers were not usable", category=UserWarning)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
config.update('jax_eager_pmap', self.eager_pmap_enabled)
config.update('jax_disable_jit', self.jit_disabled)
super().tearDown()
@jtu.pytest_mark_if_available('multiaccelerator') @jtu.pytest_mark_if_available('multiaccelerator')
class PythonPmapEagerTest(EagerPmapMixin, PythonPmapTest): class PythonPmapEagerTest(EagerPmapMixin, PythonPmapTest):

View File

@ -53,6 +53,7 @@ except ImportError:
jax.config.parse_flags_with_absl() jax.config.parse_flags_with_absl()
@jtu.thread_unsafe_test_class() # profiler isn't thread-safe
class ProfilerTest(unittest.TestCase): class ProfilerTest(unittest.TestCase):
# These tests simply test that the profiler API does not crash; they do not # These tests simply test that the profiler API does not crash; they do not
# check functional correctness. # check functional correctness.

View File

@ -548,6 +548,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.) np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
@with_pure_and_io_callbacks @with_pure_and_io_callbacks
@jtu.thread_unsafe_test() # logging isn't thread-safe
def test_exception_in_callback(self, *, callback): def test_exception_in_callback(self, *, callback):
def fail(x): def fail(x):
raise RuntimeError("Ooops") raise RuntimeError("Ooops")
@ -570,6 +571,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
self.assertIn("Traceback (most recent call last)", output) self.assertIn("Traceback (most recent call last)", output)
@with_pure_and_io_callbacks @with_pure_and_io_callbacks
@jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
def test_compilation_caching(self, *, callback): def test_compilation_caching(self, *, callback):
def f_outside(x): def f_outside(x):
return 2 * x return 2 * x

View File

@ -955,6 +955,7 @@ if CAN_USE_HYPOTHESIS:
assert next(idx_, None) is None assert next(idx_, None) is None
return idx return idx
@jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe
class StateHypothesisTest(jtu.JaxTestCase): class StateHypothesisTest(jtu.JaxTestCase):
@hp.given(get_vmap_params()) @hp.given(get_vmap_params())
@ -1711,6 +1712,7 @@ if CAN_USE_HYPOTHESIS:
min_dim=max(f1.min_dim, f2.min_dim), min_dim=max(f1.min_dim, f2.min_dim),
max_dim=min(f1.max_dim, f2.max_dim)) max_dim=min(f1.max_dim, f2.max_dim))
@jtu.thread_unsafe_test_class() # because of hypothesis
class RunStateHypothesisTest(jtu.JaxTestCase): class RunStateHypothesisTest(jtu.JaxTestCase):
@jax.legacy_prng_key('allow') @jax.legacy_prng_key('allow')