mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
broken remat test!
This commit is contained in:
parent
17d0dd701c
commit
ec8252fce2
@ -4344,6 +4344,15 @@ class RematTest(jtu.JaxTestCase):
|
||||
f_vjp(1.)[0].block_until_ready()
|
||||
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
|
||||
|
||||
def test_remat_of_scan(self):
|
||||
to_scan = lambda c, _: (jnp.sin(c), jnp.sin(c))
|
||||
f = lambda x: lax.scan(to_scan, x, None, length=3)
|
||||
jtu.check_grads(jax.remat(f), (3.,), order=2, modes=['rev'])
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(jax.remat(f), 4.)[1])(1.)
|
||||
self.assertIn(' sin ', str(jaxpr))
|
||||
self.assertIn(' cos ', str(jaxpr))
|
||||
|
||||
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user