mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jax.jacobian: propagate function signature to transformed function
This commit is contained in:
parent
ae49d2e033
commit
0d9367972b
@ -1249,6 +1249,12 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
_check_callable(fun)
|
||||
argnums = _ensure_index(argnums)
|
||||
|
||||
docstr = ("Jacobian of {fun} with respect to positional argument(s) "
|
||||
"{argnums}. Takes the same arguments as {fun} but returns the "
|
||||
"jacobian of the output with respect to the arguments at "
|
||||
"positions {argnums}.")
|
||||
|
||||
@wraps(fun, docstr=docstr, argnums=argnums)
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
||||
@ -1331,6 +1337,12 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
"""
|
||||
_check_callable(fun)
|
||||
|
||||
docstr = ("Jacobian of {fun} with respect to positional argument(s) "
|
||||
"{argnums}. Takes the same arguments as {fun} but returns the "
|
||||
"jacobian of the output with respect to the arguments at "
|
||||
"positions {argnums}.")
|
||||
|
||||
@wraps(fun, docstr=docstr, argnums=argnums)
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
||||
|
@ -1233,6 +1233,23 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, "Abstract tracer value"):
|
||||
jax.jit(f)(x)
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('grad', jax.grad),
|
||||
('jacfwd', jax.jacfwd),
|
||||
('jacref', jax.jacrev),
|
||||
)
|
||||
def test_grad_wrap(self, transform):
|
||||
# Ensures that transforms wrap transformed functions with the correct signature.
|
||||
|
||||
@partial(jit, static_argnames=['flag'])
|
||||
@transform
|
||||
def my_function(x, flag):
|
||||
return x if flag else jnp.zeros_like(x)
|
||||
|
||||
self.assertEqual(my_function(1.0, False), 0.0)
|
||||
self.assertEqual(my_function(1.0, True), 1.0)
|
||||
|
||||
def test_grad_bad_input(self):
|
||||
def f(x):
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user