mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add error checking that arguments of jvp are tuples
This commit is contained in:
parent
ec79adccbb
commit
c1d8d3f74d
@ -1089,6 +1089,11 @@ def jvp(fun, primals, tangents):
|
||||
if not isinstance(fun, lu.WrappedFun):
|
||||
fun = lu.wrap_init(fun)
|
||||
|
||||
if not isinstance(primals, tuple) or not isinstance(tangents, tuple):
|
||||
msg = ("primal and tangent arguments to jax.jvp must be tuples; "
|
||||
"found {} and {}.")
|
||||
raise TypeError(msg.format(type(primals).__name__, type(tangents).__name__))
|
||||
|
||||
ps_flat, tree_def = tree_flatten(primals)
|
||||
ts_flat, tree_def_2 = tree_flatten(tangents)
|
||||
if tree_def != tree_def_2:
|
||||
|
@ -500,6 +500,18 @@ class APITest(jtu.JaxTestCase):
|
||||
"primal and tangent arguments to jax.jvp must have equal types",
|
||||
lambda: api.jvp(lambda x: -x, (onp.float16(2),), (onp.float32(4),)))
|
||||
|
||||
|
||||
def test_jvp_non_tuple_arguments(self):
|
||||
def f(x, y): return x + y
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"primal and tangent arguments to jax.jvp must be tuples; found float and tuple.",
|
||||
lambda: partial(api.jvp(f, 0., (1.,))))
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"primal and tangent arguments to jax.jvp must be tuples; found tuple and ndarray.",
|
||||
lambda: partial(api.jvp(f, (0.,), onp.array([1., 2.]))))
|
||||
|
||||
def test_vjp_mismatched_arguments(self):
|
||||
_, pullback = api.vjp(lambda x, y: x * y, onp.float32(3), onp.float32(4))
|
||||
self.assertRaisesRegex(
|
||||
|
@ -248,7 +248,7 @@ class GeneratedFunTest(jtu.JaxTestCase):
|
||||
tangents = [tangents[i] for i in dyn_argnums]
|
||||
fun, vals = partial_argnums(fun, vals, dyn_argnums)
|
||||
ans1, deriv1 = jvp_fd(fun, vals, tangents)
|
||||
ans2, deriv2 = jvp(fun, vals, tangents)
|
||||
ans2, deriv2 = jvp(fun, tuple(vals), tuple(tangents))
|
||||
check_all_close(ans1, ans2)
|
||||
check_all_close(deriv1, deriv2)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user