Added test for python scalar

This commit is contained in:
Gaurav Pathak 2021-01-20 21:47:18 -05:00
parent 2d3f33d976
commit 186c97394d

View File

@ -958,6 +958,9 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError, "jvp called with different primal and tangent shapes"):
api.jvp(fun, (jnp.array([1.,2.,3.], dtype=jnp.float32),), (jnp.float32(20.),))
with self.assertRaisesRegex(
ValueError, "jvp called with different primal and tangent shapes"):
api.jvp(fun, (jnp.array([1.,2.,3.]),), (20.,))
def test_jvp_non_tuple_arguments(self):
def f(x, y): return x + y