diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py
index eee288d1b..ce2e54ce3 100644
--- a/jax/_src/test_util.py
+++ b/jax/_src/test_util.py
@@ -1013,16 +1013,27 @@ if hasattr(util, 'Mutex'):
   _test_rwlock = util.Mutex()
 
   def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
-    _test_rwlock.reader_lock()
-    try:
-      test(result)  # type: ignore
-    finally:
-      _test_rwlock.reader_unlock()
+    if getattr(test.__class__, "thread_hostile", False):
+      _test_rwlock.writer_lock()
+      try:
+        test(result)  # type: ignore
+      finally:
+        _test_rwlock.writer_unlock()
+    else:
+      _test_rwlock.reader_lock()
+      try:
+        test(result)  # type: ignore
+      finally:
+        _test_rwlock.reader_unlock()
 
 
   @contextmanager
-  def thread_hostile_test():
-    "Decorator for tests that are not thread-safe."
+  def thread_unsafe_test():
+    """Decorator for tests that are not thread-safe.
+
+    Note: this decorator (naturally) only applies to what it wraps, not to, say,
+    code in separate setUp() or tearDown() methods.
+    """
     if TEST_NUM_THREADS.value <= 0:
       yield
       return
@@ -1048,9 +1059,19 @@ else:
 
 
   @contextmanager
-  def thread_hostile_test():
+  def thread_unsafe_test():
     yield  # No reader-writer lock, so we get no parallelism.
 
+
+def thread_unsafe_test_class():
+  "Decorator that marks a TestCase class as thread-hostile."
+  def f(klass):
+    assert issubclass(klass, unittest.TestCase), type(klass)
+    klass.thread_hostile = True
+    return klass
+  return f
+
+
 class ThreadSafeTestResult:
   """
   Wraps a TestResult to make it thread safe.
@@ -1074,8 +1095,9 @@ class ThreadSafeTestResult:
   def stopTest(self, test: unittest.TestCase):
     stop_time = time.time()
     with self.lock:
-      # We assume test_result is an ABSL _TextAndXMLTestResult, so we can
-      # override how it gets the time.
+      # If test_result is an ABSL _TextAndXMLTestResult we override how it gets
+      # the time. This affects the timing that shows up in the XML output
+      # consumed by CI.
       time_getter = getattr(self.test_result, "time_getter", None)
       try:
         self.test_result.time_getter = lambda: self.start_time
diff --git a/tests/api_test.py b/tests/api_test.py
index bff8ff3d0..379c63900 100644
--- a/tests/api_test.py
+++ b/tests/api_test.py
@@ -632,7 +632,7 @@ class JitTest(jtu.BufferDonationTestCase):
     python_should_be_executing = False
     jit(f)(3)
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # GC effects aren't predictable with threads
   def test_jit_cache_clear(self):
     @jit
     def f(x, y):
@@ -1605,6 +1605,7 @@ class APITest(jtu.JaxTestCase):
     assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
     assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
 
+  @jtu.thread_unsafe_test()  # Concurrent cache eviction means we may retrace.
   def test_grad_of_jit(self):
     side = []
 
@@ -1618,6 +1619,7 @@ class APITest(jtu.JaxTestCase):
     assert grad(f)(2.0) == 4.0
     assert len(side) == 1
 
+  @jtu.thread_unsafe_test()  # Concurrent ache eviction means we may retrace.
   def test_jit_of_grad(self):
     side = []
 
@@ -2589,7 +2591,7 @@ class APITest(jtu.JaxTestCase):
     self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False)
     self.assertEqual(pytree[3], 4)
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # Weakref destruction seems unpredictable with threads
   def test_devicearray_weakref_friendly(self):
     x = device_put(1.)
     y = weakref.ref(x)
@@ -2738,7 +2740,7 @@ class APITest(jtu.JaxTestCase):
 
     self.assertEqual(count(), 1)
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # jit cache misses aren't thread safe
   def test_jit_infer_params_cache(self):
     def f(x):
       return x
@@ -3329,7 +3331,7 @@ class APITest(jtu.JaxTestCase):
     with self.assertRaisesRegex(TypeError, ".*is not a valid JAX type"):
       jax.grad(lambda x: x)(x)
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # logging isn't thread-safe
   def test_jit_compilation_time_logging(self):
     @api.jit
     def f(x):
@@ -3418,7 +3420,7 @@ class APITest(jtu.JaxTestCase):
     self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
     self.assertEqual(z2, 1)
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # monkey-patching mlir.jaxpr_subcomp isn't thread-safe
   def test_nested_jit_hoisting(self):
     @api.jit
     def f(x, y):
@@ -3456,7 +3458,7 @@ class APITest(jtu.JaxTestCase):
     self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul')
     self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add')
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # count_primitive_compiles isn't thread-safe
   def test_primitive_compilation_cache(self):
     with jtu.count_primitive_compiles() as count:
       lax.add(1, 2)
@@ -4016,7 +4018,7 @@ class APITest(jtu.JaxTestCase):
     a2 = jnp.array(((x, x), [x, x]))
     self.assertAllClose(np.array(((1, 1), (1, 1))), a2)
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # count_jit_tracing_cache_miss() isn't thread-safe
   def test_eval_shape_weak_type(self):
     # https://github.com/jax-ml/jax/issues/23302
     arr = jax.numpy.array(1)
@@ -4145,7 +4147,7 @@ class APITest(jtu.JaxTestCase):
       jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
     self.assertIn('Precision.HIGH', str(jaxpr))
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # Updating global configs is not thread-safe.
   def test_dot_precision_forces_retrace(self):
     num_traces = 0
 
@@ -4318,7 +4320,7 @@ class APITest(jtu.JaxTestCase):
       api.make_jaxpr(lambda: jnp.array(3))()
     self.assertEqual(count(), 0)
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # Updating global configs is not thread-safe.
   def test_rank_promotion_forces_retrace(self):
     num_traces = 0
 
@@ -4459,7 +4461,7 @@ class APITest(jtu.JaxTestCase):
     self.assertEqual(jfoo.__qualname__, f"make_jaxpr({foo.__qualname__})")
     self.assertEqual(jfoo.__module__, "jax")
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # Concurrent cache eviction means we may retrace
   def test_inner_jit_function_retracing(self):
     # https://github.com/jax-ml/jax/issues/7155
     inner_count = outer_count = 0
@@ -4507,6 +4509,7 @@ class APITest(jtu.JaxTestCase):
     with self.assertRaisesRegex(ValueError, r".*Received invalid value.*"):
       jax.device_put(jnp.arange(8), 'cpu')
 
+  @jtu.thread_unsafe_test()  # logging is not thread-safe
   def test_clear_cache(self):
     @jax.jit
     def add(x):
@@ -4525,6 +4528,7 @@ class APITest(jtu.JaxTestCase):
           tracing_add_count += 1
       self.assertEqual(tracing_add_count, 2)
 
+  @jtu.thread_unsafe_test()  # logging is not thread-safe
   def test_cache_miss_explanations(self):
     @jax.jit
     def f(x, y):
@@ -4584,6 +4588,7 @@ class APITest(jtu.JaxTestCase):
     msg = cm.output[0]
     self.assertIn("tracing context doesn't match", msg)
 
+  @jtu.thread_unsafe_test()  # logging is not thread-safe
   def test_cache_miss_explanations_new_function_in_loop(self):
     @jax.jit
     def f(x, y):
@@ -4605,6 +4610,7 @@ class APITest(jtu.JaxTestCase):
       _, msg = cm.output
       self.assertIn('another function defined on the same line', msg)
 
+  @jtu.thread_unsafe_test()  # logging is not thread-safe
   def test_cache_miss_explanations_unpacks_transforms(self):
     # Tests that the explain_tracing_cache_miss() function does not throw an
     # error when unpacking `transforms` with a length greater than 3.
@@ -4701,7 +4707,7 @@ class APITest(jtu.JaxTestCase):
     with self.assertRaisesRegex(ValueError, "ndim of its first argument"):
       jax.sharding.Mesh(jax.devices(), ("x", "y"))
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # weakref gc doesn't seem predictable
   def test_jit_boundmethod_reference_cycle(self):
     class A:
       def __init__(self):
@@ -4840,7 +4846,7 @@ class RematTest(jtu.JaxTestCase):
           ('_policy', partial(jax.remat, policy=lambda *_, **__: False)),
           ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
       ])
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # monkey patches sin_p and cos_p
   def test_remat_basic(self, remat):
     @remat
     def g(x):
@@ -5178,7 +5184,7 @@ class RematTest(jtu.JaxTestCase):
           ('_policy', partial(jax.remat, policy=lambda *_, **__: False)),
           ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
       ])
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # monkey patches sin_p
   def test_remat_no_redundant_flops(self, remat):
     # see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584
 
@@ -6422,7 +6428,7 @@ class RematTest(jtu.JaxTestCase):
     self.assertIn(' sin ', str(jaxpr))
     self.assertIn(' cos ', str(jaxpr))
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # logging isn't thread-safe
   def test_remat_residual_logging(self):
     def f(x):
       x = jnp.sin(x)
@@ -11126,7 +11132,7 @@ class AutodidaxTest(jtu.JaxTestCase):
 
 class GarbageCollectionTest(jtu.JaxTestCase):
 
-  @jtu.thread_hostile_test()
+  @jtu.thread_unsafe_test()  # GC isn't predictable
   def test_xla_gc_callback(self):
     # https://github.com/jax-ml/jax/issues/14882
     x_np = np.arange(10, dtype='int32')
diff --git a/tests/array_test.py b/tests/array_test.py
index afcdad376..a620ed55a 100644
--- a/tests/array_test.py
+++ b/tests/array_test.py
@@ -890,6 +890,7 @@ class ShardingTest(jtu.JaxTestCase):
         r"factors: \[4, 2\] should evenly divide the shape\)"):
       mps.shard_shape((8, 3))
 
+  @jtu.thread_unsafe_test()  # cache_info isn't thread-safe
   def test_pmap_sharding_hash_eq(self):
     if jax.device_count() < 2:
       self.skipTest('Test needs >= 2 devices.')
diff --git a/tests/batching_test.py b/tests/batching_test.py
index 608053c23..bab18ce53 100644
--- a/tests/batching_test.py
+++ b/tests/batching_test.py
@@ -1326,6 +1326,7 @@ def list_insert(lst: list[a], idx: int, val: a) -> list[a]:
   return lst
 
 
+@jtu.thread_unsafe_test_class()  # temporary registration isn't thread-safe
 class VmappableTest(jtu.JaxTestCase):
   def test_basic(self):
     with temporarily_register_named_array_vmappable():
diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py
index 27ebab887..ef245bc8d 100644
--- a/tests/compilation_cache_test.py
+++ b/tests/compilation_cache_test.py
@@ -95,6 +95,7 @@ def clear_cache() -> None:
     cc._cache.clear()
 
 
+@jtu.thread_unsafe_test_class()  # mocking isn't thread-safe
 class CompilationCacheTestCase(jtu.JaxTestCase):
 
   def setUp(self):
diff --git a/tests/core_test.py b/tests/core_test.py
index 7ca941c69..e4cbfd562 100644
--- a/tests/core_test.py
+++ b/tests/core_test.py
@@ -293,6 +293,7 @@ class CoreTest(jtu.JaxTestCase):
     assert d2_sin(0.0) == 0.0
     assert d3_sin(0.0) == -1.0
 
+  @jtu.thread_unsafe_test()  # gc isn't predictable when threaded
   def test_reference_cycles(self):
     gc.collect()
 
@@ -310,6 +311,7 @@ class CoreTest(jtu.JaxTestCase):
     finally:
       gc.set_debug(debug)
 
+  @jtu.thread_unsafe_test()  # gc isn't predictable when threaded
   def test_reference_cycles_jit(self):
     gc.collect()
 
diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py
index 0fc9665ce..392e544d8 100644
--- a/tests/debugging_primitives_test.py
+++ b/tests/debugging_primitives_test.py
@@ -59,6 +59,7 @@ class DebugCallbackTest(jtu.JaxTestCase):
       jax.debug.callback("this is not debug.print!")
 
 
+@jtu.thread_unsafe_test_class()  # printing isn't thread-safe
 class DebugPrintTest(jtu.JaxTestCase):
 
   def tearDown(self):
@@ -236,6 +237,7 @@ class DebugPrintTest(jtu.JaxTestCase):
     self.assertEqual(output(), "[1.23 2.35 0.  ]\n")
 
 
+@jtu.thread_unsafe_test_class()  # printing isn't thread-safe
 class DebugPrintTransformationTest(jtu.JaxTestCase):
 
   def test_debug_print_batching(self):
@@ -507,6 +509,7 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
       jax.effects_barrier()
     self.assertEqual(output(), "hello bwd: 2.0 3.0\n")
 
+@jtu.thread_unsafe_test_class()  # printing isn't thread-safe
 class DebugPrintControlFlowTest(jtu.JaxTestCase):
 
   def _assertLinesEqual(self, text1, text2):
@@ -722,6 +725,7 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
       b3: 2
       """))
 
+@jtu.thread_unsafe_test_class()  # printing isn't thread-safe
 class DebugPrintParallelTest(jtu.JaxTestCase):
 
   def _assertLinesEqual(self, text1, text2):
diff --git a/tests/garbage_collection_guard_test.py b/tests/garbage_collection_guard_test.py
index 64b2baeff..5c34c6de2 100644
--- a/tests/garbage_collection_guard_test.py
+++ b/tests/garbage_collection_guard_test.py
@@ -34,6 +34,7 @@ def _create_array_cycle():
   return weakref.ref(n1)
 
 
+@jtu.thread_unsafe_test_class()  # GC isn't predictable when threaded.
 class GarbageCollectionGuardTest(jtu.JaxTestCase):
 
   def test_gced_array_is_not_logged_by_default(self):
diff --git a/tests/infeed_test.py b/tests/infeed_test.py
index 5dd52b416..060502ae6 100644
--- a/tests/infeed_test.py
+++ b/tests/infeed_test.py
@@ -28,6 +28,7 @@ import numpy as np
 jax.config.parse_flags_with_absl()
 
 
+@jtu.thread_unsafe_test_class()  # infeed isn't thread-safe
 class InfeedTest(jtu.JaxTestCase):
 
   def setUp(self):
diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py
index 922b37ffa..2d63a4834 100644
--- a/tests/jaxpr_effects_test.py
+++ b/tests/jaxpr_effects_test.py
@@ -269,6 +269,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
     self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
 
 
+@jtu.thread_unsafe_test_class()  # because of mlir.register_lowering calls
 class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
 
   def setUp(self):
diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py
index 68c5d45c9..a04892816 100644
--- a/tests/lax_control_flow_test.py
+++ b/tests/lax_control_flow_test.py
@@ -2138,6 +2138,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
     expected = jnp.array([])
     self.assertAllClose(ans, expected)
 
+  @jtu.thread_unsafe_test()  # Cache eviction means we might retrace
   def testCaching(self):
     def cond(x):
       assert python_should_be_executing
@@ -3033,6 +3034,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
     self.assertEqual(y, x)
     self.assertIsInstance(y, jax.Array)
 
+  @jtu.thread_unsafe_test()  # live_arrays count isn't thread-safe
   def test_cond_memory_leak(self):
     # https://github.com/jax-ml/jax/issues/12719
 
diff --git a/tests/lax_test.py b/tests/lax_test.py
index 9db2f5bcc..a2b3e8b62 100644
--- a/tests/lax_test.py
+++ b/tests/lax_test.py
@@ -3886,6 +3886,9 @@ def bake_vmap(batched_args, batch_dims):
   return ys, bdim_out
 
 
+# All tests in this test class are thread-hostile because they add and remove
+# primitives from global maps.
+@jtu.thread_unsafe_test_class()  # registration isn't thread-safe
 class CustomElementTypesTest(jtu.JaxTestCase):
 
   def setUp(self):
diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py
index 02d340abc..fc2b0df84 100644
--- a/tests/lobpcg_test.py
+++ b/tests/lobpcg_test.py
@@ -197,6 +197,7 @@ def _callable_generators(dtype):
     jax_debug_nans=True,
     jax_numpy_rank_promotion='raise',
     jax_traceback_filtering='off')
+@jtu.thread_unsafe_test_class()  # matplotlib isn't thread-safe
 class LobpcgTest(jtu.JaxTestCase):
 
   def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype):
diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py
index 1fc6fe1e9..38a37844e 100644
--- a/tests/multi_device_test.py
+++ b/tests/multi_device_test.py
@@ -123,6 +123,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
     val = jax.random.normal(rng, ())
     self.assert_committed_to_device(val, device)
 
+  @jtu.thread_unsafe_test()  # count_primitive_compiles isn't thread-safe
   def test_primitive_compilation_cache(self):
     devices = self.get_devices()
 
diff --git a/tests/pjit_test.py b/tests/pjit_test.py
index 8c737b886..17cba7faf 100644
--- a/tests/pjit_test.py
+++ b/tests/pjit_test.py
@@ -2144,6 +2144,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
       compiled = f.lower(core.ShapedArray(input_shape, jnp.float32)).compile()
       compiled(a1)  # no error
 
+  @jtu.thread_unsafe_test()  # cache_info isn't thread-safe
   def test_pjit_single_device_sharding_add(self):
     a = np.array([1, 2, 3], dtype=jnp.float32)
     b = np.array([4, 5, 6], dtype=jnp.float32)
@@ -2399,6 +2400,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
     self.assertEqual(out1_sharding_id, out3_sharding_id)
     self.assertEqual(out2_sharding_id, out3_sharding_id)
 
+  @jtu.thread_unsafe_test()  # cache_info isn't thread-safe
   def test_out_sharding_indices_id_cache_hit(self):
     shape = (8, 2)
     mesh = jtu.create_mesh((4, 2), ('x', 'y'))
@@ -2863,6 +2865,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
     x_is_tracer, y_is_tracer = False, True
     assert f_mixed(x=2, y=3) == 1
 
+  @jtu.thread_unsafe_test()  # cache_info isn't thread-safe
   def test_pjit_kwargs(self):
     a = jnp.arange(8.)
     b = jnp.arange(4.)
@@ -3507,6 +3510,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
     self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
     self.assertEqual(cache_info4.misses, cache_info3.misses)
 
+  @jtu.thread_unsafe_test()  # cache_info isn't thread-safe
   def test_cache_hit_pjit_lower_with_cpp_cache_miss(self):
     mesh = jtu.create_mesh((2, 1), ('x', 'y'))
     ns = NamedSharding(mesh, P('x'))
@@ -3600,6 +3604,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
       self.assertIsInstance(out3.sharding, NamedSharding)
     self.assertEqual(count(), 1)
 
+  @jtu.thread_unsafe_test()  # cache_info isn't thread-safe
   def test_jit_mul_sum_sharding_preserved(self):
     if config.use_shardy_partitioner.value:
       raise unittest.SkipTest("Shardy doesn't support PositionalSharding")
diff --git a/tests/pmap_test.py b/tests/pmap_test.py
index 4694c8155..7ca9a43b9 100644
--- a/tests/pmap_test.py
+++ b/tests/pmap_test.py
@@ -15,6 +15,7 @@
 from __future__ import annotations
 
 from concurrent.futures import ThreadPoolExecutor
+import contextlib
 from functools import partial
 import itertools as it
 import gc
@@ -3195,19 +3196,11 @@ class EagerPmapMixin:
 
   def setUp(self):
     super().setUp()
-    self.eager_pmap_enabled = config.eager_pmap.value
-    self.jit_disabled = config.disable_jit.value
-    config.update('jax_disable_jit', True)
-    config.update('jax_eager_pmap', True)
-    self.warning_ctx = jtu.ignore_warning(
-        message="Some donated buffers were not usable", category=UserWarning)
-    self.warning_ctx.__enter__()
-
-  def tearDown(self):
-    self.warning_ctx.__exit__(None, None, None)
-    config.update('jax_eager_pmap', self.eager_pmap_enabled)
-    config.update('jax_disable_jit', self.jit_disabled)
-    super().tearDown()
+    stack = contextlib.ExitStack()
+    stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True))
+    stack.enter_context(jtu.ignore_warning(
+        message="Some donated buffers were not usable", category=UserWarning))
+    self.addCleanup(stack.close)
 
 @jtu.pytest_mark_if_available('multiaccelerator')
 class PythonPmapEagerTest(EagerPmapMixin, PythonPmapTest):
diff --git a/tests/profiler_test.py b/tests/profiler_test.py
index f0909094a..b686d30ad 100644
--- a/tests/profiler_test.py
+++ b/tests/profiler_test.py
@@ -53,6 +53,7 @@ except ImportError:
 jax.config.parse_flags_with_absl()
 
 
+@jtu.thread_unsafe_test_class()  # profiler isn't thread-safe
 class ProfilerTest(unittest.TestCase):
   # These tests simply test that the profiler API does not crash; they do not
   # check functional correctness.
diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py
index 5a3b9bab5..9b937bd67 100644
--- a/tests/python_callback_test.py
+++ b/tests/python_callback_test.py
@@ -548,6 +548,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
         np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
 
   @with_pure_and_io_callbacks
+  @jtu.thread_unsafe_test()  # logging isn't thread-safe
   def test_exception_in_callback(self, *, callback):
     def fail(x):
       raise RuntimeError("Ooops")
@@ -570,6 +571,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
       self.assertIn("Traceback (most recent call last)", output)
 
   @with_pure_and_io_callbacks
+  @jtu.thread_unsafe_test()  # count_primitive_compiles isn't thread-safe
   def test_compilation_caching(self, *, callback):
     def f_outside(x):
       return 2 * x
diff --git a/tests/state_test.py b/tests/state_test.py
index cf204b667..6c9a5b127 100644
--- a/tests/state_test.py
+++ b/tests/state_test.py
@@ -955,6 +955,7 @@ if CAN_USE_HYPOTHESIS:
     assert next(idx_, None) is None
     return idx
 
+  @jtu.thread_unsafe_test_class()  # hypothesis isn't thread-safe
   class StateHypothesisTest(jtu.JaxTestCase):
 
     @hp.given(get_vmap_params())
@@ -1711,6 +1712,7 @@ if CAN_USE_HYPOTHESIS:
                     min_dim=max(f1.min_dim, f2.min_dim),
                     max_dim=min(f1.max_dim, f2.max_dim))
 
+  @jtu.thread_unsafe_test_class()  # because of hypothesis
   class RunStateHypothesisTest(jtu.JaxTestCase):
 
     @jax.legacy_prng_key('allow')