Merge pull request #9546 from pnkraemer:jet-docs

PiperOrigin-RevId: 432073995
This commit is contained in:
jax authors 2022-03-02 18:04:40 -08:00
commit d369501417
3 changed files with 98 additions and 0 deletions

View File

@ -0,0 +1,9 @@
jax.experimental.jet module
===========================
.. automodule:: jax.experimental.jet
API
---
.. autofunction:: jet

View File

@ -20,6 +20,7 @@ Experimental Modules
jax.experimental.maps
jax.experimental.pjit
jax.experimental.sparse
jax.experimental.jet
Experimental APIs
-----------------

View File

@ -12,6 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Jet is an experimental module for higher-order automatic differentiation
that does not rely on repeated first-order automatic differentiation.
How? Through the propagation of truncated Taylor polynomials.
Consider a function :math:`f = g \circ h`, some point :math:`x`
and some offset :math:`v`.
First-order automatic differentiation (such as :func:`jax.jvp`)
computes the pair :math:`(f(x), \partial f(x)[v])` from the the pair
:math:`(h(x), \partial h(x)[v])`.
:func:`jet` implements the higher-order analogue:
Given the tuple
.. math::
(h_0, ... h_K) :=
(h(x), \partial h(x)[v], \partial^2 h(x)[v, v], ..., \partial^K h(x)[v,...,v]),
which represents a :math:`K`-th order Taylor approximation
of :math:`h` at :math:`x`, :func:`jet` returns a :math:`K`-th order
Taylor approximation of :math:`f` at :math:`x`,
.. math::
(f_0, ..., f_K) :=
(f(x), \partial f(x)[v], \partial^2 f(x)[v, v], ..., \partial^K f(x)[v,...,v]).
More specifically, :func:`jet` computes
.. math::
f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))
and can thus be used for high-order
automatic differentiation of :math:`f`.
Details are explained in
`these notes <https://github.com/google/jax/files/6717197/jet.pdf>`__.
Note:
Help improve :func:`jet` by contributing
`outstanding primitive rules <https://github.com/google/jax/issues/2431>`__.
"""
from typing import Callable
from functools import partial
@ -32,6 +72,54 @@ from jax.custom_derivatives import custom_jvp_call_jaxpr_p
from jax import lax
def jet(fun, primals, series):
r"""Taylor-mode higher-order automatic differentiation.
Args:
fun: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars.
primals: The primal values at which the Taylor approximation of ``fun`` should be
evaluated. Should be either a tuple or a list of arguments,
and its length should be equal to the number of positional parameters of
``fun``.
series: Higher order Taylor-series-coefficients.
Together, `primals` and `series` make up a truncated Taylor polynomial.
Should be either a tuple or a list of tuples or lists,
and its length dictates the degree of the truncated Taylor polynomial.
Returns:
A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``,
and together, ``primals_out`` and ``series_out`` are a
truncated Taylor polynomial of :math:`f(h(\cdot))`.
The ``primals_out`` value has the same Python tree structure as ``primals``,
and the ``series_out`` value the same Python tree structure as ``series``.
For example:
>>> import jax
>>> import jax.numpy as np
Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`,
and the first few Taylor coefficients
:math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`.
Let :math:`f(y) = \sin(y)`.
>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
>>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)
:func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)`
according to Faà di Bruno's formula:
>>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),))
>>> print(f0, f(h0))
0.12467473 0.12467473
>>> print(f1, df(h0) * h1)
0.7441479 0.74414825
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
2.9064622 2.9064634
"""
try:
order, = set(map(len, series))
except ValueError: