mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve error message for when backward function in custom_vjp does not return
a tuple. Prior to this we got an assertion that `py_cts_in is not iterable`.
This commit is contained in:
parent
3fd4c11925
commit
617d77e037
@ -545,12 +545,14 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
|
||||
# corresponding subtree of in_tree and with leaves of a non-pytree sentinel
|
||||
# object, to be replaced with Nones in the final returned result.
|
||||
zero = object() # non-pytree sentinel to replace Nones in py_cts_in
|
||||
py_cts_in_ = tuple(zero if ct is None else ct for ct in py_cts_in)
|
||||
dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
|
||||
cts_in_flat = []
|
||||
append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0]))
|
||||
try:
|
||||
tree_multimap(append_cts, py_cts_in_, dummy)
|
||||
if not isinstance(py_cts_in, tuple):
|
||||
raise ValueError
|
||||
tree_multimap(append_cts,
|
||||
tuple(zero if ct is None else ct for ct in py_cts_in), dummy)
|
||||
except ValueError:
|
||||
_, in_tree2 = tree_flatten(py_cts_in)
|
||||
msg = ("Custom VJP rule must produce an output with the same container "
|
||||
|
@ -4130,6 +4130,21 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
),
|
||||
lambda: api.grad(f)(2.))
|
||||
|
||||
def test_vjp_bwd_returns_non_tuple_error(self):
|
||||
@api.custom_vjp
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
def foo_fwd(x):
|
||||
return x, None
|
||||
|
||||
def foo_bwd(_, g):
|
||||
return 2. * g # Should be a tuple
|
||||
|
||||
f.defvjp(foo_fwd, foo_bwd)
|
||||
with self.assertRaisesRegex(TypeError, "Custom VJP rule .* must produce a tuple"):
|
||||
api.grad(f)(3.)
|
||||
|
||||
def test_issue2511(self):
|
||||
arr = jnp.ones((5, 2, 2))
|
||||
foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user