mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix grad-and-aux handling of constant aux data
This commit is contained in:
parent
4542929b50
commit
caa2ed1a40
@ -66,8 +66,8 @@ def jvp_subtrace_aux(master, primals, tangents):
|
||||
for x in list(primals) + list(tangents):
|
||||
if isinstance(x, Tracer):
|
||||
assert x.trace.level < trace.level
|
||||
ans_and_aux = yield map(partial(JVPTracer, trace), primals, tangents)
|
||||
out_tracer, aux_tracer = trace.full_raise(ans_and_aux)
|
||||
ans, aux = yield map(partial(JVPTracer, trace), primals, tangents)
|
||||
out_tracer, aux_tracer = map(trace.full_raise, (ans, aux))
|
||||
out_primal, out_tangent = out_tracer.primal, out_tracer.tangent
|
||||
aux = aux_tracer.primal # ignore aux tangent
|
||||
yield (out_primal, out_tangent), aux
|
||||
|
@ -343,7 +343,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_grad_and_aux_basic(self):
|
||||
g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.)
|
||||
self.assertEqual(type(aux), list)
|
||||
self.assertEqual(g, grad(lambda x: x**3)(3.))
|
||||
self.assertEqual(aux, [9.])
|
||||
|
||||
def test_grad_and_aux_nested(self):
|
||||
@ -367,6 +367,11 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
|
||||
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
|
||||
|
||||
def test_grad_and_aux_constant(self):
|
||||
g, aux = grad(lambda x: (x**3, [4.]), has_aux=True)(4.)
|
||||
self.assertEqual(g, grad(lambda x: x**3)(4.))
|
||||
self.assertEqual(aux, [4.])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user