improve tests

This commit is contained in:
Matthew Johnson 2020-10-20 21:16:00 -07:00
parent a46d0028cc
commit d5c940820f

View File

@ -3389,6 +3389,8 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def test_initial_style_vmap_2(self):
# This is like test_initial_style_vmap except the primal function closes
# over an array constant.
y = jnp.array([1., 2., 3.])
@api.custom_jvp
@ -3409,6 +3411,22 @@ class CustomJVPTest(jtu.JaxTestCase):
expected = 2. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3))
expected = 2. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3))
expected = 2. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
expected = 2. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
expected = 2. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
class CustomVJPTest(jtu.JaxTestCase):
@ -3628,6 +3646,33 @@ class CustomVJPTest(jtu.JaxTestCase):
expected = 2. * jnp.cos(jnp.arange(3.))
self.assertAllClose(ans, expected, check_dtypes=False)
def test_initial_style_vmap_2(self):
# This is like test_initial_style_vmap except the primal function closes
# over an array constant.
y = jnp.array([1., 2., 3.])
@api.custom_vjp
def f(x):
assert jnp.ndim(x) == 0
return 3 * x * jnp.sum(y)
def f_fwd(x):
return f(x), jnp.cos(x)
def f_rev(cos_x, g):
return (2 * cos_x * g,)
f.defvjp(f_fwd, f_rev)
def foo(x):
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
return out
ans = api.vmap(foo)(jnp.arange(3.))
expected = 3. * jnp.arange(3.) * 6
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.))
expected = 2. * jnp.cos(jnp.arange(3.))
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg(self):
@partial(api.custom_vjp, nondiff_argnums=(0,))
def app(f, x):