Instantiate zero outputs of linear_transpose

This commit is contained in:
Jamie Townsend 2021-03-26 10:50:24 +00:00
parent 848fed8b87
commit 0a3ba6f2ce
2 changed files with 10 additions and 1 deletions

View File

@ -1963,7 +1963,9 @@ def linear_transpose(fun: Callable, *primals) -> Callable:
raise TypeError("cotangent type does not match function output, "
f"expected {out_avals} but got {out_cotangents}")
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
in_cotangents = ad.backward_pass(jaxpr, consts, dummies, out_cotangents)
in_cotangents = map(
ad.instantiate_zeros,
ad.backward_pass(jaxpr, consts, dummies, out_cotangents))
return tree_unflatten(in_tree, in_cotangents)
return transposed_fun

View File

@ -1048,6 +1048,13 @@ class APITest(jtu.JaxTestCase):
expected = -5 + 10j
self.assertEqual(actual, expected)
def test_linear_transpose_zeros(self):
f = lambda x: x[0]
transpose = api.linear_transpose(f, [1., 2.])
actual, = transpose(3.)
expected = [3., 0.]
self.assertEqual(actual, expected)
def test_complex_grad_raises_error(self):
self.assertRaises(TypeError, lambda: grad(lambda x: jnp.sin(x))(1 + 2j))