jax.jacobian: propagate function signature to transformed function

This commit is contained in:
Jake VanderPlas 2022-10-04 10:21:54 -07:00
parent ae49d2e033
commit 0d9367972b
2 changed files with 29 additions and 0 deletions

View File

@ -1249,6 +1249,12 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
_check_callable(fun)
argnums = _ensure_index(argnums)
docstr = ("Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}.")
@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args,
@ -1331,6 +1337,12 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
"""
_check_callable(fun)
docstr = ("Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}.")
@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args,

View File

@ -1233,6 +1233,23 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(core.ConcretizationTypeError, "Abstract tracer value"):
jax.jit(f)(x)
@parameterized.named_parameters(
('grad', jax.grad),
('jacfwd', jax.jacfwd),
('jacref', jax.jacrev),
)
def test_grad_wrap(self, transform):
# Ensures that transforms wrap transformed functions with the correct signature.
@partial(jit, static_argnames=['flag'])
@transform
def my_function(x, flag):
return x if flag else jnp.zeros_like(x)
self.assertEqual(my_function(1.0, False), 0.0)
self.assertEqual(my_function(1.0, True), 1.0)
def test_grad_bad_input(self):
def f(x):
return x