Annotate several tests as thread-unsafe.

PiperOrigin-RevId: 714117130
This commit is contained in:
Peter Hawkins 2025-01-10 11:24:10 -08:00 committed by jax authors
parent 016fca79ca
commit 8f2f4b45fb
8 changed files with 12 additions and 2 deletions

View File

@ -11104,6 +11104,7 @@ class CleanupTest(jtu.JaxTestCase):
class EnvironmentInfoTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
@jtu.thread_unsafe_test()
def test_print_environment_info(self, return_string):
# Flush stdout buffer before checking.
sys.stdout.flush()

View File

@ -838,6 +838,7 @@ class ShardingTest(jtu.JaxTestCase):
self.assertListEqual(hlo_sharding.tile_assignment_devices(),
[0, 2, 4, 6, 1, 3, 5, 7])
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_util_clear_cache(self):
mesh = jtu.create_mesh((1,), ('x',))
s = NamedSharding(mesh, P())

View File

@ -271,6 +271,7 @@ class CacheKeyTest(jtu.JaxTestCase):
self.assertNotEqual(hash_1, hash_2)
@parameterized.parameters([False, True])
@jtu.thread_unsafe_test() # env vars are not thread-safe
def test_identical_computations_different_metadata(self, include_metadata):
f = lambda x, y: lax.mul(lax.add(x, y), 2)
g = lambda x, y: lax.mul(lax.add(x, y), 2)
@ -287,6 +288,7 @@ class CacheKeyTest(jtu.JaxTestCase):
key2 = cache_key.get(computation2, devices, compile_options, backend)
self.assertEqual(include_metadata, key1 != key2)
@jtu.thread_unsafe_test() # env vars are not thread-safe
def test_xla_flags(self):
if jtu.is_device_tpu(version=4):
raise unittest.SkipTest("TODO(b/240151176)")
@ -333,6 +335,7 @@ class CacheKeyTest(jtu.JaxTestCase):
del os.environ["XLA_FLAGS"]
sys.argv = orig_argv
@jtu.thread_unsafe_test() # env vars are not thread-safe
def test_libtpu_init_args(self):
if jtu.is_device_tpu(version=4):
raise unittest.SkipTest("TODO(b/240151176)")

View File

@ -899,6 +899,7 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
f(jnp.arange(2))
jax.effects_barrier()
@jtu.thread_unsafe_test_class() # logging isn't thread-safe
class VisualizeShardingTest(jtu.JaxTestCase):
def _create_devices(self, shape):

View File

@ -29,6 +29,7 @@ NUM_SHARDS = 4
@jtu.with_global_config(mock_num_gpu_processes=NUM_SHARDS)
@jtu.thread_unsafe_test_class()
class MockGPUTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -28,6 +28,7 @@ NUM_HOSTS_PER_SLICE = 4
@jtu.with_global_config(
jax_mock_gpu_topology=f"{NUM_SLICES}x{NUM_HOSTS_PER_SLICE}x1",
jax_cuda_visible_devices="0")
@jtu.thread_unsafe_test_class()
class MockGPUTopologyTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -1198,7 +1198,7 @@ class PJitTest(jtu.BufferDonationTestCase):
x = jnp.array([4.2], dtype=jnp.float32)
jaxpr = jax.make_jaxpr(g)(x)
self.assertEqual(
jaxpr.pretty_print(),
jaxpr.pretty_print(use_color=False),
textwrap.dedent("""
let lambda = { lambda ; a:f32[1]. let b:f32[1] = integer_pow[y=2] a in (b,) } in
{ lambda ; c:f32[1]. let
@ -1225,7 +1225,7 @@ class PJitTest(jtu.BufferDonationTestCase):
x = jnp.array([4.2], dtype=jnp.float32)
jaxpr = jax.make_jaxpr(g)(x, x)
self.assertEqual(
jaxpr.pretty_print(),
jaxpr.pretty_print(use_color=False),
textwrap.dedent("""
let f = { lambda ; a:f32[1] b:f32[1]. let c:f32[1] = mul b a in (c,) } in
{ lambda ; d:f32[1] e:f32[1]. let
@ -3462,6 +3462,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
# Test second order autodiff with src argument specified in device_put.
jtu.check_grads(g, (arr,), order=2)
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
def test_pjit_out_sharding_preserved(self):
if config.use_shardy_partitioner.value:
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")

View File

@ -859,6 +859,7 @@ class ShardMapTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
@jtu.thread_unsafe_test()
def test_debug_print_jit(self, jit):
if config.use_shardy_partitioner.value:
self.skipTest(