mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
improve tests
This commit is contained in:
parent
a46d0028cc
commit
d5c940820f
@ -3389,6 +3389,8 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_initial_style_vmap_2(self):
|
||||
# This is like test_initial_style_vmap except the primal function closes
|
||||
# over an array constant.
|
||||
y = jnp.array([1., 2., 3.])
|
||||
|
||||
@api.custom_jvp
|
||||
@ -3409,6 +3411,22 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
expected = 2. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3))
|
||||
expected = 2. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3))
|
||||
expected = 2. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
|
||||
expected = 2. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
|
||||
expected = 2. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
@ -3628,6 +3646,33 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
expected = 2. * jnp.cos(jnp.arange(3.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_initial_style_vmap_2(self):
|
||||
# This is like test_initial_style_vmap except the primal function closes
|
||||
# over an array constant.
|
||||
y = jnp.array([1., 2., 3.])
|
||||
|
||||
@api.custom_vjp
|
||||
def f(x):
|
||||
assert jnp.ndim(x) == 0
|
||||
return 3 * x * jnp.sum(y)
|
||||
def f_fwd(x):
|
||||
return f(x), jnp.cos(x)
|
||||
def f_rev(cos_x, g):
|
||||
return (2 * cos_x * g,)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
|
||||
def foo(x):
|
||||
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
||||
return out
|
||||
|
||||
ans = api.vmap(foo)(jnp.arange(3.))
|
||||
expected = 3. * jnp.arange(3.) * 6
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.))
|
||||
expected = 2. * jnp.cos(jnp.arange(3.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_nondiff_arg(self):
|
||||
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
||||
def app(f, x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user