mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
add new custom_jvp tests from #2500
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
parent
bcc5191c63
commit
67283a08ec
@ -2289,6 +2289,76 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
"respectively."),
|
||||
lambda: api.jvp(f, (np.float32(2.),), (np.float32(1.),)))
|
||||
|
||||
def test_multiple_rule_invocations(self):
|
||||
@jax.custom_jvp
|
||||
def expit(x):
|
||||
return 1 / (1 + lax.exp(-x))
|
||||
|
||||
@expit.defjvp
|
||||
def _expit_jvp(primals, tangents):
|
||||
(x,), (t,) = primals, tangents
|
||||
ans = expit(x)
|
||||
t_out = t * ans * (1 - ans)
|
||||
return ans, t_out
|
||||
|
||||
def scanned_fun(c, _):
|
||||
return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
|
||||
|
||||
def foo(x):
|
||||
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
|
||||
return c[-1]
|
||||
|
||||
# just make sure these don't crash
|
||||
foo(3.)
|
||||
grad(foo)(3.)
|
||||
grad(lambda x: jax.vmap(foo)(x).sum())(np.arange(3.))
|
||||
|
||||
def test_hard_stuff(self):
|
||||
arr = np.ones((5, 2, 2))
|
||||
api.jit(jax.vmap(np.linalg.det))(arr) # doesn't crash
|
||||
|
||||
def test_hard_stuff2(self):
|
||||
@jax.custom_jvp
|
||||
def f(x):
|
||||
return lax.tie_in(x, onp.zeros(x.shape, x.dtype))
|
||||
|
||||
@f.defjvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, = primals
|
||||
t, = tangents
|
||||
return f(x), t
|
||||
|
||||
# don't crash
|
||||
jax.jit(jax.vmap(f))(np.arange(3.))
|
||||
jax.jit(jax.vmap(jax.grad(f)))(np.arange(3.))
|
||||
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(np.arange(3.))
|
||||
jax.grad(lambda x: jax.vmap(f)(x).sum())(np.arange(3.))
|
||||
jax.jvp(jax.vmap(f), (np.arange(3.),), (np.ones(3),))
|
||||
|
||||
def test_hard_stuff3(self):
|
||||
@jax.custom_jvp
|
||||
def relu(x):
|
||||
return np.maximum(x, 0)
|
||||
|
||||
@relu.defjvp
|
||||
def _relu_jvp(primals, tangents):
|
||||
x, = primals
|
||||
t, = tangents
|
||||
return relu(x), lax.select(x > 0, t, lax.full_like(t, 0))
|
||||
|
||||
def scanned_fun(c, _):
|
||||
return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
|
||||
|
||||
def f(x):
|
||||
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
|
||||
return c[-1]
|
||||
|
||||
# don't crash
|
||||
jax.jit(jax.vmap(f))(np.arange(3.))
|
||||
jax.jit(jax.vmap(jax.grad(f)))(np.arange(3.))
|
||||
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(np.arange(3.))
|
||||
jax.grad(lambda x: jax.vmap(f)(x).sum())(np.arange(3.))
|
||||
jax.jvp(jax.jit(jax.vmap(f)), (np.arange(3.),), (np.ones(3),))
|
||||
|
||||
class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user