add docstring for defjvp_all

This commit is contained in:
Matthew Johnson 2019-06-05 16:56:43 -07:00
parent 720dec4072
commit ab20f0292c
2 changed files with 42 additions and 2 deletions

View File

@ -19,6 +19,6 @@ Module contents
---------------
.. automodule:: jax
:members: jit, disable_jit, grad, value_and_grad, vmap, pmap, jacfwd, jacrev, hessian, jvp, linearize, vjp, make_jaxpr, eval_shape, custom_transforms
:members: jit, disable_jit, grad, value_and_grad, vmap, pmap, jacfwd, jacrev, hessian, jvp, linearize, vjp, make_jaxpr, eval_shape, custom_transforms, defjvp, defjvp2, defjvp_all, defvjp, defvjp2, defvjp_all, custom_gradient
:undoc-members:
:show-inheritance:

View File

@ -1076,7 +1076,47 @@ def _check_custom_transforms_type(name, fun):
raise TypeError(msg.format(name, type(fun)))
def defjvp_all(fun, custom_jvp):
"""Define a custom JVP rule for a custom_transforms function."""
"""Define a custom JVP rule for a custom_transforms function.
If ``fun`` represents a function with signature ``a -> b``, then
``custom_jvp`` represents a function with signature ``a -> T a -> (b, Tb)``,
where we use ``T x`` to represent a tangent type for the type ``x``.
Defining a custom JVP rule will also affect the dfeault VJP rule, which is
derived from the JVP rule automatically via transposition.
Args:
fun: a custom_transforms function.
custom_jvp: a Python callable specifying the JVP rule, taking two tuples as
arguments specifying the input primal values and tangent values,
respectively. The tuple elements can be arrays, scalars, or (nested)
standard Python containers (tuple/list/dict) thereof. Must be functionally
pure.
Returns:
None. A side-effect is that ``fun`` is associated with the JVP rule
specified by ``custom_jvp``.
For example:
>>> @jax.custom_transforms
... def f(x):
... return np.sin(x ** 2)
...
>>> print(f(3.))
0.4121185
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
>>> print(out_primal)
0.4121185
>>> print(out_tangent)
-10.933563
>>> jax.defjvp_all(f, lambda ps, ts: (np.sin(ps[0] ** 2), 8. * ts[0]))
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
>>> print(out_primal)
0.4121185
>>> print(out_tangent)
16.0
"""
_check_custom_transforms_type("defjvp_all", fun)
def custom_transforms_jvp(primals, tangents, **params):
jax_kwargs, jax_args = primals[0], primals[1:]