mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Added check for shapes of arguments in jvp, resolves issue #5226
This commit is contained in:
parent
cfddde712b
commit
a47abe06ed
@ -1720,6 +1720,12 @@ def _jvp(fun: lu.WrappedFun, primals, tangents):
|
||||
f"Got primal dtype {_dtype(p)} and so expected tangent dtype "
|
||||
f"{core.primal_dtype_to_tangent_dtype(_dtype(p))}, but got "
|
||||
f"tangent dtype {_dtype(t)} instead.")
|
||||
try:
|
||||
if p.shape != t.shape:
|
||||
raise ValueError("jvp called with inconsistent primal and tangent shapes;"
|
||||
f"Got primal shape {p.shape} and tangent shape as {t.shape}")
|
||||
except AttributeError:
|
||||
pass
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(fun, tree_def)
|
||||
out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
|
||||
return (tree_unflatten(out_tree(), out_primals),
|
||||
|
@ -947,6 +947,12 @@ class APITest(jtu.JaxTestCase):
|
||||
TypeError,
|
||||
"primal and tangent arguments to jax.jvp do not match.",
|
||||
lambda: api.jvp(lambda x: -x, (np.float16(2),), (np.float32(4),)))
|
||||
# If primals and tangents are not of the same shape then raise error
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"jvp called with inconsistent primal and tangent shapes",
|
||||
lambda: api.jvp(lambda x: x+1, (np.random.randn(10,),), (np.random.randn(20,),))
|
||||
)
|
||||
|
||||
def test_jvp_non_tuple_arguments(self):
|
||||
def f(x, y): return x + y
|
||||
|
Loading…
x
Reference in New Issue
Block a user