From c61b2f6b819c5dadab05d4778b67f28aeeb64b9d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 10 Jan 2025 06:58:01 -0800 Subject: [PATCH] Make JAX test suite pass (at least most of the time) with multiple threads enabled. Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile. PiperOrigin-RevId: 714037277 --- jax/_src/test_util.py | 42 ++++++++++++++++++++------ tests/api_test.py | 36 +++++++++++++--------- tests/array_test.py | 1 + tests/batching_test.py | 1 + tests/compilation_cache_test.py | 1 + tests/core_test.py | 2 ++ tests/debugging_primitives_test.py | 4 +++ tests/garbage_collection_guard_test.py | 1 + tests/infeed_test.py | 1 + tests/jaxpr_effects_test.py | 1 + tests/lax_control_flow_test.py | 2 ++ tests/lax_test.py | 3 ++ tests/lobpcg_test.py | 1 + tests/multi_device_test.py | 1 + tests/pjit_test.py | 5 +++ tests/pmap_test.py | 19 ++++-------- tests/profiler_test.py | 1 + tests/python_callback_test.py | 2 ++ tests/state_test.py | 2 ++ 19 files changed, 88 insertions(+), 38 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index eee288d1b..ce2e54ce3 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1013,16 +1013,27 @@ if hasattr(util, 'Mutex'): _test_rwlock = util.Mutex() def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): - _test_rwlock.reader_lock() - try: - test(result) # type: ignore - finally: - _test_rwlock.reader_unlock() + if getattr(test.__class__, "thread_hostile", False): + _test_rwlock.writer_lock() + try: + test(result) # type: ignore + finally: + _test_rwlock.writer_unlock() + else: + _test_rwlock.reader_lock() + try: + test(result) # type: ignore + finally: + _test_rwlock.reader_unlock() @contextmanager - def thread_hostile_test(): - "Decorator for tests that are not thread-safe." + def thread_unsafe_test(): + """Decorator for tests that are not thread-safe. + + Note: this decorator (naturally) only applies to what it wraps, not to, say, + code in separate setUp() or tearDown() methods. + """ if TEST_NUM_THREADS.value <= 0: yield return @@ -1048,9 +1059,19 @@ else: @contextmanager - def thread_hostile_test(): + def thread_unsafe_test(): yield # No reader-writer lock, so we get no parallelism. + +def thread_unsafe_test_class(): + "Decorator that marks a TestCase class as thread-hostile." + def f(klass): + assert issubclass(klass, unittest.TestCase), type(klass) + klass.thread_hostile = True + return klass + return f + + class ThreadSafeTestResult: """ Wraps a TestResult to make it thread safe. @@ -1074,8 +1095,9 @@ class ThreadSafeTestResult: def stopTest(self, test: unittest.TestCase): stop_time = time.time() with self.lock: - # We assume test_result is an ABSL _TextAndXMLTestResult, so we can - # override how it gets the time. + # If test_result is an ABSL _TextAndXMLTestResult we override how it gets + # the time. This affects the timing that shows up in the XML output + # consumed by CI. time_getter = getattr(self.test_result, "time_getter", None) try: self.test_result.time_getter = lambda: self.start_time diff --git a/tests/api_test.py b/tests/api_test.py index bff8ff3d0..379c63900 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -632,7 +632,7 @@ class JitTest(jtu.BufferDonationTestCase): python_should_be_executing = False jit(f)(3) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # GC effects aren't predictable with threads def test_jit_cache_clear(self): @jit def f(x, y): @@ -1605,6 +1605,7 @@ class APITest(jtu.JaxTestCase): assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0) assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0)) + @jtu.thread_unsafe_test() # Concurrent cache eviction means we may retrace. def test_grad_of_jit(self): side = [] @@ -1618,6 +1619,7 @@ class APITest(jtu.JaxTestCase): assert grad(f)(2.0) == 4.0 assert len(side) == 1 + @jtu.thread_unsafe_test() # Concurrent ache eviction means we may retrace. def test_jit_of_grad(self): side = [] @@ -2589,7 +2591,7 @@ class APITest(jtu.JaxTestCase): self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False) self.assertEqual(pytree[3], 4) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # Weakref destruction seems unpredictable with threads def test_devicearray_weakref_friendly(self): x = device_put(1.) y = weakref.ref(x) @@ -2738,7 +2740,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(count(), 1) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # jit cache misses aren't thread safe def test_jit_infer_params_cache(self): def f(x): return x @@ -3329,7 +3331,7 @@ class APITest(jtu.JaxTestCase): with self.assertRaisesRegex(TypeError, ".*is not a valid JAX type"): jax.grad(lambda x: x)(x) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # logging isn't thread-safe def test_jit_compilation_time_logging(self): @api.jit def f(x): @@ -3418,7 +3420,7 @@ class APITest(jtu.JaxTestCase): self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer()) self.assertEqual(z2, 1) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # monkey-patching mlir.jaxpr_subcomp isn't thread-safe def test_nested_jit_hoisting(self): @api.jit def f(x, y): @@ -3456,7 +3458,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul') self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add') - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe def test_primitive_compilation_cache(self): with jtu.count_primitive_compiles() as count: lax.add(1, 2) @@ -4016,7 +4018,7 @@ class APITest(jtu.JaxTestCase): a2 = jnp.array(((x, x), [x, x])) self.assertAllClose(np.array(((1, 1), (1, 1))), a2) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # count_jit_tracing_cache_miss() isn't thread-safe def test_eval_shape_weak_type(self): # https://github.com/jax-ml/jax/issues/23302 arr = jax.numpy.array(1) @@ -4145,7 +4147,7 @@ class APITest(jtu.JaxTestCase): jaxpr = jax.make_jaxpr(jnp.dot)(x, x) self.assertIn('Precision.HIGH', str(jaxpr)) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # Updating global configs is not thread-safe. def test_dot_precision_forces_retrace(self): num_traces = 0 @@ -4318,7 +4320,7 @@ class APITest(jtu.JaxTestCase): api.make_jaxpr(lambda: jnp.array(3))() self.assertEqual(count(), 0) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # Updating global configs is not thread-safe. def test_rank_promotion_forces_retrace(self): num_traces = 0 @@ -4459,7 +4461,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(jfoo.__qualname__, f"make_jaxpr({foo.__qualname__})") self.assertEqual(jfoo.__module__, "jax") - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # Concurrent cache eviction means we may retrace def test_inner_jit_function_retracing(self): # https://github.com/jax-ml/jax/issues/7155 inner_count = outer_count = 0 @@ -4507,6 +4509,7 @@ class APITest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, r".*Received invalid value.*"): jax.device_put(jnp.arange(8), 'cpu') + @jtu.thread_unsafe_test() # logging is not thread-safe def test_clear_cache(self): @jax.jit def add(x): @@ -4525,6 +4528,7 @@ class APITest(jtu.JaxTestCase): tracing_add_count += 1 self.assertEqual(tracing_add_count, 2) + @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations(self): @jax.jit def f(x, y): @@ -4584,6 +4588,7 @@ class APITest(jtu.JaxTestCase): msg = cm.output[0] self.assertIn("tracing context doesn't match", msg) + @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_new_function_in_loop(self): @jax.jit def f(x, y): @@ -4605,6 +4610,7 @@ class APITest(jtu.JaxTestCase): _, msg = cm.output self.assertIn('another function defined on the same line', msg) + @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_unpacks_transforms(self): # Tests that the explain_tracing_cache_miss() function does not throw an # error when unpacking `transforms` with a length greater than 3. @@ -4701,7 +4707,7 @@ class APITest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, "ndim of its first argument"): jax.sharding.Mesh(jax.devices(), ("x", "y")) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # weakref gc doesn't seem predictable def test_jit_boundmethod_reference_cycle(self): class A: def __init__(self): @@ -4840,7 +4846,7 @@ class RematTest(jtu.JaxTestCase): ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # monkey patches sin_p and cos_p def test_remat_basic(self, remat): @remat def g(x): @@ -5178,7 +5184,7 @@ class RematTest(jtu.JaxTestCase): ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # monkey patches sin_p def test_remat_no_redundant_flops(self, remat): # see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584 @@ -6422,7 +6428,7 @@ class RematTest(jtu.JaxTestCase): self.assertIn(' sin ', str(jaxpr)) self.assertIn(' cos ', str(jaxpr)) - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # logging isn't thread-safe def test_remat_residual_logging(self): def f(x): x = jnp.sin(x) @@ -11126,7 +11132,7 @@ class AutodidaxTest(jtu.JaxTestCase): class GarbageCollectionTest(jtu.JaxTestCase): - @jtu.thread_hostile_test() + @jtu.thread_unsafe_test() # GC isn't predictable def test_xla_gc_callback(self): # https://github.com/jax-ml/jax/issues/14882 x_np = np.arange(10, dtype='int32') diff --git a/tests/array_test.py b/tests/array_test.py index afcdad376..a620ed55a 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -890,6 +890,7 @@ class ShardingTest(jtu.JaxTestCase): r"factors: \[4, 2\] should evenly divide the shape\)"): mps.shard_shape((8, 3)) + @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_pmap_sharding_hash_eq(self): if jax.device_count() < 2: self.skipTest('Test needs >= 2 devices.') diff --git a/tests/batching_test.py b/tests/batching_test.py index 608053c23..bab18ce53 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -1326,6 +1326,7 @@ def list_insert(lst: list[a], idx: int, val: a) -> list[a]: return lst +@jtu.thread_unsafe_test_class() # temporary registration isn't thread-safe class VmappableTest(jtu.JaxTestCase): def test_basic(self): with temporarily_register_named_array_vmappable(): diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 27ebab887..ef245bc8d 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -95,6 +95,7 @@ def clear_cache() -> None: cc._cache.clear() +@jtu.thread_unsafe_test_class() # mocking isn't thread-safe class CompilationCacheTestCase(jtu.JaxTestCase): def setUp(self): diff --git a/tests/core_test.py b/tests/core_test.py index 7ca941c69..e4cbfd562 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -293,6 +293,7 @@ class CoreTest(jtu.JaxTestCase): assert d2_sin(0.0) == 0.0 assert d3_sin(0.0) == -1.0 + @jtu.thread_unsafe_test() # gc isn't predictable when threaded def test_reference_cycles(self): gc.collect() @@ -310,6 +311,7 @@ class CoreTest(jtu.JaxTestCase): finally: gc.set_debug(debug) + @jtu.thread_unsafe_test() # gc isn't predictable when threaded def test_reference_cycles_jit(self): gc.collect() diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 0fc9665ce..392e544d8 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -59,6 +59,7 @@ class DebugCallbackTest(jtu.JaxTestCase): jax.debug.callback("this is not debug.print!") +@jtu.thread_unsafe_test_class() # printing isn't thread-safe class DebugPrintTest(jtu.JaxTestCase): def tearDown(self): @@ -236,6 +237,7 @@ class DebugPrintTest(jtu.JaxTestCase): self.assertEqual(output(), "[1.23 2.35 0. ]\n") +@jtu.thread_unsafe_test_class() # printing isn't thread-safe class DebugPrintTransformationTest(jtu.JaxTestCase): def test_debug_print_batching(self): @@ -507,6 +509,7 @@ class DebugPrintTransformationTest(jtu.JaxTestCase): jax.effects_barrier() self.assertEqual(output(), "hello bwd: 2.0 3.0\n") +@jtu.thread_unsafe_test_class() # printing isn't thread-safe class DebugPrintControlFlowTest(jtu.JaxTestCase): def _assertLinesEqual(self, text1, text2): @@ -722,6 +725,7 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase): b3: 2 """)) +@jtu.thread_unsafe_test_class() # printing isn't thread-safe class DebugPrintParallelTest(jtu.JaxTestCase): def _assertLinesEqual(self, text1, text2): diff --git a/tests/garbage_collection_guard_test.py b/tests/garbage_collection_guard_test.py index 64b2baeff..5c34c6de2 100644 --- a/tests/garbage_collection_guard_test.py +++ b/tests/garbage_collection_guard_test.py @@ -34,6 +34,7 @@ def _create_array_cycle(): return weakref.ref(n1) +@jtu.thread_unsafe_test_class() # GC isn't predictable when threaded. class GarbageCollectionGuardTest(jtu.JaxTestCase): def test_gced_array_is_not_logged_by_default(self): diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 5dd52b416..060502ae6 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -28,6 +28,7 @@ import numpy as np jax.config.parse_flags_with_absl() +@jtu.thread_unsafe_test_class() # infeed isn't thread-safe class InfeedTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 922b37ffa..2d63a4834 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -269,6 +269,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase): self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect}) +@jtu.thread_unsafe_test_class() # because of mlir.register_lowering calls class EffectfulJaxprLoweringTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 68c5d45c9..a04892816 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2138,6 +2138,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): expected = jnp.array([]) self.assertAllClose(ans, expected) + @jtu.thread_unsafe_test() # Cache eviction means we might retrace def testCaching(self): def cond(x): assert python_should_be_executing @@ -3033,6 +3034,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): self.assertEqual(y, x) self.assertIsInstance(y, jax.Array) + @jtu.thread_unsafe_test() # live_arrays count isn't thread-safe def test_cond_memory_leak(self): # https://github.com/jax-ml/jax/issues/12719 diff --git a/tests/lax_test.py b/tests/lax_test.py index 9db2f5bcc..a2b3e8b62 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3886,6 +3886,9 @@ def bake_vmap(batched_args, batch_dims): return ys, bdim_out +# All tests in this test class are thread-hostile because they add and remove +# primitives from global maps. +@jtu.thread_unsafe_test_class() # registration isn't thread-safe class CustomElementTypesTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py index 02d340abc..fc2b0df84 100644 --- a/tests/lobpcg_test.py +++ b/tests/lobpcg_test.py @@ -197,6 +197,7 @@ def _callable_generators(dtype): jax_debug_nans=True, jax_numpy_rank_promotion='raise', jax_traceback_filtering='off') +@jtu.thread_unsafe_test_class() # matplotlib isn't thread-safe class LobpcgTest(jtu.JaxTestCase): def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype): diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 1fc6fe1e9..38a37844e 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -123,6 +123,7 @@ class MultiDeviceTest(jtu.JaxTestCase): val = jax.random.normal(rng, ()) self.assert_committed_to_device(val, device) + @jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe def test_primitive_compilation_cache(self): devices = self.get_devices() diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8c737b886..17cba7faf 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2144,6 +2144,7 @@ class ArrayPjitTest(jtu.JaxTestCase): compiled = f.lower(core.ShapedArray(input_shape, jnp.float32)).compile() compiled(a1) # no error + @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_pjit_single_device_sharding_add(self): a = np.array([1, 2, 3], dtype=jnp.float32) b = np.array([4, 5, 6], dtype=jnp.float32) @@ -2399,6 +2400,7 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertEqual(out1_sharding_id, out3_sharding_id) self.assertEqual(out2_sharding_id, out3_sharding_id) + @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_out_sharding_indices_id_cache_hit(self): shape = (8, 2) mesh = jtu.create_mesh((4, 2), ('x', 'y')) @@ -2863,6 +2865,7 @@ class ArrayPjitTest(jtu.JaxTestCase): x_is_tracer, y_is_tracer = False, True assert f_mixed(x=2, y=3) == 1 + @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_pjit_kwargs(self): a = jnp.arange(8.) b = jnp.arange(4.) @@ -3507,6 +3510,7 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertEqual(cache_info4.hits, cache_info3.hits + 1) self.assertEqual(cache_info4.misses, cache_info3.misses) + @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_cache_hit_pjit_lower_with_cpp_cache_miss(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) @@ -3600,6 +3604,7 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertIsInstance(out3.sharding, NamedSharding) self.assertEqual(count(), 1) + @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_jit_mul_sum_sharding_preserved(self): if config.use_shardy_partitioner.value: raise unittest.SkipTest("Shardy doesn't support PositionalSharding") diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 4694c8155..7ca9a43b9 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -15,6 +15,7 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor +import contextlib from functools import partial import itertools as it import gc @@ -3195,19 +3196,11 @@ class EagerPmapMixin: def setUp(self): super().setUp() - self.eager_pmap_enabled = config.eager_pmap.value - self.jit_disabled = config.disable_jit.value - config.update('jax_disable_jit', True) - config.update('jax_eager_pmap', True) - self.warning_ctx = jtu.ignore_warning( - message="Some donated buffers were not usable", category=UserWarning) - self.warning_ctx.__enter__() - - def tearDown(self): - self.warning_ctx.__exit__(None, None, None) - config.update('jax_eager_pmap', self.eager_pmap_enabled) - config.update('jax_disable_jit', self.jit_disabled) - super().tearDown() + stack = contextlib.ExitStack() + stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True)) + stack.enter_context(jtu.ignore_warning( + message="Some donated buffers were not usable", category=UserWarning)) + self.addCleanup(stack.close) @jtu.pytest_mark_if_available('multiaccelerator') class PythonPmapEagerTest(EagerPmapMixin, PythonPmapTest): diff --git a/tests/profiler_test.py b/tests/profiler_test.py index f0909094a..b686d30ad 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -53,6 +53,7 @@ except ImportError: jax.config.parse_flags_with_absl() +@jtu.thread_unsafe_test_class() # profiler isn't thread-safe class ProfilerTest(unittest.TestCase): # These tests simply test that the profiler API does not crash; they do not # check functional correctness. diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 5a3b9bab5..9b937bd67 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -548,6 +548,7 @@ class PythonCallbackTest(jtu.JaxTestCase): np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.) @with_pure_and_io_callbacks + @jtu.thread_unsafe_test() # logging isn't thread-safe def test_exception_in_callback(self, *, callback): def fail(x): raise RuntimeError("Ooops") @@ -570,6 +571,7 @@ class PythonCallbackTest(jtu.JaxTestCase): self.assertIn("Traceback (most recent call last)", output) @with_pure_and_io_callbacks + @jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe def test_compilation_caching(self, *, callback): def f_outside(x): return 2 * x diff --git a/tests/state_test.py b/tests/state_test.py index cf204b667..6c9a5b127 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -955,6 +955,7 @@ if CAN_USE_HYPOTHESIS: assert next(idx_, None) is None return idx + @jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe class StateHypothesisTest(jtu.JaxTestCase): @hp.given(get_vmap_params()) @@ -1711,6 +1712,7 @@ if CAN_USE_HYPOTHESIS: min_dim=max(f1.min_dim, f2.min_dim), max_dim=min(f1.max_dim, f2.max_dim)) + @jtu.thread_unsafe_test_class() # because of hypothesis class RunStateHypothesisTest(jtu.JaxTestCase): @jax.legacy_prng_key('allow')