mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00
100 lines
4.5 KiB
ReStructuredText
100 lines
4.5 KiB
ReStructuredText
Glossary of terms
|
|
=================
|
|
|
|
.. glossary::
|
|
|
|
Array
|
|
JAX's analog of :class:`numpy.ndarray`. See :class:`jax.Array`.
|
|
|
|
CPU
|
|
Short for *Central Processing Unit*, CPUs are the standard computational architecture
|
|
available in most computers. JAX can run computations on CPUs, but often can achieve
|
|
much better performance on :term:`GPU` and :term:`TPU`.
|
|
|
|
Device
|
|
The generic name used to refer to the :term:`CPU`, :term:`GPU`, or :term:`TPU` used
|
|
by JAX to perform computations.
|
|
|
|
forward-mode autodiff
|
|
See :term:`JVP`
|
|
|
|
functional programming
|
|
A programming paradigm in which programs are defined by applying and composing
|
|
:term:`pure functions<pure function>`. JAX is designed for use with functional programs.
|
|
|
|
GPU
|
|
Short for *Graphical Processing Unit*, GPUs were originally specialized for operations
|
|
related to rendering of images on screen, but now are much more general-purpose. JAX is
|
|
able to target GPUs for fast operations on arrays (see also :term:`CPU` and :term:`TPU`).
|
|
|
|
jaxpr
|
|
Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that
|
|
is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution.
|
|
See :ref:`jax-internals-jaxpr` for more discussion and examples.
|
|
|
|
JIT
|
|
Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of
|
|
array operations to :term:`XLA`, most often accomplished using :func:`jax.jit`.
|
|
|
|
JVP
|
|
Short for *Jacobian Vector Product*, also sometimes known as *forward-mode* automatic
|
|
differentiation. For more details, see :ref:`jacobian-vector-product`. In JAX, JVP is
|
|
a :term:`transformation` that is implemented via :func:`jax.jvp`. See also :term:`VJP`.
|
|
|
|
primitive
|
|
A primitive is a fundamental unit of computation used in JAX programs. Most functions
|
|
in :mod:`jax.lax` represent individual primitives. When representing a computation in
|
|
a :term:`jaxpr`, each operation in the jaxpr is a primitive.
|
|
|
|
pure function
|
|
A pure function is a function whose outputs are based only on its inputs, and which has
|
|
no side-effects. JAX's :term:`transformation` model is designed to work with pure functions.
|
|
See also :term:`functional programming`.
|
|
|
|
pytree
|
|
A pytree is an abstraction that lets JAX handle tuples, lists, dicts, and other more
|
|
general containers of array values in a uniform way. Refer to :ref:`working-with-pytrees`
|
|
for a more detailed discussion.
|
|
|
|
reverse-mode autodiff
|
|
See :term:`VJP`.
|
|
|
|
SPMD
|
|
Short for *Single Program Multi Data*, it refers to a parallel computation technique in which
|
|
the same computation (e.g., the forward pass of a neural net) is run on different input data
|
|
(e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs).
|
|
:func:`jax.pmap` is a JAX :term:`transformation` that implements SPMD parallelism.
|
|
|
|
static
|
|
In a :term:`JIT` compilation, a value that is not traced (see :term:`Tracer`). Also
|
|
sometimes refers to compile-time computations on static values.
|
|
|
|
TPU
|
|
Short for *Tensor Processing Unit*, TPUs are chips specifically engineered for fast operations
|
|
on N-dimensional tensors used in deep learning applications. JAX is able to target TPUs for
|
|
fast operations on arrays (see also :term:`CPU` and :term:`GPU`).
|
|
|
|
Tracer
|
|
An object used as a standin for a JAX :term:`Array` in order to determine the
|
|
sequence of operations performed by a Python function. Internally, JAX implements this
|
|
via the :class:`jax.core.Tracer` class.
|
|
|
|
transformation
|
|
A higher-order function: that is, a function that takes a function as input and outputs
|
|
a transformed function. Examples in JAX include :func:`jax.jit`, :func:`jax.vmap`, and
|
|
:func:`jax.grad`.
|
|
|
|
VJP
|
|
Short for *Vector Jacobian Product*, also sometimes known as *reverse-mode* automatic
|
|
differentiation. For more details, see :ref:`vector-jacobian-product`. In JAX, VJP is
|
|
a :term:`transformation` that is implemented via :func:`jax.vjp`. See also :term:`JVP`.
|
|
|
|
XLA
|
|
Short for *Accelerated Linear Algebra*, XLA is a domain-specific compiler for linear
|
|
algebra operations that is the primary backend for :term:`JIT`-compiled JAX code.
|
|
See https://www.tensorflow.org/xla/.
|
|
|
|
weak type
|
|
A JAX data type that has the same type promotion semantics as Python scalars;
|
|
see :ref:`weak-types`.
|