mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile. PiperOrigin-RevId: 714037277
This commit is contained in:
parent
86643a1b3e
commit
c61b2f6b81
@ -1013,16 +1013,27 @@ if hasattr(util, 'Mutex'):
|
||||
_test_rwlock = util.Mutex()
|
||||
|
||||
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
|
||||
_test_rwlock.reader_lock()
|
||||
try:
|
||||
test(result) # type: ignore
|
||||
finally:
|
||||
_test_rwlock.reader_unlock()
|
||||
if getattr(test.__class__, "thread_hostile", False):
|
||||
_test_rwlock.writer_lock()
|
||||
try:
|
||||
test(result) # type: ignore
|
||||
finally:
|
||||
_test_rwlock.writer_unlock()
|
||||
else:
|
||||
_test_rwlock.reader_lock()
|
||||
try:
|
||||
test(result) # type: ignore
|
||||
finally:
|
||||
_test_rwlock.reader_unlock()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def thread_hostile_test():
|
||||
"Decorator for tests that are not thread-safe."
|
||||
def thread_unsafe_test():
|
||||
"""Decorator for tests that are not thread-safe.
|
||||
|
||||
Note: this decorator (naturally) only applies to what it wraps, not to, say,
|
||||
code in separate setUp() or tearDown() methods.
|
||||
"""
|
||||
if TEST_NUM_THREADS.value <= 0:
|
||||
yield
|
||||
return
|
||||
@ -1048,9 +1059,19 @@ else:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def thread_hostile_test():
|
||||
def thread_unsafe_test():
|
||||
yield # No reader-writer lock, so we get no parallelism.
|
||||
|
||||
|
||||
def thread_unsafe_test_class():
|
||||
"Decorator that marks a TestCase class as thread-hostile."
|
||||
def f(klass):
|
||||
assert issubclass(klass, unittest.TestCase), type(klass)
|
||||
klass.thread_hostile = True
|
||||
return klass
|
||||
return f
|
||||
|
||||
|
||||
class ThreadSafeTestResult:
|
||||
"""
|
||||
Wraps a TestResult to make it thread safe.
|
||||
@ -1074,8 +1095,9 @@ class ThreadSafeTestResult:
|
||||
def stopTest(self, test: unittest.TestCase):
|
||||
stop_time = time.time()
|
||||
with self.lock:
|
||||
# We assume test_result is an ABSL _TextAndXMLTestResult, so we can
|
||||
# override how it gets the time.
|
||||
# If test_result is an ABSL _TextAndXMLTestResult we override how it gets
|
||||
# the time. This affects the timing that shows up in the XML output
|
||||
# consumed by CI.
|
||||
time_getter = getattr(self.test_result, "time_getter", None)
|
||||
try:
|
||||
self.test_result.time_getter = lambda: self.start_time
|
||||
|
@ -632,7 +632,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
python_should_be_executing = False
|
||||
jit(f)(3)
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # GC effects aren't predictable with threads
|
||||
def test_jit_cache_clear(self):
|
||||
@jit
|
||||
def f(x, y):
|
||||
@ -1605,6 +1605,7 @@ class APITest(jtu.JaxTestCase):
|
||||
assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
|
||||
assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
|
||||
|
||||
@jtu.thread_unsafe_test() # Concurrent cache eviction means we may retrace.
|
||||
def test_grad_of_jit(self):
|
||||
side = []
|
||||
|
||||
@ -1618,6 +1619,7 @@ class APITest(jtu.JaxTestCase):
|
||||
assert grad(f)(2.0) == 4.0
|
||||
assert len(side) == 1
|
||||
|
||||
@jtu.thread_unsafe_test() # Concurrent ache eviction means we may retrace.
|
||||
def test_jit_of_grad(self):
|
||||
side = []
|
||||
|
||||
@ -2589,7 +2591,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False)
|
||||
self.assertEqual(pytree[3], 4)
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # Weakref destruction seems unpredictable with threads
|
||||
def test_devicearray_weakref_friendly(self):
|
||||
x = device_put(1.)
|
||||
y = weakref.ref(x)
|
||||
@ -2738,7 +2740,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # jit cache misses aren't thread safe
|
||||
def test_jit_infer_params_cache(self):
|
||||
def f(x):
|
||||
return x
|
||||
@ -3329,7 +3331,7 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, ".*is not a valid JAX type"):
|
||||
jax.grad(lambda x: x)(x)
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # logging isn't thread-safe
|
||||
def test_jit_compilation_time_logging(self):
|
||||
@api.jit
|
||||
def f(x):
|
||||
@ -3418,7 +3420,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
|
||||
self.assertEqual(z2, 1)
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # monkey-patching mlir.jaxpr_subcomp isn't thread-safe
|
||||
def test_nested_jit_hoisting(self):
|
||||
@api.jit
|
||||
def f(x, y):
|
||||
@ -3456,7 +3458,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul')
|
||||
self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add')
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
|
||||
def test_primitive_compilation_cache(self):
|
||||
with jtu.count_primitive_compiles() as count:
|
||||
lax.add(1, 2)
|
||||
@ -4016,7 +4018,7 @@ class APITest(jtu.JaxTestCase):
|
||||
a2 = jnp.array(((x, x), [x, x]))
|
||||
self.assertAllClose(np.array(((1, 1), (1, 1))), a2)
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # count_jit_tracing_cache_miss() isn't thread-safe
|
||||
def test_eval_shape_weak_type(self):
|
||||
# https://github.com/jax-ml/jax/issues/23302
|
||||
arr = jax.numpy.array(1)
|
||||
@ -4145,7 +4147,7 @@ class APITest(jtu.JaxTestCase):
|
||||
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
||||
self.assertIn('Precision.HIGH', str(jaxpr))
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # Updating global configs is not thread-safe.
|
||||
def test_dot_precision_forces_retrace(self):
|
||||
num_traces = 0
|
||||
|
||||
@ -4318,7 +4320,7 @@ class APITest(jtu.JaxTestCase):
|
||||
api.make_jaxpr(lambda: jnp.array(3))()
|
||||
self.assertEqual(count(), 0)
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # Updating global configs is not thread-safe.
|
||||
def test_rank_promotion_forces_retrace(self):
|
||||
num_traces = 0
|
||||
|
||||
@ -4459,7 +4461,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertEqual(jfoo.__qualname__, f"make_jaxpr({foo.__qualname__})")
|
||||
self.assertEqual(jfoo.__module__, "jax")
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # Concurrent cache eviction means we may retrace
|
||||
def test_inner_jit_function_retracing(self):
|
||||
# https://github.com/jax-ml/jax/issues/7155
|
||||
inner_count = outer_count = 0
|
||||
@ -4507,6 +4509,7 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, r".*Received invalid value.*"):
|
||||
jax.device_put(jnp.arange(8), 'cpu')
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_clear_cache(self):
|
||||
@jax.jit
|
||||
def add(x):
|
||||
@ -4525,6 +4528,7 @@ class APITest(jtu.JaxTestCase):
|
||||
tracing_add_count += 1
|
||||
self.assertEqual(tracing_add_count, 2)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations(self):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
@ -4584,6 +4588,7 @@ class APITest(jtu.JaxTestCase):
|
||||
msg = cm.output[0]
|
||||
self.assertIn("tracing context doesn't match", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_new_function_in_loop(self):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
@ -4605,6 +4610,7 @@ class APITest(jtu.JaxTestCase):
|
||||
_, msg = cm.output
|
||||
self.assertIn('another function defined on the same line', msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_unpacks_transforms(self):
|
||||
# Tests that the explain_tracing_cache_miss() function does not throw an
|
||||
# error when unpacking `transforms` with a length greater than 3.
|
||||
@ -4701,7 +4707,7 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "ndim of its first argument"):
|
||||
jax.sharding.Mesh(jax.devices(), ("x", "y"))
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # weakref gc doesn't seem predictable
|
||||
def test_jit_boundmethod_reference_cycle(self):
|
||||
class A:
|
||||
def __init__(self):
|
||||
@ -4840,7 +4846,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
('_policy', partial(jax.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # monkey patches sin_p and cos_p
|
||||
def test_remat_basic(self, remat):
|
||||
@remat
|
||||
def g(x):
|
||||
@ -5178,7 +5184,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
('_policy', partial(jax.remat, policy=lambda *_, **__: False)),
|
||||
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
|
||||
])
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # monkey patches sin_p
|
||||
def test_remat_no_redundant_flops(self, remat):
|
||||
# see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584
|
||||
|
||||
@ -6422,7 +6428,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # logging isn't thread-safe
|
||||
def test_remat_residual_logging(self):
|
||||
def f(x):
|
||||
x = jnp.sin(x)
|
||||
@ -11126,7 +11132,7 @@ class AutodidaxTest(jtu.JaxTestCase):
|
||||
|
||||
class GarbageCollectionTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.thread_hostile_test()
|
||||
@jtu.thread_unsafe_test() # GC isn't predictable
|
||||
def test_xla_gc_callback(self):
|
||||
# https://github.com/jax-ml/jax/issues/14882
|
||||
x_np = np.arange(10, dtype='int32')
|
||||
|
@ -890,6 +890,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
r"factors: \[4, 2\] should evenly divide the shape\)"):
|
||||
mps.shard_shape((8, 3))
|
||||
|
||||
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
|
||||
def test_pmap_sharding_hash_eq(self):
|
||||
if jax.device_count() < 2:
|
||||
self.skipTest('Test needs >= 2 devices.')
|
||||
|
@ -1326,6 +1326,7 @@ def list_insert(lst: list[a], idx: int, val: a) -> list[a]:
|
||||
return lst
|
||||
|
||||
|
||||
@jtu.thread_unsafe_test_class() # temporary registration isn't thread-safe
|
||||
class VmappableTest(jtu.JaxTestCase):
|
||||
def test_basic(self):
|
||||
with temporarily_register_named_array_vmappable():
|
||||
|
@ -95,6 +95,7 @@ def clear_cache() -> None:
|
||||
cc._cache.clear()
|
||||
|
||||
|
||||
@jtu.thread_unsafe_test_class() # mocking isn't thread-safe
|
||||
class CompilationCacheTestCase(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -293,6 +293,7 @@ class CoreTest(jtu.JaxTestCase):
|
||||
assert d2_sin(0.0) == 0.0
|
||||
assert d3_sin(0.0) == -1.0
|
||||
|
||||
@jtu.thread_unsafe_test() # gc isn't predictable when threaded
|
||||
def test_reference_cycles(self):
|
||||
gc.collect()
|
||||
|
||||
@ -310,6 +311,7 @@ class CoreTest(jtu.JaxTestCase):
|
||||
finally:
|
||||
gc.set_debug(debug)
|
||||
|
||||
@jtu.thread_unsafe_test() # gc isn't predictable when threaded
|
||||
def test_reference_cycles_jit(self):
|
||||
gc.collect()
|
||||
|
||||
|
@ -59,6 +59,7 @@ class DebugCallbackTest(jtu.JaxTestCase):
|
||||
jax.debug.callback("this is not debug.print!")
|
||||
|
||||
|
||||
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
|
||||
class DebugPrintTest(jtu.JaxTestCase):
|
||||
|
||||
def tearDown(self):
|
||||
@ -236,6 +237,7 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "[1.23 2.35 0. ]\n")
|
||||
|
||||
|
||||
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
|
||||
class DebugPrintTransformationTest(jtu.JaxTestCase):
|
||||
|
||||
def test_debug_print_batching(self):
|
||||
@ -507,6 +509,7 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
|
||||
jax.effects_barrier()
|
||||
self.assertEqual(output(), "hello bwd: 2.0 3.0\n")
|
||||
|
||||
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
|
||||
class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
def _assertLinesEqual(self, text1, text2):
|
||||
@ -722,6 +725,7 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
b3: 2
|
||||
"""))
|
||||
|
||||
@jtu.thread_unsafe_test_class() # printing isn't thread-safe
|
||||
class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
|
||||
def _assertLinesEqual(self, text1, text2):
|
||||
|
@ -34,6 +34,7 @@ def _create_array_cycle():
|
||||
return weakref.ref(n1)
|
||||
|
||||
|
||||
@jtu.thread_unsafe_test_class() # GC isn't predictable when threaded.
|
||||
class GarbageCollectionGuardTest(jtu.JaxTestCase):
|
||||
|
||||
def test_gced_array_is_not_logged_by_default(self):
|
||||
|
@ -28,6 +28,7 @@ import numpy as np
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.thread_unsafe_test_class() # infeed isn't thread-safe
|
||||
class InfeedTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -269,6 +269,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
||||
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
|
||||
|
||||
|
||||
@jtu.thread_unsafe_test_class() # because of mlir.register_lowering calls
|
||||
class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -2138,6 +2138,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
expected = jnp.array([])
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
@jtu.thread_unsafe_test() # Cache eviction means we might retrace
|
||||
def testCaching(self):
|
||||
def cond(x):
|
||||
assert python_should_be_executing
|
||||
@ -3033,6 +3034,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertEqual(y, x)
|
||||
self.assertIsInstance(y, jax.Array)
|
||||
|
||||
@jtu.thread_unsafe_test() # live_arrays count isn't thread-safe
|
||||
def test_cond_memory_leak(self):
|
||||
# https://github.com/jax-ml/jax/issues/12719
|
||||
|
||||
|
@ -3886,6 +3886,9 @@ def bake_vmap(batched_args, batch_dims):
|
||||
return ys, bdim_out
|
||||
|
||||
|
||||
# All tests in this test class are thread-hostile because they add and remove
|
||||
# primitives from global maps.
|
||||
@jtu.thread_unsafe_test_class() # registration isn't thread-safe
|
||||
class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -197,6 +197,7 @@ def _callable_generators(dtype):
|
||||
jax_debug_nans=True,
|
||||
jax_numpy_rank_promotion='raise',
|
||||
jax_traceback_filtering='off')
|
||||
@jtu.thread_unsafe_test_class() # matplotlib isn't thread-safe
|
||||
class LobpcgTest(jtu.JaxTestCase):
|
||||
|
||||
def checkLobpcgConsistency(self, matrix_name, n, k, m, tol, dtype):
|
||||
|
@ -123,6 +123,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
val = jax.random.normal(rng, ())
|
||||
self.assert_committed_to_device(val, device)
|
||||
|
||||
@jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
|
||||
def test_primitive_compilation_cache(self):
|
||||
devices = self.get_devices()
|
||||
|
||||
|
@ -2144,6 +2144,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
compiled = f.lower(core.ShapedArray(input_shape, jnp.float32)).compile()
|
||||
compiled(a1) # no error
|
||||
|
||||
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
|
||||
def test_pjit_single_device_sharding_add(self):
|
||||
a = np.array([1, 2, 3], dtype=jnp.float32)
|
||||
b = np.array([4, 5, 6], dtype=jnp.float32)
|
||||
@ -2399,6 +2400,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out1_sharding_id, out3_sharding_id)
|
||||
self.assertEqual(out2_sharding_id, out3_sharding_id)
|
||||
|
||||
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
|
||||
def test_out_sharding_indices_id_cache_hit(self):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
@ -2863,6 +2865,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
x_is_tracer, y_is_tracer = False, True
|
||||
assert f_mixed(x=2, y=3) == 1
|
||||
|
||||
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
|
||||
def test_pjit_kwargs(self):
|
||||
a = jnp.arange(8.)
|
||||
b = jnp.arange(4.)
|
||||
@ -3507,6 +3510,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
|
||||
self.assertEqual(cache_info4.misses, cache_info3.misses)
|
||||
|
||||
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
|
||||
def test_cache_hit_pjit_lower_with_cpp_cache_miss(self):
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
ns = NamedSharding(mesh, P('x'))
|
||||
@ -3600,6 +3604,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(out3.sharding, NamedSharding)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
|
||||
def test_jit_mul_sum_sharding_preserved(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")
|
||||
|
@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import gc
|
||||
@ -3195,19 +3196,11 @@ class EagerPmapMixin:
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.eager_pmap_enabled = config.eager_pmap.value
|
||||
self.jit_disabled = config.disable_jit.value
|
||||
config.update('jax_disable_jit', True)
|
||||
config.update('jax_eager_pmap', True)
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message="Some donated buffers were not usable", category=UserWarning)
|
||||
self.warning_ctx.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.warning_ctx.__exit__(None, None, None)
|
||||
config.update('jax_eager_pmap', self.eager_pmap_enabled)
|
||||
config.update('jax_disable_jit', self.jit_disabled)
|
||||
super().tearDown()
|
||||
stack = contextlib.ExitStack()
|
||||
stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True))
|
||||
stack.enter_context(jtu.ignore_warning(
|
||||
message="Some donated buffers were not usable", category=UserWarning))
|
||||
self.addCleanup(stack.close)
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PythonPmapEagerTest(EagerPmapMixin, PythonPmapTest):
|
||||
|
@ -53,6 +53,7 @@ except ImportError:
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.thread_unsafe_test_class() # profiler isn't thread-safe
|
||||
class ProfilerTest(unittest.TestCase):
|
||||
# These tests simply test that the profiler API does not crash; they do not
|
||||
# check functional correctness.
|
||||
|
@ -548,6 +548,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
|
||||
|
||||
@with_pure_and_io_callbacks
|
||||
@jtu.thread_unsafe_test() # logging isn't thread-safe
|
||||
def test_exception_in_callback(self, *, callback):
|
||||
def fail(x):
|
||||
raise RuntimeError("Ooops")
|
||||
@ -570,6 +571,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
self.assertIn("Traceback (most recent call last)", output)
|
||||
|
||||
@with_pure_and_io_callbacks
|
||||
@jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe
|
||||
def test_compilation_caching(self, *, callback):
|
||||
def f_outside(x):
|
||||
return 2 * x
|
||||
|
@ -955,6 +955,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
assert next(idx_, None) is None
|
||||
return idx
|
||||
|
||||
@jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe
|
||||
class StateHypothesisTest(jtu.JaxTestCase):
|
||||
|
||||
@hp.given(get_vmap_params())
|
||||
@ -1711,6 +1712,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
min_dim=max(f1.min_dim, f2.min_dim),
|
||||
max_dim=min(f1.max_dim, f2.max_dim))
|
||||
|
||||
@jtu.thread_unsafe_test_class() # because of hypothesis
|
||||
class RunStateHypothesisTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.legacy_prng_key('allow')
|
||||
|
Loading…
x
Reference in New Issue
Block a user