broken remat test!

This commit is contained in:
Matthew Johnson 2022-04-29 10:56:03 -07:00
parent 17d0dd701c
commit ec8252fce2

View File

@ -4344,6 +4344,15 @@ class RematTest(jtu.JaxTestCase):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
def test_remat_of_scan(self):
to_scan = lambda c, _: (jnp.sin(c), jnp.sin(c))
f = lambda x: lax.scan(to_scan, x, None, length=3)
jtu.check_grads(jax.remat(f), (3.,), order=2, modes=['rev'])
jaxpr = api.make_jaxpr(api.linearize(jax.remat(f), 4.)[1])(1.)
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))
class JaxprTest(jtu.JaxTestCase):