Updated uses of make_jaxpr in new code

This commit is contained in:
George Necula 2019-11-28 09:00:55 +01:00
parent 2b0b04fcad
commit 3b97c5f792

View File

@ -1509,11 +1509,11 @@ class JaxprTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.eqns
scan_eqn, = jaxpr.jaxpr.eqns
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.eqns
scan_eqn, = jaxpr.jaxpr.eqns
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
def test_remat_no_redundant_flops(self):