Avoid assuming that jnp.sin will be traced in abstract mesh tests

The test does not clear the JAX caches, and jax.sin is a jitted closure
that's shared between all test methods, so there's no guarantee that someone
hasn't already traced sine at that same shape before. This only shows up rarely
since it depends on the subset of tests assigned to the same test executor.

PiperOrigin-RevId: 706706380
This commit is contained in:
Adam Paszke 2024-12-16 07:44:32 -08:00 committed by jax authors
parent 11e0fdf3e7
commit 3b9a8f7913
2 changed files with 5 additions and 4 deletions

View File

@ -4523,7 +4523,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def f(x):
x = with_sharding_constraint(
x, NamedSharding(mesh_lib.AbstractMesh(mesh1.shape_tuple), P('x')))
return jnp.sin(x)
return jax.lax.sin(x)
with (
jtu.count_jit_tracing_cache_miss() as tracing_count,
@ -4536,7 +4536,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
# same num_devices but different devices.
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
f(b) # tracing and lowering cache *hit*
self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin`
self.assertEqual(tracing_count(), 1)
self.assertEqual(lowering_count(), 1)
self.assertEqual(compilation_count(), 2) # 2 misses since devices differ.

View File

@ -811,7 +811,7 @@ class ShardMapTest(jtu.JaxTestCase):
def f(x):
x = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('i'),
out_specs=P('i'))(x)
return jnp.sin(x)
return jax.lax.sin(x)
with (
jtu.count_jit_tracing_cache_miss() as tracing_count,
@ -825,7 +825,7 @@ class ShardMapTest(jtu.JaxTestCase):
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
f(b) # tracing and lowering cache *hit*
self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin`
self.assertEqual(tracing_count(), 1)
self.assertEqual(lowering_count(), 1)
self.assertEqual(compilation_count(), 2) # 2 misses since devices differ.