mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #21646 from dfm:gh21643
PiperOrigin-RevId: 640559336
This commit is contained in:
commit
7771cd25b1
@ -2075,13 +2075,10 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
# warm-up the cache
|
||||
x = jnp.ones(axis_size)
|
||||
f(x) # warm-up any dispatching compilations
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
_, f_bwd = jax.vjp(f, x)
|
||||
_ = f_bwd(x)
|
||||
self.assertEqual(count[0], 2) # one for fwd, one for bwd
|
||||
_, f_bwd = jax.vjp(f, x)
|
||||
_ = f_bwd(x)
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
_, f_bwd2 = jax.vjp(f, x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user