mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Relaxed check to allow both tuples and lists
This commit is contained in:
parent
c1d8d3f74d
commit
e0706ff864
15
jax/api.py
15
jax/api.py
@ -1064,13 +1064,11 @@ def jvp(fun, primals, tangents):
|
||||
or standard Python containers of arrays or scalars. It should return an
|
||||
array, scalar, or standard Python container of arrays or scalars.
|
||||
primals: The primal values at which the Jacobian of `fun` should be
|
||||
evaluated. Should be a tuple of arrays, scalar, or standard Python
|
||||
container thereof. The length of the tuple is equal to the number of
|
||||
positional parameters of `fun`.
|
||||
evaluated. Should be either a tuple or a list of arguments,
|
||||
and its length should equal to the number of positional parameters of `fun`.
|
||||
tangents: The tangent vector for which the Jacobian-vector product should be
|
||||
evaluated. Should be a tuple of arrays, scalar, or standard Python
|
||||
container thereof, with the same tree structure and array shapes as
|
||||
`primals`.
|
||||
evaluated. Should be either a tuple or a list of tangents, with the same
|
||||
tree structure and array shapes as `primals`.
|
||||
|
||||
Returns:
|
||||
A `(primals_out, tangents_out)` pair, where `primals_out` is
|
||||
@ -1089,8 +1087,9 @@ def jvp(fun, primals, tangents):
|
||||
if not isinstance(fun, lu.WrappedFun):
|
||||
fun = lu.wrap_init(fun)
|
||||
|
||||
if not isinstance(primals, tuple) or not isinstance(tangents, tuple):
|
||||
msg = ("primal and tangent arguments to jax.jvp must be tuples; "
|
||||
if (not isinstance(primals, (tuple, list)) or
|
||||
not isinstance(tangents, (tuple, list))):
|
||||
msg = ("primal and tangent arguments to jax.jvp must be tuples or lists; "
|
||||
"found {} and {}.")
|
||||
raise TypeError(msg.format(type(primals).__name__, type(tangents).__name__))
|
||||
|
||||
|
@ -495,6 +495,12 @@ class APITest(jtu.JaxTestCase):
|
||||
("primal and tangent arguments to jax.jvp must have the same tree "
|
||||
"structure"),
|
||||
lambda: api.jvp(lambda x, y: x * y, (onp.float32(2),), ()))
|
||||
# If primals and tangents must both be tuples or both lists
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
("primal and tangent arguments to jax.jvp must have the same tree "
|
||||
"structure"),
|
||||
lambda: api.jvp(lambda x, y: x * y, (onp.float32(2),), [onp.float32(2)]))
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"primal and tangent arguments to jax.jvp must have equal types",
|
||||
@ -505,11 +511,11 @@ class APITest(jtu.JaxTestCase):
|
||||
def f(x, y): return x + y
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"primal and tangent arguments to jax.jvp must be tuples; found float and tuple.",
|
||||
"primal and tangent arguments to jax.jvp must be tuples or lists; found float and tuple.",
|
||||
lambda: partial(api.jvp(f, 0., (1.,))))
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"primal and tangent arguments to jax.jvp must be tuples; found tuple and ndarray.",
|
||||
"primal and tangent arguments to jax.jvp must be tuples or lists; found tuple and ndarray.",
|
||||
lambda: partial(api.jvp(f, (0.,), onp.array([1., 2.]))))
|
||||
|
||||
def test_vjp_mismatched_arguments(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user