add test for #2657

This commit is contained in:
Matthew Johnson 2020-04-10 11:45:33 -07:00
parent 61abdc1e10
commit 7750a16cf8

View File

@ -2446,6 +2446,34 @@ class CustomJVPTest(jtu.JaxTestCase):
api.eval_shape(expit, np.ones((2, 3)))
api.eval_shape(api.grad(lambda x: expit(x).sum()), np.ones((2, 3)))
def test_jaxpr_zeros(self):
# from https://github.com/google/jax/issues/2657
@api.custom_jvp
def f(A, b):
return A @ b
def f_jvp(primals, tangents):
A, b = primals
dA, db = tangents
z = f(A, b)
dz = A @ db + dA @ b
return z, dz
f.defjvp(f_jvp)
def experiment(theta):
def step(q, _):
z = f(np.eye(3), np.ones(3) * theta)
q += z[0]
return q, q
q = 0.
q, _ = lax.scan(step, q, None, 4)
return q
grad(experiment)(1.) # doesn't crash
class CustomVJPTest(jtu.JaxTestCase):
def test_basic(self):