Add documentation for jax.jvp and jax.vjp.

This commit is contained in:
Peter Hawkins 2019-02-19 22:08:14 -05:00
parent 34d3d81387
commit b714cb30cc
2 changed files with 49 additions and 1 deletions

View File

@ -17,6 +17,6 @@ Module contents
---------------
.. automodule:: jax
:members: jit, grad, value_and_grad, vmap, jacfwd, jacrev, hessian, make_jaxpr
:members: jit, grad, value_and_grad, vmap, jacfwd, jacrev, hessian, jvp, vjp, make_jaxpr
:undoc-members:
:show-inheritance:

View File

@ -410,6 +410,29 @@ def papply(fun, in_axes=0):
def jvp(fun, primals, tangents):
"""Computes a (forward-mode) Jacobian-vector product of `fun`.
Args:
fun: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars.
primals: The primal values at which the Jacobian of `fun` should be
evaluated. Should be a tuple of arrays, scalar, or standard Python
container thereof.
tangents: The tangent vector for which the Jacobian-vector product should be
evaluated. Should be a tuple of arrays, scalar, or standard Python
container thereof.
Returns:
A `(primals_out, tangents_out)` pair, where `primals_out` is `fun(primals)`,
and `tangents_out` is the Jacobian-vector product of `function` evaluated at
`primals` with `tangents`.
For example:
>>> jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
(array(0.09983342, dtype=float32), array(0.19900084, dtype=float32))
"""
def trim_arg(primal, tangent):
primal_jtuple, tree_def = pytree_to_jaxtupletree(primal)
tangent_jtuple, tree_def_2 = pytree_to_jaxtupletree(tangent)
@ -443,6 +466,31 @@ def lift_linearized(jaxpr, consts, io_tree, out_pval, *py_args):
return apply_jaxtree_fun(fun, io_tree, *py_args)
def vjp(fun, *primals):
"""Compute a (reverse-mode) vector-Jacobian product of `fun`.
This is a more general form of `grad` that can be used for functions with
non-scalar outputs. For most common use cases, you most likely want `grad`
instead of `vjp`.
Args:
fun: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars.
primals: A sequence of of primal values at which the Jacobian of `fun`
should be evaluated. Should be a tuple of arrays, scalar, or standard
Python containers thereof.
Returns:
A `(primals_out, gradient)` pair, where `primals_out` is `fun(*primals)`.
`gradient` is a function from tangent values that computes vector-Jacobian
product of `fun` in an epsilon ball around `primals`.
>>> def f(x, y):
>>> return jax.numpy.sin(x), jax.numpy.cos(y)
>>> primals, g = jax.vjp(f, 0.5, 1.0)
>>> g((-0.7, 0.3))
(array(-0.61430776, dtype=float32), array(-0.2524413, dtype=float32))
"""
if not isinstance(fun, lu.WrappedFun):
fun = lu.wrap_init(fun)
primals_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, primals))