mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add test for #2657
This commit is contained in:
parent
61abdc1e10
commit
7750a16cf8
@ -2446,6 +2446,34 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
api.eval_shape(expit, np.ones((2, 3)))
|
||||
api.eval_shape(api.grad(lambda x: expit(x).sum()), np.ones((2, 3)))
|
||||
|
||||
def test_jaxpr_zeros(self):
|
||||
# from https://github.com/google/jax/issues/2657
|
||||
@api.custom_jvp
|
||||
def f(A, b):
|
||||
return A @ b
|
||||
|
||||
def f_jvp(primals, tangents):
|
||||
A, b = primals
|
||||
dA, db = tangents
|
||||
z = f(A, b)
|
||||
dz = A @ db + dA @ b
|
||||
return z, dz
|
||||
|
||||
f.defjvp(f_jvp)
|
||||
|
||||
def experiment(theta):
|
||||
def step(q, _):
|
||||
z = f(np.eye(3), np.ones(3) * theta)
|
||||
q += z[0]
|
||||
return q, q
|
||||
|
||||
q = 0.
|
||||
q, _ = lax.scan(step, q, None, 4)
|
||||
return q
|
||||
|
||||
grad(experiment)(1.) # doesn't crash
|
||||
|
||||
|
||||
class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user