regex check for each test case

This commit is contained in:
Gaurav Pathak 2021-01-19 09:12:11 -05:00
parent bd9ac93b6b
commit 9260f2c7d7
2 changed files with 9 additions and 4 deletions

View File

@ -1721,7 +1721,7 @@ def _jvp(fun: lu.WrappedFun, primals, tangents):
f"{core.primal_dtype_to_tangent_dtype(_dtype(p))}, but got "
f"tangent dtype {_dtype(t)} instead.")
if np.shape(p) != np.shape(t):
raise ValueError("jvp called with inconsistent primal and tangent shapes;"
raise ValueError("jvp called with different primal and tangent shapes;"
f"Got primal shape {np.shape(p)} and tangent shape as {np.shape(t)}")
flat_fun, out_tree = flatten_fun_nokwargs(fun, tree_def)

View File

@ -949,10 +949,15 @@ class APITest(jtu.JaxTestCase):
lambda: api.jvp(lambda x: -x, (np.float16(2),), (np.float32(4),)))
# If primals and tangents are not of the same shape then raise error
fun = lambda x: x+1
with self.assertRaisesRegex(ValueError, "jvp called with inconsistent primal and tangent shapes"):
with self.assertRaisesRegex(
ValueError, "jvp called with different primal and tangent shapes"):
api.jvp(fun, (jnp.array([1.,2.,3.]),), (jnp.array([1.,2.,3.,4.]),))
api.jvp(fun, (jnp.float(10.),), (jnp.array([1.,2.,3.]),))
api.jvp(fun, (jnp.array([1.,2.,3.]),), (jnp.float(20.),))
with self.assertRaisesRegex(
ValueError, "jvp called with different primal and tangent shapes"):
api.jvp(fun, (jnp.float32(10.),), (jnp.array([1.,2.,3.]),))
with self.assertRaisesRegex(
ValueError, "jvp called with different primal and tangent shapes"):
api.jvp(fun, (jnp.array([1.,2.,3.]),), (jnp.float32(20.),))
def test_jvp_non_tuple_arguments(self):
def f(x, y): return x + y