mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Merge pull request #9546 from pnkraemer:jet-docs
PiperOrigin-RevId: 432073995
This commit is contained in:
commit
d369501417
9
docs/jax.experimental.jet.rst
Normal file
9
docs/jax.experimental.jet.rst
Normal file
@ -0,0 +1,9 @@
|
||||
jax.experimental.jet module
|
||||
===========================
|
||||
|
||||
.. automodule:: jax.experimental.jet
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
.. autofunction:: jet
|
@ -20,6 +20,7 @@ Experimental Modules
|
||||
jax.experimental.maps
|
||||
jax.experimental.pjit
|
||||
jax.experimental.sparse
|
||||
jax.experimental.jet
|
||||
|
||||
Experimental APIs
|
||||
-----------------
|
||||
|
@ -12,6 +12,46 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""Jet is an experimental module for higher-order automatic differentiation
|
||||
that does not rely on repeated first-order automatic differentiation.
|
||||
|
||||
How? Through the propagation of truncated Taylor polynomials.
|
||||
Consider a function :math:`f = g \circ h`, some point :math:`x`
|
||||
and some offset :math:`v`.
|
||||
First-order automatic differentiation (such as :func:`jax.jvp`)
|
||||
computes the pair :math:`(f(x), \partial f(x)[v])` from the the pair
|
||||
:math:`(h(x), \partial h(x)[v])`.
|
||||
|
||||
:func:`jet` implements the higher-order analogue:
|
||||
Given the tuple
|
||||
|
||||
.. math::
|
||||
(h_0, ... h_K) :=
|
||||
(h(x), \partial h(x)[v], \partial^2 h(x)[v, v], ..., \partial^K h(x)[v,...,v]),
|
||||
|
||||
which represents a :math:`K`-th order Taylor approximation
|
||||
of :math:`h` at :math:`x`, :func:`jet` returns a :math:`K`-th order
|
||||
Taylor approximation of :math:`f` at :math:`x`,
|
||||
|
||||
.. math::
|
||||
(f_0, ..., f_K) :=
|
||||
(f(x), \partial f(x)[v], \partial^2 f(x)[v, v], ..., \partial^K f(x)[v,...,v]).
|
||||
|
||||
More specifically, :func:`jet` computes
|
||||
|
||||
.. math::
|
||||
f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))
|
||||
|
||||
and can thus be used for high-order
|
||||
automatic differentiation of :math:`f`.
|
||||
Details are explained in
|
||||
`these notes <https://github.com/google/jax/files/6717197/jet.pdf>`__.
|
||||
|
||||
Note:
|
||||
Help improve :func:`jet` by contributing
|
||||
`outstanding primitive rules <https://github.com/google/jax/issues/2431>`__.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from functools import partial
|
||||
@ -32,6 +72,54 @@ from jax.custom_derivatives import custom_jvp_call_jaxpr_p
|
||||
from jax import lax
|
||||
|
||||
def jet(fun, primals, series):
|
||||
r"""Taylor-mode higher-order automatic differentiation.
|
||||
|
||||
Args:
|
||||
fun: Function to be differentiated. Its arguments should be arrays, scalars,
|
||||
or standard Python containers of arrays or scalars. It should return an
|
||||
array, scalar, or standard Python container of arrays or scalars.
|
||||
primals: The primal values at which the Taylor approximation of ``fun`` should be
|
||||
evaluated. Should be either a tuple or a list of arguments,
|
||||
and its length should be equal to the number of positional parameters of
|
||||
``fun``.
|
||||
series: Higher order Taylor-series-coefficients.
|
||||
Together, `primals` and `series` make up a truncated Taylor polynomial.
|
||||
Should be either a tuple or a list of tuples or lists,
|
||||
and its length dictates the degree of the truncated Taylor polynomial.
|
||||
|
||||
Returns:
|
||||
A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``,
|
||||
and together, ``primals_out`` and ``series_out`` are a
|
||||
truncated Taylor polynomial of :math:`f(h(\cdot))`.
|
||||
The ``primals_out`` value has the same Python tree structure as ``primals``,
|
||||
and the ``series_out`` value the same Python tree structure as ``series``.
|
||||
|
||||
For example:
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as np
|
||||
|
||||
Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`,
|
||||
and the first few Taylor coefficients
|
||||
:math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`.
|
||||
Let :math:`f(y) = \sin(y)`.
|
||||
|
||||
>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
|
||||
>>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)
|
||||
|
||||
:func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)`
|
||||
according to Faà di Bruno's formula:
|
||||
|
||||
>>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),))
|
||||
>>> print(f0, f(h0))
|
||||
0.12467473 0.12467473
|
||||
|
||||
>>> print(f1, df(h0) * h1)
|
||||
0.7441479 0.74414825
|
||||
|
||||
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
|
||||
2.9064622 2.9064634
|
||||
"""
|
||||
try:
|
||||
order, = set(map(len, series))
|
||||
except ValueError:
|
||||
|
Loading…
x
Reference in New Issue
Block a user