Merge pull request #21646 from dfm:gh21643

PiperOrigin-RevId: 640559336
This commit is contained in:
jax authors 2024-06-05 09:57:35 -07:00
commit 7771cd25b1

View File

@ -2075,13 +2075,10 @@ class PythonPmapTest(jtu.JaxTestCase):
def f(x):
return jnp.sin(x)
# warm-up the cache
x = jnp.ones(axis_size)
f(x) # warm-up any dispatching compilations
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
_, f_bwd = jax.vjp(f, x)
_ = f_bwd(x)
self.assertEqual(count[0], 2) # one for fwd, one for bwd
_, f_bwd = jax.vjp(f, x)
_ = f_bwd(x)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
_, f_bwd2 = jax.vjp(f, x)