mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Annotate several tests as thread-unsafe.
PiperOrigin-RevId: 714117130
This commit is contained in:
parent
016fca79ca
commit
8f2f4b45fb
@ -11104,6 +11104,7 @@ class CleanupTest(jtu.JaxTestCase):
|
||||
|
||||
class EnvironmentInfoTest(jtu.JaxTestCase):
|
||||
@parameterized.parameters([True, False])
|
||||
@jtu.thread_unsafe_test()
|
||||
def test_print_environment_info(self, return_string):
|
||||
# Flush stdout buffer before checking.
|
||||
sys.stdout.flush()
|
||||
|
@ -838,6 +838,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
self.assertListEqual(hlo_sharding.tile_assignment_devices(),
|
||||
[0, 2, 4, 6, 1, 3, 5, 7])
|
||||
|
||||
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
|
||||
def test_util_clear_cache(self):
|
||||
mesh = jtu.create_mesh((1,), ('x',))
|
||||
s = NamedSharding(mesh, P())
|
||||
|
@ -271,6 +271,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(hash_1, hash_2)
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
@jtu.thread_unsafe_test() # env vars are not thread-safe
|
||||
def test_identical_computations_different_metadata(self, include_metadata):
|
||||
f = lambda x, y: lax.mul(lax.add(x, y), 2)
|
||||
g = lambda x, y: lax.mul(lax.add(x, y), 2)
|
||||
@ -287,6 +288,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
key2 = cache_key.get(computation2, devices, compile_options, backend)
|
||||
self.assertEqual(include_metadata, key1 != key2)
|
||||
|
||||
@jtu.thread_unsafe_test() # env vars are not thread-safe
|
||||
def test_xla_flags(self):
|
||||
if jtu.is_device_tpu(version=4):
|
||||
raise unittest.SkipTest("TODO(b/240151176)")
|
||||
@ -333,6 +335,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
del os.environ["XLA_FLAGS"]
|
||||
sys.argv = orig_argv
|
||||
|
||||
@jtu.thread_unsafe_test() # env vars are not thread-safe
|
||||
def test_libtpu_init_args(self):
|
||||
if jtu.is_device_tpu(version=4):
|
||||
raise unittest.SkipTest("TODO(b/240151176)")
|
||||
|
@ -899,6 +899,7 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
f(jnp.arange(2))
|
||||
jax.effects_barrier()
|
||||
|
||||
@jtu.thread_unsafe_test_class() # logging isn't thread-safe
|
||||
class VisualizeShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def _create_devices(self, shape):
|
||||
|
@ -29,6 +29,7 @@ NUM_SHARDS = 4
|
||||
|
||||
|
||||
@jtu.with_global_config(mock_num_gpu_processes=NUM_SHARDS)
|
||||
@jtu.thread_unsafe_test_class()
|
||||
class MockGPUTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -28,6 +28,7 @@ NUM_HOSTS_PER_SLICE = 4
|
||||
@jtu.with_global_config(
|
||||
jax_mock_gpu_topology=f"{NUM_SLICES}x{NUM_HOSTS_PER_SLICE}x1",
|
||||
jax_cuda_visible_devices="0")
|
||||
@jtu.thread_unsafe_test_class()
|
||||
class MockGPUTopologyTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -1198,7 +1198,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
x = jnp.array([4.2], dtype=jnp.float32)
|
||||
jaxpr = jax.make_jaxpr(g)(x)
|
||||
self.assertEqual(
|
||||
jaxpr.pretty_print(),
|
||||
jaxpr.pretty_print(use_color=False),
|
||||
textwrap.dedent("""
|
||||
let lambda = { lambda ; a:f32[1]. let b:f32[1] = integer_pow[y=2] a in (b,) } in
|
||||
{ lambda ; c:f32[1]. let
|
||||
@ -1225,7 +1225,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
x = jnp.array([4.2], dtype=jnp.float32)
|
||||
jaxpr = jax.make_jaxpr(g)(x, x)
|
||||
self.assertEqual(
|
||||
jaxpr.pretty_print(),
|
||||
jaxpr.pretty_print(use_color=False),
|
||||
textwrap.dedent("""
|
||||
let f = { lambda ; a:f32[1] b:f32[1]. let c:f32[1] = mul b a in (c,) } in
|
||||
{ lambda ; d:f32[1] e:f32[1]. let
|
||||
@ -3462,6 +3462,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# Test second order autodiff with src argument specified in device_put.
|
||||
jtu.check_grads(g, (arr,), order=2)
|
||||
|
||||
@jtu.thread_unsafe_test() # cache_info isn't thread-safe
|
||||
def test_pjit_out_sharding_preserved(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
raise unittest.SkipTest("Shardy doesn't support PositionalSharding")
|
||||
|
@ -859,6 +859,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
|
||||
@jtu.thread_unsafe_test()
|
||||
def test_debug_print_jit(self, jit):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest(
|
||||
|
Loading…
x
Reference in New Issue
Block a user