Add/update JAX Advanced Tutorials docs, ToC structure

This commit is contained in:
8bitmp3 2024-09-20 22:19:14 +00:00
parent 6b93b35842
commit 0cf040c9a1
20 changed files with 20 additions and 4426 deletions

View File

@ -38,10 +38,7 @@ JAX 201
:maxdepth: 1
parallelism
advanced-autodiff
gradient-checkpointing
advanced-debugging
external-callbacks
profiling-and-performance
JAX 301
@ -50,6 +47,4 @@ JAX 301
.. toctree::
:maxdepth: 1
jax-primitives
jaxpr
advanced-compilation

View File

@ -360,4 +360,6 @@ rediraffe_redirects = {
'jax-101/07-state.md': 'stateful-computations.md',
'jax-101/08-pjit.rst': 'sharded-computation.md',
'jax-101/index.rst': 'tutorials.rst',
'notebooks/external_callbacks.md': 'external-callbacks.md',
'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md',
}

View File

@ -10,8 +10,6 @@ that use or interface with JAX.
:caption: Extensible JAX internals
:maxdepth: 1
notebooks/How_JAX_primitives_work
jaxpr
notebooks/Writing_custom_interpreters_in_Jax
Custom_Operation_for_GPUs
jax.extend

View File

@ -364,7 +364,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
"We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
]
},
{

View File

@ -311,7 +311,7 @@ Our implementation of `rms_norm` has the appropriate semantics, and it supports
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
```
We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:
We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:
```{code-cell} ipython3
jax.make_jaxpr(jax.vmap(rms_norm))(x)

View File

@ -30,7 +30,7 @@ Glossary of terms
jaxpr
Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that
is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution.
See :ref:`understanding-jaxprs` for more discussion and examples.
See :ref:`jax-internals-jaxpr` for more discussion and examples.
JIT
Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of

View File

@ -121,8 +121,6 @@ maintains an up-to-date list.
installation
quickstart
notebooks/Common_Gotchas_in_JAX
faq
.. toctree::
:hidden:
@ -130,11 +128,14 @@ maintains an up-to-date list.
tutorials
notebooks/Common_Gotchas_in_JAX
faq
.. toctree::
:hidden:
:maxdepth: 2
:caption: Resources
:caption: More guides/resources
user_guides
advanced_guide

View File

@ -1,472 +0,0 @@
.. _understanding-jaxprs:
Understanding Jaxprs
====================
Updated: May 3, 2020 (for commit f1a46fe).
Conceptually, one can think of JAX transformations as first trace-specializing
the Python function to be transformed into a small and well-behaved
intermediate form that is then interpreted with transformation-specific
interpretation rules. One of the reasons JAX can pack so much power into such a
small software package is that it starts with a familiar and flexible
programming interface (Python with NumPy) and it uses the actual Python
interpreter to do most of the heavy lifting to distill the essence of the
computation into a simple statically-typed expression language with limited
higher-order features. That language is the jaxpr language.
Not all Python programs can be processed this way, but it turns out that many
scientific computing and machine learning programs can.
Before we proceed, it is important to point out that not all JAX
transformations literally materialize a jaxpr as described above; some, e.g.,
differentiation or batching, will apply transformations incrementally during
tracing. Nevertheless, if one wants to understand how JAX works internally, or
to make use of the result of JAX tracing, it is useful to understand jaxprs.
A jaxpr instance represents a function with one or more typed parameters (input
variables) and one or more typed results. The results depend only on the input
variables; there are no free variables captured from enclosing scopes. The
inputs and outputs have types, which in JAX are represented as abstract values.
There are two related representations in the code for jaxprs,
:py:class:`jax.core.Jaxpr` and :py:class:`jax.core.ClosedJaxpr`. A
:py:class:`jax.core.ClosedJaxpr` represents a partially-applied
:py:class:`jax.core.Jaxpr`, and is what you obtain when you use
:py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following fields:
* ``jaxpr`` is a :py:class:`jax.core.Jaxpr` representing the actual
computation content of the function (described below).
* ``consts`` is a list of constants.
The most interesting part of the ClosedJaxpr is the actual execution content,
represented as a :py:class:`jax.core.Jaxpr` as printed using the following
grammar::
Jaxpr ::= { lambda Var* ; Var+. let
Eqn*
in [Expr+] }
where:
* The parameters of the jaxpr are shown as two lists of variables separated by
``;``. The first set of variables are the ones that have been introduced
to stand for constants that have been hoisted out. These are called the
``constvars``, and in a :py:class:`jax.core.ClosedJaxpr` the ``consts``
field holds corresponding values. The second list of variables, called
``invars``, correspond to the inputs of the traced Python function.
* ``Eqn*`` is a list of equations, defining intermediate variables referring to
intermediate expressions. Each equation defines one or more variables as the
result of applying a primitive on some atomic expressions. Each equation uses only
input variables and intermediate variables defined by previous equations.
* ``Expr+``: is a list of output atomic expressions (literals or variables)
for the jaxpr.
Equations are printed as follows::
Eqn ::= Var+ = Primitive [ Param* ] Expr+
where:
* ``Var+`` are one or more intermediate variables to be defined as the output
of a primitive invocation (some primitives can return multiple values).
* ``Expr+`` are one or more atomic expressions, each either a variable or a
literal constant. A special variable ``unitvar`` or literal ``unit``,
printed as ``*``, represents a value that is not needed
in the rest of the computation and has been elided. That is, units are just
placeholders.
* ``Param*`` are zero or more named parameters to the primitive, printed in
square brackets. Each parameter is shown as ``Name = Value``.
Most jaxpr primitives are first-order (they take just one or more ``Expr`` as arguments)::
Primitive := add | sub | sin | mul | ...
The jaxpr primitives are documented in the :py:mod:`jax.lax` module.
For example, here is the jaxpr produced for the function ``func1`` below
>>> from jax import make_jaxpr
>>> import jax.numpy as jnp
>>> def func1(first, second):
... temp = first + jnp.sin(second) * 3.
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
Here there are no constvars, ``a`` and ``b`` are the input variables
and they correspond respectively to
``first`` and ``second`` function parameters. The scalar literal ``3.0`` is kept
inline.
The ``reduce_sum`` primitive has named parameter ``axes``, in addition to the
operand ``e``.
Note that even though execution of a program that calls into JAX builds a jaxpr,
Python-level control-flow and Python-level functions execute normally.
This means that just because a Python program contains functions and control-flow,
the resulting jaxpr does not have to contain control-flow or higher-order features.
For example, when tracing the function ``func3`` JAX will inline the call to
``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
jaxpr as before
>>> def func2(inner, first, second):
... temp = first + inner(second) * 3.
... return jnp.sum(temp)
...
>>> def inner(second):
... if second.shape[0] > 4:
... return jnp.sin(second)
... else:
... assert False
...
>>> def func3(first, second):
... return func2(inner, first, second)
...
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
Handling PyTrees
----------------
In jaxpr there are no tuple types; instead primitives take multiple inputs
and produce multiple outputs. When processing a function that has structured
inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists
of inputs and outputs. For more details, please see the documentation for
PyTrees (:ref:`pytrees`).
For example, the following code produces an identical jaxpr to what we saw
before (with two input vars, one for each element of the input tuple)
>>> def func4(arg): # Arg is a pair
... temp = arg[0] + jnp.sin(arg[1]) * 3.
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
Constant vars
-------------
Some values in jaxprs are constants, in that their value does not depend on the
jaxpr's arguments. When these values are scalars they are represented directly
in the jaxpr equations; non-scalar array constants are instead hoisted out to
the top-level jaxpr, where they correspond to constant variables ("constvars").
These constvars differ from the other jaxpr parameters ("invars") only as a
bookkeeping convention.
Higher-order primitives
-----------------------
jaxpr includes several higher-order primitives. They are more complicated because
they include sub-jaxprs.
Conditionals
^^^^^^^^^^^^
JAX traces through normal Python conditionals. To capture a
conditional expression for dynamic execution, one must use the
:py:func:`jax.lax.switch` and :py:func:`jax.lax.cond` constructors,
which have the signatures::
lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B
lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B
Both of these will bind a primitive called ``cond`` internally. The
``cond`` primitive in jaxprs reflects the more general signature of
:py:func:`lax.switch`: it takes an integer denoting the index of the branch
to execute (clamped into valid indexing range).
For example:
>>> from jax import lax
>>>
>>> def one_of_three(index, arg):
... return lax.switch(index, [lambda x: x + 1.,
... lambda x: x - 2.,
... lambda x: x + 3.],
... arg)
...
>>> print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
d:i32[] = clamp 0 c 2
e:f32[] = cond[
branches=(
{ lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
{ lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
{ lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
)
] d b
in (e,) }
The `branches` parameter to the cond primitive corresponds to the branch
functionals. In this example, those functionals each take one input variable,
corresponding to ``x``.
The above instance of the cond primitive takes two operands. The first
one (``d``) is the branch index, then ``b`` is the operand (``arg``) to
be passed to whichever jaxpr in ``branches`` is selected by the branch
index.
Another example, using :py:func:`lax.cond`:
>>> from jax import lax
>>>
>>> def func7(arg):
... return lax.cond(arg >= 0.,
... lambda xtrue: xtrue + 3.,
... lambda xfalse: xfalse - 3.,
... arg)
...
>>> print(make_jaxpr(func7)(5.))
{ lambda ; a:f32[]. let
b:bool[] = ge a 0.0
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
d:f32[] = cond[
branches=(
{ lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) }
{ lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) }
)
] c a
in (d,) }
In this case, the boolean predicate is converted to an integer index
(0 or 1), and ``branches`` are jaxprs that correspond to the false and
true branch functionals, in that order. Again, each functional takes
one input variable, corresponding to ``xfalse`` and ``xtrue``
respectively.
The following example shows a more complicated situation when the input
to the branch functionals is a tuple, and the `false` branch functional
contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
>>> def func8(arg1, arg2): # arg2 is a pair
... return lax.cond(arg1 >= 0.,
... lambda xtrue: xtrue[0],
... lambda xfalse: jnp.array([1]) + xfalse[1],
... arg2)
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
e:bool[] = ge b 0.0
f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
g:f32[1] = cond[
branches=(
{ lambda ; h:i32[1] i:f32[1] j:f32[]. let
k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h
l:f32[1] = add k j
in (l,) }
{ lambda ; m_:i32[1] n:f32[1] o:f32[]. let in (n,) }
)
] f a c d
in (g,) }
While
^^^^^
Just like for conditionals, Python loops are inlined during tracing.
If you want to capture a loop for dynamic execution, you must use one of several
special operations, :py:func:`jax.lax.while_loop` (a primitive)
and :py:func:`jax.lax.fori_loop`
(a helper that generates a while_loop primitive)::
lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
In the above signature, “C” stands for the type of the loop “carry” value.
For example, here is an example fori loop
>>> import numpy as np
>>>
>>> def func10(arg, n):
... ones = jnp.ones(arg.shape) # A constant
... return lax.fori_loop(0, n,
... lambda i, carry: carry + ones * 3. + arg,
... arg + ones)
...
>>> print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda ; a:f32[16] b:i32[]. let
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
d:f32[16] = add a c
_:i32[] _:i32[] e:f32[16] = while[
body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
k:i32[] = add h 1
l:f32[16] = mul f 3.0
m:f32[16] = add j l
n:f32[16] = add m g
in (k, i, n) }
body_nconsts=2
cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let
r:bool[] = lt o p
in (r,) }
cond_nconsts=0
] c a 0 b d
in (e,) }
The while primitive takes 5 arguments: ``c a 0 b d``, as follows:
* 0 constants for ``cond_jaxpr`` (since ``cond_nconsts`` is 0)
* 2 constants for ``body_jaxpr`` (``c``, and ``a``)
* 3 parameters for the initial value of carry
Scan
^^^^
JAX supports a special form of loop over the elements of an array (with
statically known shape). The fact that there are a fixed number of iterations
makes this form of looping easily reverse-differentiable. Such loops are
constructed with the :py:func:`jax.lax.scan` function::
lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
This is written in terms of a `Haskell Type Signature`_:
``C`` is the type of the scan carry, ``A`` is the element type of the
input array(s), and ``B`` is the element type of the output array(s).
For the example consider the function ``func11`` below
>>> def func11(arr, extra):
... ones = jnp.ones(arr.shape) # A constant
... def body(carry, aelems):
... # carry: running dot-product of the two arrays
... # aelems: a pair with corresponding elements from the two arrays
... ae1, ae2 = aelems
... return (carry + ae1 * ae2 + extra, carry)
... return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda ; a:f32[16] b:f32[]. let
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
d:f32[] e:f32[16] = scan[
_split_transpose=False
jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
j:f32[] = mul h i
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
l:f32[] = add k j
m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
n:f32[] = add l m
in (n, g) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1
reverse=False
unroll=1
] b 0.0 a c
in (d, e) }
The ``linear`` parameter describes for each of the input variables whether they
are guaranteed to be used linearly in the body. Once the scan goes through
linearization, more arguments will be linear.
The scan primitive takes 4 arguments: ``b 0.0 a c``, of which:
* one is the free variable for the body
* one is the initial value of the carry
* The next 2 are the arrays over which the scan operates.
XLA_call
^^^^^^^^
The call primitive arises from JIT compilation, and it encapsulates
a sub-jaxpr along with parameters that specify the backend and the device on
which the computation should run. For example
>>> from jax import jit
>>>
>>> def func12(arg):
... @jit
... def inner(x):
... return x + arg * jnp.ones(1) # Include a constant in the inner function
... return arg + inner(arg - 2.)
...
>>> print(make_jaxpr(func12)(1.)) # doctest:+ELLIPSIS
{ lambda ; a:f32[]. let
b:f32[] = sub a 2.0
c:f32[1] = pjit[
name=inner
jaxpr={ lambda ; d:f32[] e:f32[]. let
f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
h:f32[1] = mul g f
i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
j:f32[1] = add i h
in (j,) }
] a b
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
l:f32[1] = add k c
in (l,) }
XLA_pmap
^^^^^^^^
If you use the :py:func:`jax.pmap` transformation, the function to be mapped is
captured using the ``xla_pmap`` primitive. Consider this example
>>> from jax import pmap
>>>
>>> def func13(arr, extra):
... def inner(x):
... # use a free variable "extra" and a constant jnp.ones(1)
... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
... return pmap(inner, axis_name='rows')(arr)
...
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda ; a:f32[1,3] b:f32[]. let
c:f32[1,3] = xla_pmap[
axis_name=rows
axis_size=1
backend=None
call_jaxpr={ lambda ; d:f32[] e:f32[3]. let
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
g:f32[3] = add e f
h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
i:f32[3] = add g h
j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e
k:f32[3] = div i j
in (k,) }
devices=None
donated_invars=(False, False)
global_axis_size=1
in_axes=(None, 0)
is_explicit_global_axis_size=False
name=inner
out_axes=(0,)
] b a
in (c,) }
The ``xla_pmap`` primitive specifies the name of the axis (parameter
``axis_name``) and the body of the function to be mapped as the ``call_jaxpr``
parameter. The value of this parameter is a Jaxpr with 2 input variables.
The parameter ``in_axes`` specifies which of the input variables should be
mapped and which should be broadcast. In our example, the value of ``extra``
is broadcast and the value of ``arr`` is mapped.
.. _Haskell Type Signature: https://wiki.haskell.org/Type_signature

View File

@ -51,7 +51,7 @@ def log2(x):
print(jax.make_jaxpr(log2)(3.0))
```
The {ref}`understanding-jaxprs` section of the documentation provides more information on the meaning of the above output.
The {ref}`jax-internals-jaxpr` section of the documentation provides more information on the meaning of the above output.
Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`.
This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code.

File diff suppressed because it is too large Load Diff

View File

@ -1,771 +0,0 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
name: python3
---
+++ {"id": "vfxqky4PCUnh"}
# How JAX primitives work
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)
*necula@google.com*, October 2019.
JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,
`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,
which means that as the Python function executes
the only operations it applies to the data are either inspections of data
attributes such as shape or type, or special operations called JAX primitives.
In particular, a JAX-traceable function is sometimes invoked by JAX with
abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,
which captures the type and the shape of values, but not the concrete data values.
JAX primitives know how to operate on both concrete data
values and on the JAX abstract values.
The JAX-transformed functions must themselves be JAX-traceable functions,
to ensure that these transformations
can be composed, e.g., `jit(jacfwd(grad(f)))`.
There are pre-defined JAX primitives corresponding to most XLA operations,
e.g., add, matmul, sin, cos, indexing.
JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs
using JAXs implementation of numpy are JAX-traceable and therefore transformable.
Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.
The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives,
one can define a new primitive that encapsulates the behavior of the function.
**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**
Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically
as "multiply_add(x, y, z) = x * y + z".
This function operates on 3 identically-shaped tensors of floating point
values and performs the operations pointwise.
+++ {"id": "HIJYIHNTD1yI"}
## Using existing primitives
The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other
functions that are themselves written using JAX primitives, e.g., those
defined in the `jax.lax` module:
```{code-cell} ipython3
:id: tbOF0LB0EMne
:outputId: 3fb1c8a7-7a4c-4a3a-f7ff-37b7dc740528
from jax import lax
from jax._src import api
def multiply_add_lax(x, y, z):
"""Implementation of multiply-add using the jax.lax primitives."""
return lax.add(lax.mul(x, y), z)
def square_add_lax(a, b):
"""A square-add function using the newly defined multiply-add."""
return multiply_add_lax(a, a, b)
print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
```
+++ {"id": "Cgv60Wm3E_D5"}
In order to understand how JAX is internally using the primitives,
we add some helpers for tracing function calls.
```{code-cell} ipython3
:cellView: form
:id: mQRQGEGiE53K
#@title Helper functions (execute this cell)
import functools
import traceback
_indentation = 0
def _trace(msg=None):
"""Print a message at current indentation."""
if msg is not None:
print(" " * _indentation + msg)
def _trace_indent(msg=None):
"""Print a message and then indent the rest."""
global _indentation
_trace(msg)
_indentation = 1 + _indentation
def _trace_unindent(msg=None):
"""Unindent then print a message."""
global _indentation
_indentation = _indentation - 1
_trace(msg)
def trace(name):
"""A decorator for functions to trace arguments and results."""
def trace_func(func): # pylint: disable=missing-docstring
def pp(v):
"""Print certain values more succinctly"""
vtype = str(type(v))
if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
return "<JaxComputationBuilder>"
elif "jaxlib.xla_extension.XlaOp" in vtype:
return "<XlaOp at 0x{:x}>".format(id(v))
elif ("partial_eval.JaxprTracer" in vtype or
"batching.BatchTracer" in vtype or
"ad.JVPTracer" in vtype):
return "Traced<{}>".format(v.aval)
elif isinstance(v, tuple):
return "({})".format(pp_values(v))
else:
return str(v)
def pp_values(args):
return ", ".join([pp(arg) for arg in args])
@functools.wraps(func)
def func_wrapper(*args):
_trace_indent("call {}({})".format(name, pp_values(args)))
res = func(*args)
_trace_unindent("|<- {} = {}".format(name, pp(res)))
return res
return func_wrapper
return trace_func
class expectNotImplementedError(object):
"""Context manager to check for NotImplementedError."""
def __enter__(self): pass
def __exit__(self, type, value, tb):
global _indentation
_indentation = 0
if type is NotImplementedError:
print("\nFound expected exception:")
traceback.print_exc(limit=3)
return True
elif type is None: # No exception
assert False, "Expected NotImplementedError"
else:
return False
```
+++ {"id": "Qf4eLrLCFYDl"}
Instead of using `jax.lax` primitives directly, we can use other functions
that are already written in terms of those primitives, such as those in `jax.numpy`:
```{code-cell} ipython3
:id: QhKorz6cFRJb
:outputId: aba3cef3-6bcc-4eb3-c7b3-34e405f2f82a
import jax.numpy as jnp
import numpy as np
@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
return jnp.add(jnp.multiply(x, y), z)
@trace("square_add_numpy")
def square_add_numpy(a, b):
return multiply_add_numpy(a, a, b)
print("\nNormal evaluation:")
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
```
+++ {"id": "Sg-D8EdeFn4a"}
Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and
`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further
below in this colab).
It is important to remember that a JAX-traceable function must be able to
operate not only on concrete arguments but also on special abstract arguments
that JAX may use to abstract the function execution.
The JAX traceability property is satisfied as long as the function is written
in terms of JAX primitives.
+++ {"id": "WxrQO7-XGLcg"}
## Defining new JAX primitives
The right way to add support for multiply-add is in terms of existing
JAX primitives, as shown above. However, in order to demonstrate how JAX
primitives work let us pretend that we want to add a new primitive to
JAX for the multiply-add functionality.
```{code-cell} ipython3
:id: cPqAH1XOGTN4
from jax import core
multiply_add_p = core.Primitive("multiply_add") # Create the primitive
@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
"""The JAX-traceable way to use the JAX primitive.
Note that the traced arguments must be passed as positional arguments
to `bind`.
"""
return multiply_add_p.bind(x, y, z)
@trace("square_add_prim")
def square_add_prim(a, b):
"""A square-add function implemented using the new JAX-primitive."""
return multiply_add_prim(a, a, b)
```
+++ {"id": "LMzs5PAKGr-4"}
If we try to call the newly defined functions we get an error, because
we have not yet told JAX anything about the semantics of the new primitive.
```{code-cell} ipython3
:id: _X3PAYxhGpWd
:outputId: 90ea2c6a-9ef3-40ea-e9a3-3ab1cfc59fc8
with expectNotImplementedError():
square_add_prim(2., 10.)
```
+++ {"id": "elha0FdgHSEF"}
### Primal evaluation rules
```{code-cell} ipython3
:id: FT34FFAGHARU
:outputId: 4c54f1c2-8a50-4788-90e1-06aee412c43b
@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
"""Concrete implementation of the primitive.
This function does not need to be JAX traceable.
Args:
x, y, z: the concrete arguments of the primitive. Will only be called with
concrete values.
Returns:
the concrete result of the primitive.
"""
# Note that we can use the original numpy, which is not JAX traceable
return np.add(np.multiply(x, y), z)
# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)
```
```{code-cell} ipython3
:id: G5bstKaeNAVV
:outputId: deb94d5b-dfea-4e6f-9ec2-70b416c996c5
assert square_add_prim(2., 10.) == 14.
```
+++ {"id": "upBf-uAuHhPJ"}
### JIT
If we now try to use `jit` we get a `NotImplementedError`:
```{code-cell} ipython3
:id: QG-LULjiHk4b
:outputId: d4ef4406-8dae-4c96-97ca-b662340474ee
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
```
+++ {"id": "rHS1bAGHH44E"}
#### Abstract evaluation rules
In order to JIT the function, and for other transformations as well,
JAX first evaluates it abstractly using only the
shape and type of the arguments. This abstract evaluation serves multiple
purposes:
* Gets the sequence of JAX primitives that are used in the computation. This
sequence will be compiled.
* Computes the shape and type of all vectors and operations used in the computation.
For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.
In the latter case, JAX uses the actual concrete value wrapped as an abstract value.
```{code-cell} ipython3
:id: ctQmEeckIbdo
:outputId: e751d0cc-460e-4ffd-df2e-fdabf9cffdc2
from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive.
This function does not need to be JAX traceable. It will be invoked with
abstractions of the actual arguments.
Args:
xs, ys, zs: abstractions of the arguments.
Result:
a ShapedArray for the result of the primitive.
"""
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return core.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
```
+++ {"id": "RPN88X6YI43A"}
If we re-attempt to JIT, we see how the abstract evaluation proceeds, but
we get another error, about missing the actual XLA compilation rule:
```{code-cell} ipython3
:id: eOcNR92SI2h-
:outputId: 356ef229-3703-4696-cc3d-7c05de405fb0
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
```
+++ {"id": "9IOV1R-fJMHp"}
#### XLA Compilation rules
JAX compilation works by compiling each primitive into a graph of XLA operations.
This is the biggest hurdle to adding new functionality to JAX, because the
set of XLA operations is limited, and JAX already has pre-defined primitives
for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++.
```{code-cell} ipython3
:id: FYQWSSjKJaWP
from jax._src.lib.mlir.dialects import hlo
@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
"""The compilation to XLA of the primitive.
Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
the results of the function.
Does not need to be a JAX-traceable function.
"""
return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]
# Now we register the lowering rule with JAX
# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
# TODO: TPU?
from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
```
+++ {"id": "K98LX-VaJkFu"}
Now we succeed to JIT. Notice below that JAX first evaluates the function
abstractly, which triggers the `multiply_add_abstract_eval` function, and
then compiles the set of primitives it has encountered, including `multiply_add`.
At this point JAX invokes `multiply_add_xla_translation`.
```{code-cell} ipython3
:id: rj3TLsolJgEc
:outputId: e384bee4-1e9c-4344-f49c-d3b5ec08eb32
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
```
+++ {"id": "Omrez-2_KFfo"}
Below is another use of `jit` where we compile only
with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads
in the third argument to `multiply_add_abstract_eval` being
`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with
both `ShapedArray` and `ConcreteArray`.
```{code-cell} ipython3
:id: mPfTwIBoKOEK
:outputId: b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b
assert api.jit(lambda x, y: square_add_prim(x, y),
static_argnums=1)(2., 10.) == 14.
```
+++ {"id": "_Ya3B5l4J1VA"}
### Forward differentiation
JAX implements forward differentiation in the form of
a Jacobian-vector product (see the [JAX autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Jacobian-Matrix-and-Matrix-Jacobian-products)).
If we attempt now to compute the `jvp` function we get an
error because we have not yet told JAX how to differentiate
the `multiply_add` primitive.
```{code-cell} ipython3
:id: OxDx6NQnKwMI
:outputId: ce659ef3-c03c-4856-f252-49ec4b6eb964
# The second argument `(2., 10.)` are the argument values
# where we evaluate the Jacobian, and the third `(1., 1.)`
# are the values of the tangents for the arguments.
with expectNotImplementedError():
api.jvp(square_add_prim, (2., 10.), (1., 1.))
```
```{code-cell} ipython3
:id: zxG24C1JMIMM
from jax.interpreters import ad
@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
"""Evaluates the primal output and the tangents (Jacobian-vector product).
Given values of the arguments and perturbation of the arguments (tangents),
compute the output of the primitive and the perturbation of the output.
This method must be JAX-traceable. JAX may invoke it with abstract values
for the arguments and tangents.
Args:
arg_values: a tuple of arguments
arg_tangents: a tuple with the tangents of the arguments. The tuple has
the same length as the arg_values. Some of the tangents may also be the
special value ad.Zero to specify a zero tangent.
Returns:
a pair of the primal output and the tangent.
"""
x, y, z = arg_values
xt, yt, zt = arg_tangents
_trace("Primal evaluation:")
# Now we have a JAX-traceable computation of the output.
# Normally, we can use the ma primitive itself to compute the primal output.
primal_out = multiply_add_prim(x, y, z)
_trace("Tangent evaluation:")
# We must use a JAX-traceable way to compute the tangent. It turns out that
# the output tangent can be computed as (xt * y + x * yt + zt),
# which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
# We do need to deal specially with Zero. Here we just turn it into a
# proper tensor of 0s (of the same shape as 'x').
# An alternative would be to check for Zero and perform algebraic
# simplification of the output tangent computation.
def make_zero(tan):
return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
return (primal_out, output_tangent)
# Register the forward differentiation rule with JAX
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
```
```{code-cell} ipython3
:id: ma3KBkiAMfW1
:outputId: f34cbbc6-20d9-48ca-9a9a-b5d91a972cdd
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
```
+++ {"id": "69QsEcu-lP4u"}
TO EXPLAIN:
* Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.
* Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet
we do not call the multiply_add_abstract_eval.
* I think it would be useful to show the jaxpr here
+++ {"id": "Sb6e3ZAHOPHv"}
#### JIT of forward differentiation
We can apply JIT to the forward differentiation function:
```{code-cell} ipython3
:id: hg-hzVu-N-hv
:outputId: 38d32067-e152-4046-ad80-7f95a31ba628
assert api.jit(lambda arg_values, arg_tangents:
api.jvp(square_add_prim, arg_values, arg_tangents))(
(2., 10.), (1., 1.)) == (14., 5.)
```
+++ {"id": "jlZt1_v2mU88"}
Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn
evaluates abstractly both the primal and the tangent evaluation (a total of
3 invocations of the `ma` primitive). Then we compile the 3 occurrences
of the primitive.
+++ {"id": "555yt6ZIOePB"}
### Reverse differentiation
If we attempt now to use reverse differentiation we
see that JAX starts by using the `multiply_add_value_and_jvp` to
compute the forward differentiation for abstract values, but then runs
into a `NotImplementedError`.
When computing the reverse differentiation JAX first does abstract evaluation
of the forward differentiation code `multiply_add_value_and_jvp` to obtain a
trace of primitives that compute the output tangent.
Observe that JAX performs this abstract evaluation with concrete values
for the differentiation point, and abstract values for the tangents.
Observe also that JAX uses the special abstract tangent value `Zero` for
the tangent corresponding to the 3rd argument of `ma`. This reflects the
fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,
which flows to the 3rd argument to `multiply_add_prim`.
Observe also that during the abstract evaluation of the tangent we pass the
value 0.0 as the tangent for the 3rd argument. This is due to the use
of the `make_zero` function in the definition of `multiply_add_value_and_jvp`.
```{code-cell} ipython3
:id: 8eAVnexaOjBn
:outputId: e4ee89cf-ab4a-4505-9817-fa978a2865ab
# This is reverse differentiation w.r.t. the first argument of square_add_prim
with expectNotImplementedError():
api.grad(square_add_prim)(2., 10.)
```
+++ {"id": "fSHLUMDN26AY"}
The above error is because there is a missing piece for JAX to be able
to use the forward differentiation code to compute reverse differentiation.
+++ {"id": "3ibDbGF-PjK9"}
#### Transposition
As explained above, when computing reverse differentiation JAX obtains
a trace of primitives that compute the tangent using forward differentiation.
Then, **JAX interprets this trace abstractly backwards** and for each
primitive it applies a **transposition** rule.
To understand what is going on, consider for now a simpler example of the function "f(x, y) = x * y + y". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:
```
a = xt * 4.
b = 2. * yt
c = a + b
ft = c + yt
```
By construction, the tangent calculation is always linear in the input tangents.
The only non-linear operator that may arise in the tangent calculation is multiplication,
but then one of the operands is constant.
JAX will produce the reverse differentiation computation by processing the
JVP computation backwards. For each operation in the tangent computation,
it accumulates the cotangents
of the variables used by the operation, using the cotangent of the result
of the operation:
```
# Initialize cotangents of inputs and intermediate vars
xct = yct = act = bct = cct = 0.
# Initialize cotangent of the output
fct = 1.
# Process "ft = c + yt"
cct += fct
yct += fct
# Process "c = a + b"
act += cct
bct += cct
# Process "b = 2. * yt"
yct += 2. * bct
# Process "a = xt * 4."
xct += act * 4.
```
One can verify that this computation produces `xct = 4.` and `yct = 3.`, which
are the partial derivatives of the function `f`.
JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:
```
p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)
```
Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other
arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned
for the constant arguments.
In particular,
```
add_transpose(out_ct, _, _) = (out_ct, out_ct)
mult_transpose(out_ct, x, _) = (None, x * out_ct)
mult_transpose(out_ct, _, y) = (out_ct * y, None)
```
```{code-cell} ipython3
:id: JaHxFdkRO42r
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
"""Evaluates the transpose of a linear primitive.
This method is only used when computing the backward gradient following
value_and_jvp, and is only needed for primitives that are used in the JVP
calculation for some other primitive. We need transposition for multiply_add_prim,
because we have used multiply_add_prim in the computation of the output_tangent in
multiply_add_value_and_jvp.
In our case, multiply_add is not a linear primitive. However, it is used linearly
w.r.t. tangents in multiply_add_value_and_jvp:
output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
Always one of the first two multiplicative arguments is a constant.
Args:
ct: the cotangent of the output of the primitive.
x, y, z: values of the arguments. The arguments that are used linearly
get an ad.UndefinedPrimal value. The other arguments get a constant
value.
Returns:
a tuple with the cotangent of the inputs, with the value None
corresponding to the constant arguments.
"""
if not ad.is_undefined_primal(x):
# This use of multiply_add is with a constant "x"
assert ad.is_undefined_primal(y)
ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
res = None, ct_y, ct
else:
# This use of multiply_add is with a constant "y"
assert ad.is_undefined_primal(x)
ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
res = ct_x, None, ct
return res
ad.primitive_transposes[multiply_add_p] = multiply_add_transpose
```
+++ {"id": "PpChox-Jp7wb"}
Now we can complete the run of the `grad`:
```{code-cell} ipython3
:id: PogPKS4MPevd
:outputId: d33328d4-3e87-45b5-9b31-21ad624b67af
assert api.grad(square_add_prim)(2., 10.) == 4.
```
+++ {"id": "8M1xLCXW4fK7"}
Notice the two calls to `multiply_add_transpose`. They correspond to the two
uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the
last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0.
+++ {"id": "EIJs6FYmPg6c"}
#### JIT of reverse differentiation
Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only
abstract values, while in the absence of JIT we used `ConcreteArray`.
```{code-cell} ipython3
:id: FZ-JGbWZPq2-
:outputId: e42b5222-9c3e-4853-e13a-874f6605d178
assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
```
+++ {"id": "-3lqPkdQPvl5"}
### Batching
The batching transformation takes a point-wise computation and turns it
into a computation on vectors. If we try it right now, we get a `NotImplementedError`:
```{code-cell} ipython3
:id: hFvBR3I9Pzh3
:outputId: 434608bc-281f-4d3b-83bd-eaaf3b51b1cd
# The arguments are two vectors instead of two scalars
with expectNotImplementedError():
api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
np.array([10., 20.]))
```
+++ {"id": "gILasMiP6elR"}
We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the `multiply_add_prim` already operates pointwise for any dimension of input vectors. So the batched version can use the same `multiply_add_prim` implementation.
```{code-cell} ipython3
:id: KQfeqRIrP7zg
from jax.interpreters import batching
@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
"""Computes the batched version of the primitive.
This must be a JAX-traceable function.
Since the multiply_add primitive already operates pointwise on arbitrary
dimension tensors, to batch it we can use the primitive itself. This works as
long as both the inputs have the same dimensions and are batched along the
same axes. The result is batched along the axis that the inputs are batched.
Args:
vector_arg_values: a tuple of two arguments, each being a tensor of matching
shape.
batch_axes: the axes that are being batched. See vmap documentation.
Returns:
a tuple of the result, and the result axis that was batched.
"""
assert batch_axes[0] == batch_axes[1]
assert batch_axes[0] == batch_axes[2]
_trace("Using multiply_add to compute the batch:")
res = multiply_add_prim(*vector_arg_values)
return res, batch_axes[0]
batching.primitive_batchers[multiply_add_p] = multiply_add_batch
```
```{code-cell} ipython3
:id: VwxNk869P_YG
:outputId: 9d22c921-5803-4d33-9e88-b6e439ba9738
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
```
+++ {"id": "NmqLlV1TQDCC"}
#### JIT of batching
```{code-cell} ipython3
:id: xqEdXVUgQCTt
:outputId: 9c22fd9c-919c-491d-bbeb-32c241b808fa
assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
(np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
```

File diff suppressed because it is too large Load Diff

View File

@ -1,515 +0,0 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
name: python3
---
+++ {"id": "7XNMxdTwURqI"}
# External callbacks
<!--* freshness: { reviewed: '2024-04-08' } *-->
+++ {"id": "h6lXo6bSUYGq"}
This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation.
+++ {"id": "Xi_nhfpnlmbm"}
## Why callbacks?
A callback routine is a way to perform **host-side** execution of code at runtime.
As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation.
Using a simple Python `print` statement, it looks like this:
```{code-cell}
:id: lz8rEL1Amb4r
:outputId: bbd37102-19f2-46d2-b794-3d4952c6fe97
import jax
@jax.jit
def f(x):
y = x + 1
print("intermediate value: {}".format(y))
return y * 2
result = f(2)
```
+++ {"id": "yEy41sFAmxOp"}
What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)).
To print the value at runtime we need a callback, for example `jax.debug.print`:
```{code-cell}
:id: wFfHmoQxnKDF
:outputId: 6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a
@jax.jit
def f(x):
y = x + 1
jax.debug.print("intermediate value: {}", y)
return y * 2
result = f(2)
```
+++ {"id": "CvWv3pudn9X5"}
This works by passing the runtime value represented by `y` back to the host process, where the host can print the value.
+++ {"id": "X0vR078znuT-"}
## Flavors of Callback
In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:
- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect.
- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.
- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.
(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`).
From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.
|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |
|-------------------------------------|----|----|----|----|----|----|
|`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ |
|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ |
|`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ |
¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff
² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`.
³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases.
+++ {"id": "hE_M8DaPvoym"}
### Exploring `jax.pure_callback`
`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).
The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times.
```{code-cell}
:id: 4lQDzXy6t_-k
:outputId: 279e4daf-0540-4eab-f535-d3bcbac74c44
import jax
import jax.numpy as jnp
import numpy as np
def f_host(x):
# call a numpy (not jax.numpy) operation:
return np.sin(x).astype(x.dtype)
def f(x):
result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.pure_callback(f_host, result_shape, x)
x = jnp.arange(5.0)
f(x)
```
+++ {"id": "q7YCIr8qMrDs"}
Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:"
```{code-cell}
:id: bgoZ0fxsuoWV
:outputId: 901443bd-5cb4-4923-ce53-6f832ac22ca9
jax.jit(f)(x)
```
```{code-cell}
:id: ajBRGWGfupu2
:outputId: b28e31ee-7457-4b92-872b-52d819f53ddf
jax.vmap(f)(x)
```
```{code-cell}
:id: xe7AOGexvC13
:outputId: 8fa77977-1f2b-41c5-cc5e-11993ee5aa3e
def body_fun(_, x):
return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
```
+++ {"id": "tMzAVs2VNj5G"}
However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:
```{code-cell}
:id: 4QAF4VhUu5bb
:outputId: f8a06d02-47e9-4240-8077-d7be81e5a480
%xmode minimal
```
```{code-cell}
:id: qUpKPxlOurfY
:outputId: 11a665e8-40eb-4b0e-dc2e-a544a25fc57e
:tags: [raises-exception]
jax.grad(f)(x)
```
+++ {"id": "y9DAibV4Nwpo"}
For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below.
+++ {"id": "LrvdAloMZbIe"}
By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:
```{code-cell}
:id: mmFc_zawZrBq
:outputId: a4df7568-3f64-4b2f-9a2c-7adb2e0815e0
def print_something():
print('printing something')
return np.int32(0)
@jax.jit
def f1():
return jax.pure_callback(print_something, np.int32(0))
f1();
```
```{code-cell}
:id: tTwE4kpmaNei
@jax.jit
def f2():
jax.pure_callback(print_something, np.int32(0))
return 1.0
f2();
```
+++ {"id": "qfyGYbw4Z5U3"}
In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.
In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects.
+++ {"id": "JHcJybr7OEBM"}
### Exploring `jax.experimental.io_callback`
In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.
As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!).
```{code-cell}
:id: eAg5xIhrOiWV
:outputId: e3cfec21-d843-4852-a49d-69a69fba9fc1
from jax.experimental import io_callback
from functools import partial
global_rng = np.random.default_rng(0)
def host_side_random_like(x):
"""Generate a random array like x using the global_rng state"""
# We have two side-effects here:
# - printing the shape and dtype
# - calling global_rng, thus updating its state
print(f'generating {x.dtype}{list(x.shape)}')
return global_rng.uniform(size=x.shape).astype(x.dtype)
@jax.jit
def numpy_random_like(x):
return io_callback(host_side_random_like, x, x)
x = jnp.zeros(5)
numpy_random_like(x)
```
+++ {"id": "mAIF31MlXj33"}
The `io_callback` is compatible with `vmap` by default:
```{code-cell}
:id: NY3o5dG6Vg6u
:outputId: a67a8a98-214e-40ca-ad98-a930cd3db85e
jax.vmap(numpy_random_like)(x)
```
+++ {"id": "XXvSeeOXXquZ"}
Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.
If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error:
```{code-cell}
:id: 3aNmRsDrX3-2
:outputId: a8ff4b77-f4cb-442f-8cfb-ea7251c66274
:tags: [raises-exception]
@jax.jit
def numpy_random_like_ordered(x):
return io_callback(host_side_random_like, x, x, ordered=True)
jax.vmap(numpy_random_like_ordered)(x)
```
+++ {"id": "fD2FTHlUYAZH"}
On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced:
```{code-cell}
:id: lMVzZlIEWL7F
:outputId: f9741c18-a30d-4d46-b706-8102849286b5
def body_fun(_, x):
return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
```
+++ {"id": "w_sf8mCbbo8K"}
Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable:
```{code-cell}
:id: Cn6_RG4JcKZm
:outputId: 336ae5d2-e35b-4fe5-cbfb-14a7aef28c07
:tags: [raises-exception]
jax.grad(numpy_random_like)(x)
```
+++ {"id": "plvfn9lWcKu4"}
However, if the callback is not dependent on a differentiated variable, it will execute:
```{code-cell}
:id: wxgfDmDfb5bx
:outputId: d8c0285c-cd04-4b4d-d15a-1b07f778882d
@jax.jit
def f(x):
io_callback(lambda: print('hello'), None)
return x
jax.grad(f)(1.0);
```
+++ {"id": "STLI40EZcVIY"}
Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation.
+++ {"id": "pkkM1ZmqclV-"}
### Exploring `debug.callback`
Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program.
```{code-cell}
:id: 74TdWyu9eqBa
:outputId: d8551dab-2e61-492e-9ac3-dc3db51b2c18
from jax import debug
def log_value(x):
# This could be an actual logging call; we'll use
# print() for demonstration
print("log:", x)
@jax.jit
def f(x):
debug.callback(log_value, x)
return x
f(1.0);
```
+++ {"id": "P848STlsfzmW"}
The debug callback is compatible with `vmap`:
```{code-cell}
:id: 2sSNsPB-fGVI
:outputId: fff58575-d94c-48fb-b88a-c1c395595fd0
x = jnp.arange(5.0)
jax.vmap(f)(x);
```
+++ {"id": "VDMacqpXf3La"}
And is also compatible with `grad` and other autodiff transformations
```{code-cell}
:id: wkFRle-tfTDe
:outputId: 4e8a81d0-5012-4c51-d843-3fbdc498df31
jax.grad(f)(1.0);
```
+++ {"id": "w8t-SDZ3gRzE"}
This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`.
+++ {"id": "dF7hoWGQUneJ"}
## Example: `pure_callback` with `custom_jvp`
One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`).
Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers.
Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`.
We can start by defining a straightforward `pure_callback`:
```{code-cell}
:id: Ge4fNPZdVSJY
import jax
import jax.numpy as jnp
import scipy.special
def jv(v, z):
v, z = jnp.asarray(v), jnp.asarray(z)
# Require the order v to be integer type: this simplifies
# the JVP rule below.
assert jnp.issubdtype(v.dtype, jnp.integer)
# Promote the input to inexact (float/complex).
# Note that jnp.result_type() accounts for the enable_x64 flag.
z = z.astype(jnp.result_type(float, z.dtype))
# Wrap scipy function to return the expected dtype.
_scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)
# Define the expected shape & dtype of output.
result_shape_dtype = jax.ShapeDtypeStruct(
shape=jnp.broadcast_shapes(v.shape, z.shape),
dtype=z.dtype)
# We use vectorize=True because scipy.special.jv handles broadcasted inputs.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
```
+++ {"id": "vyjQj-0QVuoN"}
This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`:
```{code-cell}
:id: f4e46670f4e4
j1 = partial(jv, 1)
z = jnp.arange(5.0)
```
```{code-cell}
:id: 6svImqFHWBwj
:outputId: bc8c778a-6c10-443b-9be2-c0f28e2ac1a9
print(j1(z))
```
+++ {"id": "d48eb4f2d48e"}
Here is the same result with `jit`:
```{code-cell}
:id: txvRqR9DWGdC
:outputId: d25f3476-23b1-48e4-dda1-3c06d32c3b87
print(jax.jit(j1)(z))
```
+++ {"id": "d861a472d861"}
And here is the same result again with `vmap`:
```{code-cell}
:id: BS-Ve5u_WU0C
:outputId: 08cecd1f-6953-4853-e9db-25a03eb5b000
print(jax.vmap(j1)(z))
```
+++ {"id": "SCH2ii_dWXP6"}
However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function:
```{code-cell}
:id: q3qh_4DrWxdQ
:outputId: c46b0bfa-96f3-4629-b9af-a4d4f3ccb870
:tags: [raises-exception]
jax.grad(j1)(z)
```
+++ {"id": "PtYeJ_xUW09v"}
Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:
$$
d J_\nu(z) = \left\{
\begin{eqnarray}
-J_1(z),\ &\nu=0\\
[J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0
\end{eqnarray}\right.
$$
The gradient with respect to $\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example.
We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function:
```{code-cell}
:id: BOVQnt05XvLs
jv = jax.custom_jvp(jv)
@jv.defjvp
def _jv_jvp(primals, tangents):
v, z = primals
_, z_dot = tangents # Note: v_dot is always 0 because v is integer.
jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
return jv(v, z), z_dot * djv_dz
```
+++ {"id": "W1SxcvQSX44c"}
Now computing the gradient of our function will work correctly:
```{code-cell}
:id: sCGceBs-X8nL
:outputId: 71c5589f-f996-44a0-f09a-ca8bb40c167a
j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))
```
+++ {"id": "gWQ4phN5YB26"}
Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free:
```{code-cell}
:id: QTe5mRAvYQBh
:outputId: d58ecff3-9419-422a-fd0e-14a7d9cf2cc3
jax.hessian(j1)(2.0)
```
+++ {"id": "QEXGxU4uYZii"}
Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device.
When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called.
However, if you are running JAX on a single CPU (where the "host" and "device" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities.

View File

@ -16,3 +16,13 @@ Tutorials
working-with-pytrees
sharded-computation
stateful-computations
.. toctree::
:maxdepth: 1
:caption: Advanced tutorials
advanced-autodiff
external-callbacks
gradient-checkpointing
jax-primitives
jaxpr

View File

@ -33,7 +33,6 @@ or deployed codebases.
:maxdepth: 1
:caption: Custom operations
notebooks/external_callbacks
pallas/index
ffi