mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Updated uses of make_jaxpr in new code
This commit is contained in:
parent
2b0b04fcad
commit
3b97c5f792
@ -1509,11 +1509,11 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
|
||||
scan_eqn, = jaxpr.eqns
|
||||
scan_eqn, = jaxpr.jaxpr.eqns
|
||||
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
||||
|
||||
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
|
||||
scan_eqn, = jaxpr.eqns
|
||||
scan_eqn, = jaxpr.jaxpr.eqns
|
||||
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
||||
|
||||
def test_remat_no_redundant_flops(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user