add new custom_jvp tests from #2500

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
Matthew Johnson 2020-03-28 13:52:40 -07:00
parent bcc5191c63
commit 67283a08ec

View File

@ -2289,6 +2289,76 @@ class CustomJVPTest(jtu.JaxTestCase):
"respectively."),
lambda: api.jvp(f, (np.float32(2.),), (np.float32(1.),)))
def test_multiple_rule_invocations(self):
@jax.custom_jvp
def expit(x):
return 1 / (1 + lax.exp(-x))
@expit.defjvp
def _expit_jvp(primals, tangents):
(x,), (t,) = primals, tangents
ans = expit(x)
t_out = t * ans * (1 - ans)
return ans, t_out
def scanned_fun(c, _):
return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
def foo(x):
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
return c[-1]
# just make sure these don't crash
foo(3.)
grad(foo)(3.)
grad(lambda x: jax.vmap(foo)(x).sum())(np.arange(3.))
def test_hard_stuff(self):
arr = np.ones((5, 2, 2))
api.jit(jax.vmap(np.linalg.det))(arr) # doesn't crash
def test_hard_stuff2(self):
@jax.custom_jvp
def f(x):
return lax.tie_in(x, onp.zeros(x.shape, x.dtype))
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), t
# don't crash
jax.jit(jax.vmap(f))(np.arange(3.))
jax.jit(jax.vmap(jax.grad(f)))(np.arange(3.))
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(np.arange(3.))
jax.grad(lambda x: jax.vmap(f)(x).sum())(np.arange(3.))
jax.jvp(jax.vmap(f), (np.arange(3.),), (np.ones(3),))
def test_hard_stuff3(self):
@jax.custom_jvp
def relu(x):
return np.maximum(x, 0)
@relu.defjvp
def _relu_jvp(primals, tangents):
x, = primals
t, = tangents
return relu(x), lax.select(x > 0, t, lax.full_like(t, 0))
def scanned_fun(c, _):
return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
def f(x):
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
return c[-1]
# don't crash
jax.jit(jax.vmap(f))(np.arange(3.))
jax.jit(jax.vmap(jax.grad(f)))(np.arange(3.))
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(np.arange(3.))
jax.grad(lambda x: jax.vmap(f)(x).sum())(np.arange(3.))
jax.jvp(jax.jit(jax.vmap(f)), (np.arange(3.),), (np.ones(3),))
class CustomVJPTest(jtu.JaxTestCase):