mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
regex check for each test case
This commit is contained in:
parent
bd9ac93b6b
commit
9260f2c7d7
@ -1721,7 +1721,7 @@ def _jvp(fun: lu.WrappedFun, primals, tangents):
|
||||
f"{core.primal_dtype_to_tangent_dtype(_dtype(p))}, but got "
|
||||
f"tangent dtype {_dtype(t)} instead.")
|
||||
if np.shape(p) != np.shape(t):
|
||||
raise ValueError("jvp called with inconsistent primal and tangent shapes;"
|
||||
raise ValueError("jvp called with different primal and tangent shapes;"
|
||||
f"Got primal shape {np.shape(p)} and tangent shape as {np.shape(t)}")
|
||||
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(fun, tree_def)
|
||||
|
@ -949,10 +949,15 @@ class APITest(jtu.JaxTestCase):
|
||||
lambda: api.jvp(lambda x: -x, (np.float16(2),), (np.float32(4),)))
|
||||
# If primals and tangents are not of the same shape then raise error
|
||||
fun = lambda x: x+1
|
||||
with self.assertRaisesRegex(ValueError, "jvp called with inconsistent primal and tangent shapes"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "jvp called with different primal and tangent shapes"):
|
||||
api.jvp(fun, (jnp.array([1.,2.,3.]),), (jnp.array([1.,2.,3.,4.]),))
|
||||
api.jvp(fun, (jnp.float(10.),), (jnp.array([1.,2.,3.]),))
|
||||
api.jvp(fun, (jnp.array([1.,2.,3.]),), (jnp.float(20.),))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "jvp called with different primal and tangent shapes"):
|
||||
api.jvp(fun, (jnp.float32(10.),), (jnp.array([1.,2.,3.]),))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "jvp called with different primal and tangent shapes"):
|
||||
api.jvp(fun, (jnp.array([1.,2.,3.]),), (jnp.float32(20.),))
|
||||
|
||||
def test_jvp_non_tuple_arguments(self):
|
||||
def f(x, y): return x + y
|
||||
|
Loading…
x
Reference in New Issue
Block a user