DOC: add initial JAX glossary

This commit is contained in:
Jake VanderPlas 2021-03-01 17:56:43 -08:00
parent 3960e63172
commit 8f5038d4b9
4 changed files with 83 additions and 0 deletions

74
docs/glossary.rst Normal file
View File

@ -0,0 +1,74 @@
JAX Glossary of Terms
=====================
.. glossary::
CPU
Short for *Central Processing Unit*, CPUs are the standard computational architecture
available in most computers. JAX can run computations on CPUs, but often can achieve
much better performance on :term:`GPU` and :term:`TPU`.
Device
The generic name used to refer to the :term:`CPU`, :term:`GPU`, or :term:`TPU` used
by JAX to perform computations.
DeviceArray
JAX's analog of the :class:`numpy.ndarray`. See :class:`jax.interpreters.xla.DeviceArray`.
forward-mode autodiff
See :term:`JVP`
functional programming
A programming paradigm in which programs are defined by applying and composing
:term:`pure functions<pure function>`. JAX is designed for use with functional programs.
GPU
Short for *Graphical Processing Unit*, GPUs were originally specialized for operations
related to rendering of images on screen, but now are much more general-purpose. JAX is
able to target GPUs for fast operations on arrays (see also :term:`CPU` and :term:`TPU`).
JIT
Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of
array operations to :term:`XLA`, most often accomplished using :func:`jax.jit`.
JVP
Short for *Jacobian Vector Product*, also sometimes known as *forward-mode* automatic
differentiation. For more details, see :ref:`jacobian-vector-product`. In JAX, JVP is
a :term:`transformation` that is implemented via :func:`jax.jvp`. See also :term:`VJP`.
pure function
A pure function is a function whose outputs are based only on its inputs, and which has
no side-effects. JAX's :term:`transformation` model is designed to work with pure functions.
See also :term:`functional programming`.
reverse-mode autodiff
See :term:`VJP`.
static
In a :term:`JIT` compilation, a value that is not traced (see :term:`Tracer`). Also
sometimes refers to compile-time computations on static values.
TPU
Short for *Tensor Processing Unit*, TPUs are chips specifically engineered for fast operations
on N-dimensional tensors used in deep learning applications. JAX is able to target TPUs for
fast operations on arrays (see also :term:`CPU` and :term:`GPU`).
Tracer
An object used as a standin for a JAX :term:`DeviceArray` in order to determine the
sequence of operations performed by a Python function. Internally, JAX implements this
via the :class:`jax.core.Tracer` class.
transformation
A higher-order function: that is, a function that takes a function as input and outputs
a transformed function. Examples in JAX include :func:`jax.jit`, :func:`jax.vmap`, and
:func:`jax.grad`.
VJP
Short for *Vector Jacobian Product*, also sometimes known as *reverse-mode* automatic
differentiation. For more details, see :ref:`vector-jacobian-product`. In JAX, VJP is
a :term:`transformation` that is implemented via :func:`jax.vjp`. See also :term:`JVP`.
XLA
Short for *Accelerated Linear Algebra*, XLA is a domain-specific compiler for linear
algebra operations that is the primary backend for :term:`JIT`-compiled JAX code.
See https://www.tensorflow.org/xla/.

View File

@ -53,6 +53,7 @@ For an introduction to JAX, start at the
rank_promotion_warning
type_promotion
custom_vjp_update
glossary
.. toctree::
:maxdepth: 2

View File

@ -619,6 +619,8 @@
"id": "mtSRvouV6vvG"
},
"source": [
"(jacobian-vector-product)=\n",
"\n",
"### Jacobian-Vector products (JVPs, aka forward-mode autodiff)\n",
"\n",
"JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar `grad` function is built on reverse-mode, but to explain the difference in the two modes, and when each can be useful, we need a bit of math background.\n",
@ -708,6 +710,8 @@
"id": "PhkvkZazdXu1"
},
"source": [
"(vector-jacobian-product)=\n",
"\n",
"### Vector-Jacobian products (VJPs, aka reverse-mode autodiff)\n",
"\n",
"Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time.\n",

View File

@ -331,6 +331,8 @@ To implement `hessian`, we could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd
+++ {"id": "mtSRvouV6vvG"}
(jacobian-vector-product)=
### Jacobian-Vector products (JVPs, aka forward-mode autodiff)
JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar `grad` function is built on reverse-mode, but to explain the difference in the two modes, and when each can be useful, we need a bit of math background.
@ -400,6 +402,8 @@ To do better for functions like this, we just need to use reverse-mode.
+++ {"id": "PhkvkZazdXu1"}
(vector-jacobian-product)=
### Vector-Jacobian products (VJPs, aka reverse-mode autodiff)
Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time.