From 8f2f4b45fb07b2e1ad74909573142b8119471fa6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 10 Jan 2025 11:24:10 -0800 Subject: [PATCH] Annotate several tests as thread-unsafe. PiperOrigin-RevId: 714117130 --- tests/api_test.py | 1 + tests/array_test.py | 1 + tests/cache_key_test.py | 3 +++ tests/debugging_primitives_test.py | 1 + tests/mock_gpu_test.py | 1 + tests/mock_gpu_topology_test.py | 1 + tests/pjit_test.py | 5 +++-- tests/shard_map_test.py | 1 + 8 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/api_test.py b/tests/api_test.py index 379c63900..b7e9b4fb9 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -11104,6 +11104,7 @@ class CleanupTest(jtu.JaxTestCase): class EnvironmentInfoTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) + @jtu.thread_unsafe_test() def test_print_environment_info(self, return_string): # Flush stdout buffer before checking. sys.stdout.flush() diff --git a/tests/array_test.py b/tests/array_test.py index a620ed55a..bf7aa6488 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -838,6 +838,7 @@ class ShardingTest(jtu.JaxTestCase): self.assertListEqual(hlo_sharding.tile_assignment_devices(), [0, 2, 4, 6, 1, 3, 5, 7]) + @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_util_clear_cache(self): mesh = jtu.create_mesh((1,), ('x',)) s = NamedSharding(mesh, P()) diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index f84a9d5fb..74f76c75b 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -271,6 +271,7 @@ class CacheKeyTest(jtu.JaxTestCase): self.assertNotEqual(hash_1, hash_2) @parameterized.parameters([False, True]) + @jtu.thread_unsafe_test() # env vars are not thread-safe def test_identical_computations_different_metadata(self, include_metadata): f = lambda x, y: lax.mul(lax.add(x, y), 2) g = lambda x, y: lax.mul(lax.add(x, y), 2) @@ -287,6 +288,7 @@ class CacheKeyTest(jtu.JaxTestCase): key2 = cache_key.get(computation2, devices, compile_options, backend) self.assertEqual(include_metadata, key1 != key2) + @jtu.thread_unsafe_test() # env vars are not thread-safe def test_xla_flags(self): if jtu.is_device_tpu(version=4): raise unittest.SkipTest("TODO(b/240151176)") @@ -333,6 +335,7 @@ class CacheKeyTest(jtu.JaxTestCase): del os.environ["XLA_FLAGS"] sys.argv = orig_argv + @jtu.thread_unsafe_test() # env vars are not thread-safe def test_libtpu_init_args(self): if jtu.is_device_tpu(version=4): raise unittest.SkipTest("TODO(b/240151176)") diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 392e544d8..019d68cff 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -899,6 +899,7 @@ class DebugPrintParallelTest(jtu.JaxTestCase): f(jnp.arange(2)) jax.effects_barrier() +@jtu.thread_unsafe_test_class() # logging isn't thread-safe class VisualizeShardingTest(jtu.JaxTestCase): def _create_devices(self, shape): diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index d17b3c1e7..44ca85782 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -29,6 +29,7 @@ NUM_SHARDS = 4 @jtu.with_global_config(mock_num_gpu_processes=NUM_SHARDS) +@jtu.thread_unsafe_test_class() class MockGPUTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py index 3c479025d..59c511ae6 100644 --- a/tests/mock_gpu_topology_test.py +++ b/tests/mock_gpu_topology_test.py @@ -28,6 +28,7 @@ NUM_HOSTS_PER_SLICE = 4 @jtu.with_global_config( jax_mock_gpu_topology=f"{NUM_SLICES}x{NUM_HOSTS_PER_SLICE}x1", jax_cuda_visible_devices="0") +@jtu.thread_unsafe_test_class() class MockGPUTopologyTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 17cba7faf..44768cbed 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1198,7 +1198,7 @@ class PJitTest(jtu.BufferDonationTestCase): x = jnp.array([4.2], dtype=jnp.float32) jaxpr = jax.make_jaxpr(g)(x) self.assertEqual( - jaxpr.pretty_print(), + jaxpr.pretty_print(use_color=False), textwrap.dedent(""" let lambda = { lambda ; a:f32[1]. let b:f32[1] = integer_pow[y=2] a in (b,) } in { lambda ; c:f32[1]. let @@ -1225,7 +1225,7 @@ class PJitTest(jtu.BufferDonationTestCase): x = jnp.array([4.2], dtype=jnp.float32) jaxpr = jax.make_jaxpr(g)(x, x) self.assertEqual( - jaxpr.pretty_print(), + jaxpr.pretty_print(use_color=False), textwrap.dedent(""" let f = { lambda ; a:f32[1] b:f32[1]. let c:f32[1] = mul b a in (c,) } in { lambda ; d:f32[1] e:f32[1]. let @@ -3462,6 +3462,7 @@ class ArrayPjitTest(jtu.JaxTestCase): # Test second order autodiff with src argument specified in device_put. jtu.check_grads(g, (arr,), order=2) + @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_pjit_out_sharding_preserved(self): if config.use_shardy_partitioner.value: raise unittest.SkipTest("Shardy doesn't support PositionalSharding") diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index c1c5e1556..b23c17022 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -859,6 +859,7 @@ class ShardMapTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) @jtu.run_on_devices('cpu', 'gpu', 'tpu') + @jtu.thread_unsafe_test() def test_debug_print_jit(self, jit): if config.use_shardy_partitioner.value: self.skipTest(