mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Instantiate zero outputs of linear_transpose
This commit is contained in:
parent
848fed8b87
commit
0a3ba6f2ce
@ -1963,7 +1963,9 @@ def linear_transpose(fun: Callable, *primals) -> Callable:
|
||||
raise TypeError("cotangent type does not match function output, "
|
||||
f"expected {out_avals} but got {out_cotangents}")
|
||||
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
|
||||
in_cotangents = ad.backward_pass(jaxpr, consts, dummies, out_cotangents)
|
||||
in_cotangents = map(
|
||||
ad.instantiate_zeros,
|
||||
ad.backward_pass(jaxpr, consts, dummies, out_cotangents))
|
||||
return tree_unflatten(in_tree, in_cotangents)
|
||||
|
||||
return transposed_fun
|
||||
|
@ -1048,6 +1048,13 @@ class APITest(jtu.JaxTestCase):
|
||||
expected = -5 + 10j
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_linear_transpose_zeros(self):
|
||||
f = lambda x: x[0]
|
||||
transpose = api.linear_transpose(f, [1., 2.])
|
||||
actual, = transpose(3.)
|
||||
expected = [3., 0.]
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_complex_grad_raises_error(self):
|
||||
self.assertRaises(TypeError, lambda: grad(lambda x: jnp.sin(x))(1 + 2j))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user