mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Avoid assuming that jnp.sin will be traced in abstract mesh tests
The test does not clear the JAX caches, and jax.sin is a jitted closure that's shared between all test methods, so there's no guarantee that someone hasn't already traced sine at that same shape before. This only shows up rarely since it depends on the subset of tests assigned to the same test executor. PiperOrigin-RevId: 706706380
This commit is contained in:
parent
11e0fdf3e7
commit
3b9a8f7913
@ -4523,7 +4523,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
x = with_sharding_constraint(
|
||||
x, NamedSharding(mesh_lib.AbstractMesh(mesh1.shape_tuple), P('x')))
|
||||
return jnp.sin(x)
|
||||
return jax.lax.sin(x)
|
||||
|
||||
with (
|
||||
jtu.count_jit_tracing_cache_miss() as tracing_count,
|
||||
@ -4536,7 +4536,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# same num_devices but different devices.
|
||||
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
|
||||
f(b) # tracing and lowering cache *hit*
|
||||
self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin`
|
||||
|
||||
self.assertEqual(tracing_count(), 1)
|
||||
self.assertEqual(lowering_count(), 1)
|
||||
self.assertEqual(compilation_count(), 2) # 2 misses since devices differ.
|
||||
|
||||
|
@ -811,7 +811,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
x = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('i'),
|
||||
out_specs=P('i'))(x)
|
||||
return jnp.sin(x)
|
||||
return jax.lax.sin(x)
|
||||
|
||||
with (
|
||||
jtu.count_jit_tracing_cache_miss() as tracing_count,
|
||||
@ -825,7 +825,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
|
||||
f(b) # tracing and lowering cache *hit*
|
||||
|
||||
self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin`
|
||||
self.assertEqual(tracing_count(), 1)
|
||||
self.assertEqual(lowering_count(), 1)
|
||||
self.assertEqual(compilation_count(), 2) # 2 misses since devices differ.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user