Relaxed check to allow both tuples and lists

This commit is contained in:
George Necula 2019-11-27 14:24:41 +01:00
parent c1d8d3f74d
commit e0706ff864
2 changed files with 15 additions and 10 deletions

View File

@ -1064,13 +1064,11 @@ def jvp(fun, primals, tangents):
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars.
primals: The primal values at which the Jacobian of `fun` should be
evaluated. Should be a tuple of arrays, scalar, or standard Python
container thereof. The length of the tuple is equal to the number of
positional parameters of `fun`.
evaluated. Should be either a tuple or a list of arguments,
and its length should equal to the number of positional parameters of `fun`.
tangents: The tangent vector for which the Jacobian-vector product should be
evaluated. Should be a tuple of arrays, scalar, or standard Python
container thereof, with the same tree structure and array shapes as
`primals`.
evaluated. Should be either a tuple or a list of tangents, with the same
tree structure and array shapes as `primals`.
Returns:
A `(primals_out, tangents_out)` pair, where `primals_out` is
@ -1089,8 +1087,9 @@ def jvp(fun, primals, tangents):
if not isinstance(fun, lu.WrappedFun):
fun = lu.wrap_init(fun)
if not isinstance(primals, tuple) or not isinstance(tangents, tuple):
msg = ("primal and tangent arguments to jax.jvp must be tuples; "
if (not isinstance(primals, (tuple, list)) or
not isinstance(tangents, (tuple, list))):
msg = ("primal and tangent arguments to jax.jvp must be tuples or lists; "
"found {} and {}.")
raise TypeError(msg.format(type(primals).__name__, type(tangents).__name__))

View File

@ -495,6 +495,12 @@ class APITest(jtu.JaxTestCase):
("primal and tangent arguments to jax.jvp must have the same tree "
"structure"),
lambda: api.jvp(lambda x, y: x * y, (onp.float32(2),), ()))
# If primals and tangents must both be tuples or both lists
self.assertRaisesRegex(
TypeError,
("primal and tangent arguments to jax.jvp must have the same tree "
"structure"),
lambda: api.jvp(lambda x, y: x * y, (onp.float32(2),), [onp.float32(2)]))
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must have equal types",
@ -505,11 +511,11 @@ class APITest(jtu.JaxTestCase):
def f(x, y): return x + y
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must be tuples; found float and tuple.",
"primal and tangent arguments to jax.jvp must be tuples or lists; found float and tuple.",
lambda: partial(api.jvp(f, 0., (1.,))))
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must be tuples; found tuple and ndarray.",
"primal and tangent arguments to jax.jvp must be tuples or lists; found tuple and ndarray.",
lambda: partial(api.jvp(f, (0.,), onp.array([1., 2.]))))
def test_vjp_mismatched_arguments(self):