mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
8ad5118228
commit
5a97eab2a0
@ -75,7 +75,11 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_nokwargs2(in_tree, *args_flat):
|
||||
py_args = tree_unflatten(in_tree, args_flat)
|
||||
ans, aux = yield py_args, {}
|
||||
pair = yield py_args, {}
|
||||
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
|
||||
raise TypeError("expected function with aux output to return a two-element "
|
||||
f"tuple, but got type {type(pair)} with value {repr(pair)}")
|
||||
ans, aux = pair
|
||||
ans_flat, ans_tree = tree_flatten(ans)
|
||||
aux_flat, aux_tree = tree_flatten(aux)
|
||||
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
|
||||
|
@ -876,6 +876,16 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertAllClose(g, grad(lambda x: x**3)(3.))
|
||||
self.assertAllClose(aux, [9.], check_dtypes=False)
|
||||
|
||||
def test_grad_and_aux_error(self):
|
||||
with self.assertRaisesRegex(TypeError, "two-element tuple"):
|
||||
grad(lambda x: (1, 2, 3), has_aux=True)(1.)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "two-element tuple"):
|
||||
grad(lambda x: x, has_aux=True)(1.)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "two-element tuple"):
|
||||
grad(lambda x: (x,), has_aux=True)(1.)
|
||||
|
||||
def test_grad_and_aux_nested(self):
|
||||
def f(x):
|
||||
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
|
||||
@ -2319,6 +2329,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertEqual(f(x), f(a))
|
||||
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
def test_remat_basic(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user