mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add documentation for jax.jvp and jax.vjp.
This commit is contained in:
parent
34d3d81387
commit
b714cb30cc
@ -17,6 +17,6 @@ Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: jax
|
||||
:members: jit, grad, value_and_grad, vmap, jacfwd, jacrev, hessian, make_jaxpr
|
||||
:members: jit, grad, value_and_grad, vmap, jacfwd, jacrev, hessian, jvp, vjp, make_jaxpr
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
48
jax/api.py
48
jax/api.py
@ -410,6 +410,29 @@ def papply(fun, in_axes=0):
|
||||
|
||||
|
||||
def jvp(fun, primals, tangents):
|
||||
"""Computes a (forward-mode) Jacobian-vector product of `fun`.
|
||||
|
||||
Args:
|
||||
fun: Function to be differentiated. Its arguments should be arrays, scalars,
|
||||
or standard Python containers of arrays or scalars. It should return an
|
||||
array, scalar, or standard Python container of arrays or scalars.
|
||||
primals: The primal values at which the Jacobian of `fun` should be
|
||||
evaluated. Should be a tuple of arrays, scalar, or standard Python
|
||||
container thereof.
|
||||
tangents: The tangent vector for which the Jacobian-vector product should be
|
||||
evaluated. Should be a tuple of arrays, scalar, or standard Python
|
||||
container thereof.
|
||||
|
||||
Returns:
|
||||
A `(primals_out, tangents_out)` pair, where `primals_out` is `fun(primals)`,
|
||||
and `tangents_out` is the Jacobian-vector product of `function` evaluated at
|
||||
`primals` with `tangents`.
|
||||
|
||||
For example:
|
||||
|
||||
>>> jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
|
||||
(array(0.09983342, dtype=float32), array(0.19900084, dtype=float32))
|
||||
"""
|
||||
def trim_arg(primal, tangent):
|
||||
primal_jtuple, tree_def = pytree_to_jaxtupletree(primal)
|
||||
tangent_jtuple, tree_def_2 = pytree_to_jaxtupletree(tangent)
|
||||
@ -443,6 +466,31 @@ def lift_linearized(jaxpr, consts, io_tree, out_pval, *py_args):
|
||||
return apply_jaxtree_fun(fun, io_tree, *py_args)
|
||||
|
||||
def vjp(fun, *primals):
|
||||
"""Compute a (reverse-mode) vector-Jacobian product of `fun`.
|
||||
|
||||
This is a more general form of `grad` that can be used for functions with
|
||||
non-scalar outputs. For most common use cases, you most likely want `grad`
|
||||
instead of `vjp`.
|
||||
|
||||
Args:
|
||||
fun: Function to be differentiated. Its arguments should be arrays, scalars,
|
||||
or standard Python containers of arrays or scalars. It should return an
|
||||
array, scalar, or standard Python container of arrays or scalars.
|
||||
primals: A sequence of of primal values at which the Jacobian of `fun`
|
||||
should be evaluated. Should be a tuple of arrays, scalar, or standard
|
||||
Python containers thereof.
|
||||
|
||||
Returns:
|
||||
A `(primals_out, gradient)` pair, where `primals_out` is `fun(*primals)`.
|
||||
`gradient` is a function from tangent values that computes vector-Jacobian
|
||||
product of `fun` in an epsilon ball around `primals`.
|
||||
|
||||
>>> def f(x, y):
|
||||
>>> return jax.numpy.sin(x), jax.numpy.cos(y)
|
||||
>>> primals, g = jax.vjp(f, 0.5, 1.0)
|
||||
>>> g((-0.7, 0.3))
|
||||
(array(-0.61430776, dtype=float32), array(-0.2524413, dtype=float32))
|
||||
"""
|
||||
if not isinstance(fun, lu.WrappedFun):
|
||||
fun = lu.wrap_init(fun)
|
||||
primals_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, primals))
|
||||
|
Loading…
x
Reference in New Issue
Block a user