Improve error message for when backward function in custom_vjp does not return

a tuple.

Prior to this we got an assertion that `py_cts_in is not iterable`.
This commit is contained in:
George Necula 2021-01-29 19:55:02 +01:00
parent 3fd4c11925
commit 617d77e037
2 changed files with 19 additions and 2 deletions

View File

@ -545,12 +545,14 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
# corresponding subtree of in_tree and with leaves of a non-pytree sentinel
# object, to be replaced with Nones in the final returned result.
zero = object() # non-pytree sentinel to replace Nones in py_cts_in
py_cts_in_ = tuple(zero if ct is None else ct for ct in py_cts_in)
dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
cts_in_flat = []
append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0]))
try:
tree_multimap(append_cts, py_cts_in_, dummy)
if not isinstance(py_cts_in, tuple):
raise ValueError
tree_multimap(append_cts,
tuple(zero if ct is None else ct for ct in py_cts_in), dummy)
except ValueError:
_, in_tree2 = tree_flatten(py_cts_in)
msg = ("Custom VJP rule must produce an output with the same container "

View File

@ -4130,6 +4130,21 @@ class CustomVJPTest(jtu.JaxTestCase):
),
lambda: api.grad(f)(2.))
def test_vjp_bwd_returns_non_tuple_error(self):
@api.custom_vjp
def f(x):
return x
def foo_fwd(x):
return x, None
def foo_bwd(_, g):
return 2. * g # Should be a tuple
f.defvjp(foo_fwd, foo_bwd)
with self.assertRaisesRegex(TypeError, "Custom VJP rule .* must produce a tuple"):
api.grad(f)(3.)
def test_issue2511(self):
arr = jnp.ones((5, 2, 2))
foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x)