fix grad-and-aux handling of constant aux data

This commit is contained in:
Matthew Johnson 2019-03-07 14:48:05 -08:00
parent 4542929b50
commit caa2ed1a40
2 changed files with 8 additions and 3 deletions

View File

@ -66,8 +66,8 @@ def jvp_subtrace_aux(master, primals, tangents):
for x in list(primals) + list(tangents):
if isinstance(x, Tracer):
assert x.trace.level < trace.level
ans_and_aux = yield map(partial(JVPTracer, trace), primals, tangents)
out_tracer, aux_tracer = trace.full_raise(ans_and_aux)
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents)
out_tracer, aux_tracer = map(trace.full_raise, (ans, aux))
out_primal, out_tangent = out_tracer.primal, out_tracer.tangent
aux = aux_tracer.primal # ignore aux tangent
yield (out_primal, out_tangent), aux

View File

@ -343,7 +343,7 @@ class APITest(jtu.JaxTestCase):
def test_grad_and_aux_basic(self):
g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.)
self.assertEqual(type(aux), list)
self.assertEqual(g, grad(lambda x: x**3)(3.))
self.assertEqual(aux, [9.])
def test_grad_and_aux_nested(self):
@ -367,6 +367,11 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
def test_grad_and_aux_constant(self):
g, aux = grad(lambda x: (x**3, [4.]), has_aux=True)(4.)
self.assertEqual(g, grad(lambda x: x**3)(4.))
self.assertEqual(aux, [4.])
if __name__ == '__main__':
absltest.main()