improve error messages for grad(..., has_aux=True)

fixes #5776
This commit is contained in:
Matthew Johnson 2021-02-18 09:46:16 -08:00
parent 8ad5118228
commit 5a97eab2a0
2 changed files with 16 additions and 1 deletions

View File

@ -75,7 +75,11 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
@lu.transformation_with_aux
def flatten_fun_nokwargs2(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans, aux = yield py_args, {}
pair = yield py_args, {}
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
raise TypeError("expected function with aux output to return a two-element "
f"tuple, but got type {type(pair)} with value {repr(pair)}")
ans, aux = pair
ans_flat, ans_tree = tree_flatten(ans)
aux_flat, aux_tree = tree_flatten(aux)
yield (ans_flat, aux_flat), (ans_tree, aux_tree)

View File

@ -876,6 +876,16 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(g, grad(lambda x: x**3)(3.))
self.assertAllClose(aux, [9.], check_dtypes=False)
def test_grad_and_aux_error(self):
with self.assertRaisesRegex(TypeError, "two-element tuple"):
grad(lambda x: (1, 2, 3), has_aux=True)(1.)
with self.assertRaisesRegex(TypeError, "two-element tuple"):
grad(lambda x: x, has_aux=True)(1.)
with self.assertRaisesRegex(TypeError, "two-element tuple"):
grad(lambda x: (x,), has_aux=True)(1.)
def test_grad_and_aux_nested(self):
def f(x):
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
@ -2319,6 +2329,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(f(x), f(a))
class RematTest(jtu.JaxTestCase):
def test_remat_basic(self):