mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add/update JAX Advanced Tutorials docs, ToC structure
This commit is contained in:
parent
6b93b35842
commit
0cf040c9a1
@ -38,10 +38,7 @@ JAX 201
|
||||
:maxdepth: 1
|
||||
|
||||
parallelism
|
||||
advanced-autodiff
|
||||
gradient-checkpointing
|
||||
advanced-debugging
|
||||
external-callbacks
|
||||
profiling-and-performance
|
||||
|
||||
JAX 301
|
||||
@ -50,6 +47,4 @@ JAX 301
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax-primitives
|
||||
jaxpr
|
||||
advanced-compilation
|
||||
|
@ -360,4 +360,6 @@ rediraffe_redirects = {
|
||||
'jax-101/07-state.md': 'stateful-computations.md',
|
||||
'jax-101/08-pjit.rst': 'sharded-computation.md',
|
||||
'jax-101/index.rst': 'tutorials.rst',
|
||||
'notebooks/external_callbacks.md': 'external-callbacks.md',
|
||||
'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md',
|
||||
}
|
||||
|
@ -10,8 +10,6 @@ that use or interface with JAX.
|
||||
:caption: Extensible JAX internals
|
||||
:maxdepth: 1
|
||||
|
||||
notebooks/How_JAX_primitives_work
|
||||
jaxpr
|
||||
notebooks/Writing_custom_interpreters_in_Jax
|
||||
Custom_Operation_for_GPUs
|
||||
jax.extend
|
||||
|
@ -364,7 +364,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
|
||||
"We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -311,7 +311,7 @@ Our implementation of `rms_norm` has the appropriate semantics, and it supports
|
||||
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
|
||||
```
|
||||
|
||||
We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:
|
||||
We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:
|
||||
|
||||
```{code-cell} ipython3
|
||||
jax.make_jaxpr(jax.vmap(rms_norm))(x)
|
||||
|
@ -30,7 +30,7 @@ Glossary of terms
|
||||
jaxpr
|
||||
Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that
|
||||
is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution.
|
||||
See :ref:`understanding-jaxprs` for more discussion and examples.
|
||||
See :ref:`jax-internals-jaxpr` for more discussion and examples.
|
||||
|
||||
JIT
|
||||
Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of
|
||||
|
@ -121,8 +121,6 @@ maintains an up-to-date list.
|
||||
|
||||
installation
|
||||
quickstart
|
||||
notebooks/Common_Gotchas_in_JAX
|
||||
faq
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
@ -130,11 +128,14 @@ maintains an up-to-date list.
|
||||
|
||||
tutorials
|
||||
|
||||
notebooks/Common_Gotchas_in_JAX
|
||||
|
||||
faq
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 2
|
||||
:caption: Resources
|
||||
:caption: More guides/resources
|
||||
|
||||
user_guides
|
||||
advanced_guide
|
||||
|
472
docs/jaxpr.rst
472
docs/jaxpr.rst
@ -1,472 +0,0 @@
|
||||
.. _understanding-jaxprs:
|
||||
|
||||
Understanding Jaxprs
|
||||
====================
|
||||
|
||||
Updated: May 3, 2020 (for commit f1a46fe).
|
||||
|
||||
Conceptually, one can think of JAX transformations as first trace-specializing
|
||||
the Python function to be transformed into a small and well-behaved
|
||||
intermediate form that is then interpreted with transformation-specific
|
||||
interpretation rules. One of the reasons JAX can pack so much power into such a
|
||||
small software package is that it starts with a familiar and flexible
|
||||
programming interface (Python with NumPy) and it uses the actual Python
|
||||
interpreter to do most of the heavy lifting to distill the essence of the
|
||||
computation into a simple statically-typed expression language with limited
|
||||
higher-order features. That language is the jaxpr language.
|
||||
|
||||
Not all Python programs can be processed this way, but it turns out that many
|
||||
scientific computing and machine learning programs can.
|
||||
|
||||
Before we proceed, it is important to point out that not all JAX
|
||||
transformations literally materialize a jaxpr as described above; some, e.g.,
|
||||
differentiation or batching, will apply transformations incrementally during
|
||||
tracing. Nevertheless, if one wants to understand how JAX works internally, or
|
||||
to make use of the result of JAX tracing, it is useful to understand jaxprs.
|
||||
|
||||
A jaxpr instance represents a function with one or more typed parameters (input
|
||||
variables) and one or more typed results. The results depend only on the input
|
||||
variables; there are no free variables captured from enclosing scopes. The
|
||||
inputs and outputs have types, which in JAX are represented as abstract values.
|
||||
There are two related representations in the code for jaxprs,
|
||||
:py:class:`jax.core.Jaxpr` and :py:class:`jax.core.ClosedJaxpr`. A
|
||||
:py:class:`jax.core.ClosedJaxpr` represents a partially-applied
|
||||
:py:class:`jax.core.Jaxpr`, and is what you obtain when you use
|
||||
:py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following fields:
|
||||
|
||||
* ``jaxpr`` is a :py:class:`jax.core.Jaxpr` representing the actual
|
||||
computation content of the function (described below).
|
||||
* ``consts`` is a list of constants.
|
||||
|
||||
The most interesting part of the ClosedJaxpr is the actual execution content,
|
||||
represented as a :py:class:`jax.core.Jaxpr` as printed using the following
|
||||
grammar::
|
||||
|
||||
Jaxpr ::= { lambda Var* ; Var+. let
|
||||
Eqn*
|
||||
in [Expr+] }
|
||||
|
||||
where:
|
||||
* The parameters of the jaxpr are shown as two lists of variables separated by
|
||||
``;``. The first set of variables are the ones that have been introduced
|
||||
to stand for constants that have been hoisted out. These are called the
|
||||
``constvars``, and in a :py:class:`jax.core.ClosedJaxpr` the ``consts``
|
||||
field holds corresponding values. The second list of variables, called
|
||||
``invars``, correspond to the inputs of the traced Python function.
|
||||
* ``Eqn*`` is a list of equations, defining intermediate variables referring to
|
||||
intermediate expressions. Each equation defines one or more variables as the
|
||||
result of applying a primitive on some atomic expressions. Each equation uses only
|
||||
input variables and intermediate variables defined by previous equations.
|
||||
* ``Expr+``: is a list of output atomic expressions (literals or variables)
|
||||
for the jaxpr.
|
||||
|
||||
Equations are printed as follows::
|
||||
|
||||
Eqn ::= Var+ = Primitive [ Param* ] Expr+
|
||||
|
||||
where:
|
||||
* ``Var+`` are one or more intermediate variables to be defined as the output
|
||||
of a primitive invocation (some primitives can return multiple values).
|
||||
* ``Expr+`` are one or more atomic expressions, each either a variable or a
|
||||
literal constant. A special variable ``unitvar`` or literal ``unit``,
|
||||
printed as ``*``, represents a value that is not needed
|
||||
in the rest of the computation and has been elided. That is, units are just
|
||||
placeholders.
|
||||
* ``Param*`` are zero or more named parameters to the primitive, printed in
|
||||
square brackets. Each parameter is shown as ``Name = Value``.
|
||||
|
||||
|
||||
Most jaxpr primitives are first-order (they take just one or more ``Expr`` as arguments)::
|
||||
|
||||
Primitive := add | sub | sin | mul | ...
|
||||
|
||||
|
||||
The jaxpr primitives are documented in the :py:mod:`jax.lax` module.
|
||||
|
||||
For example, here is the jaxpr produced for the function ``func1`` below
|
||||
|
||||
>>> from jax import make_jaxpr
|
||||
>>> import jax.numpy as jnp
|
||||
>>> def func1(first, second):
|
||||
... temp = first + jnp.sin(second) * 3.
|
||||
... return jnp.sum(temp)
|
||||
...
|
||||
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
|
||||
{ lambda ; a:f32[8] b:f32[8]. let
|
||||
c:f32[8] = sin b
|
||||
d:f32[8] = mul c 3.0
|
||||
e:f32[8] = add a d
|
||||
f:f32[] = reduce_sum[axes=(0,)] e
|
||||
in (f,) }
|
||||
|
||||
Here there are no constvars, ``a`` and ``b`` are the input variables
|
||||
and they correspond respectively to
|
||||
``first`` and ``second`` function parameters. The scalar literal ``3.0`` is kept
|
||||
inline.
|
||||
The ``reduce_sum`` primitive has named parameter ``axes``, in addition to the
|
||||
operand ``e``.
|
||||
|
||||
Note that even though execution of a program that calls into JAX builds a jaxpr,
|
||||
Python-level control-flow and Python-level functions execute normally.
|
||||
This means that just because a Python program contains functions and control-flow,
|
||||
the resulting jaxpr does not have to contain control-flow or higher-order features.
|
||||
|
||||
For example, when tracing the function ``func3`` JAX will inline the call to
|
||||
``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
|
||||
jaxpr as before
|
||||
|
||||
>>> def func2(inner, first, second):
|
||||
... temp = first + inner(second) * 3.
|
||||
... return jnp.sum(temp)
|
||||
...
|
||||
>>> def inner(second):
|
||||
... if second.shape[0] > 4:
|
||||
... return jnp.sin(second)
|
||||
... else:
|
||||
... assert False
|
||||
...
|
||||
>>> def func3(first, second):
|
||||
... return func2(inner, first, second)
|
||||
...
|
||||
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
|
||||
{ lambda ; a:f32[8] b:f32[8]. let
|
||||
c:f32[8] = sin b
|
||||
d:f32[8] = mul c 3.0
|
||||
e:f32[8] = add a d
|
||||
f:f32[] = reduce_sum[axes=(0,)] e
|
||||
in (f,) }
|
||||
|
||||
|
||||
Handling PyTrees
|
||||
----------------
|
||||
|
||||
In jaxpr there are no tuple types; instead primitives take multiple inputs
|
||||
and produce multiple outputs. When processing a function that has structured
|
||||
inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists
|
||||
of inputs and outputs. For more details, please see the documentation for
|
||||
PyTrees (:ref:`pytrees`).
|
||||
|
||||
For example, the following code produces an identical jaxpr to what we saw
|
||||
before (with two input vars, one for each element of the input tuple)
|
||||
|
||||
|
||||
>>> def func4(arg): # Arg is a pair
|
||||
... temp = arg[0] + jnp.sin(arg[1]) * 3.
|
||||
... return jnp.sum(temp)
|
||||
...
|
||||
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
|
||||
{ lambda ; a:f32[8] b:f32[8]. let
|
||||
c:f32[8] = sin b
|
||||
d:f32[8] = mul c 3.0
|
||||
e:f32[8] = add a d
|
||||
f:f32[] = reduce_sum[axes=(0,)] e
|
||||
in (f,) }
|
||||
|
||||
|
||||
|
||||
Constant vars
|
||||
-------------
|
||||
|
||||
Some values in jaxprs are constants, in that their value does not depend on the
|
||||
jaxpr's arguments. When these values are scalars they are represented directly
|
||||
in the jaxpr equations; non-scalar array constants are instead hoisted out to
|
||||
the top-level jaxpr, where they correspond to constant variables ("constvars").
|
||||
These constvars differ from the other jaxpr parameters ("invars") only as a
|
||||
bookkeeping convention.
|
||||
|
||||
|
||||
Higher-order primitives
|
||||
-----------------------
|
||||
|
||||
jaxpr includes several higher-order primitives. They are more complicated because
|
||||
they include sub-jaxprs.
|
||||
|
||||
Conditionals
|
||||
^^^^^^^^^^^^
|
||||
|
||||
JAX traces through normal Python conditionals. To capture a
|
||||
conditional expression for dynamic execution, one must use the
|
||||
:py:func:`jax.lax.switch` and :py:func:`jax.lax.cond` constructors,
|
||||
which have the signatures::
|
||||
|
||||
lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B
|
||||
|
||||
lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B
|
||||
|
||||
Both of these will bind a primitive called ``cond`` internally. The
|
||||
``cond`` primitive in jaxprs reflects the more general signature of
|
||||
:py:func:`lax.switch`: it takes an integer denoting the index of the branch
|
||||
to execute (clamped into valid indexing range).
|
||||
|
||||
For example:
|
||||
|
||||
>>> from jax import lax
|
||||
>>>
|
||||
>>> def one_of_three(index, arg):
|
||||
... return lax.switch(index, [lambda x: x + 1.,
|
||||
... lambda x: x - 2.,
|
||||
... lambda x: x + 3.],
|
||||
... arg)
|
||||
...
|
||||
>>> print(make_jaxpr(one_of_three)(1, 5.))
|
||||
{ lambda ; a:i32[] b:f32[]. let
|
||||
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
|
||||
d:i32[] = clamp 0 c 2
|
||||
e:f32[] = cond[
|
||||
branches=(
|
||||
{ lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
|
||||
{ lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
|
||||
{ lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
|
||||
)
|
||||
] d b
|
||||
in (e,) }
|
||||
|
||||
The `branches` parameter to the cond primitive corresponds to the branch
|
||||
functionals. In this example, those functionals each take one input variable,
|
||||
corresponding to ``x``.
|
||||
|
||||
The above instance of the cond primitive takes two operands. The first
|
||||
one (``d``) is the branch index, then ``b`` is the operand (``arg``) to
|
||||
be passed to whichever jaxpr in ``branches`` is selected by the branch
|
||||
index.
|
||||
|
||||
Another example, using :py:func:`lax.cond`:
|
||||
|
||||
>>> from jax import lax
|
||||
>>>
|
||||
>>> def func7(arg):
|
||||
... return lax.cond(arg >= 0.,
|
||||
... lambda xtrue: xtrue + 3.,
|
||||
... lambda xfalse: xfalse - 3.,
|
||||
... arg)
|
||||
...
|
||||
>>> print(make_jaxpr(func7)(5.))
|
||||
{ lambda ; a:f32[]. let
|
||||
b:bool[] = ge a 0.0
|
||||
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
|
||||
d:f32[] = cond[
|
||||
branches=(
|
||||
{ lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) }
|
||||
{ lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) }
|
||||
)
|
||||
] c a
|
||||
in (d,) }
|
||||
|
||||
In this case, the boolean predicate is converted to an integer index
|
||||
(0 or 1), and ``branches`` are jaxprs that correspond to the false and
|
||||
true branch functionals, in that order. Again, each functional takes
|
||||
one input variable, corresponding to ``xfalse`` and ``xtrue``
|
||||
respectively.
|
||||
|
||||
The following example shows a more complicated situation when the input
|
||||
to the branch functionals is a tuple, and the `false` branch functional
|
||||
contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
|
||||
|
||||
>>> def func8(arg1, arg2): # arg2 is a pair
|
||||
... return lax.cond(arg1 >= 0.,
|
||||
... lambda xtrue: xtrue[0],
|
||||
... lambda xfalse: jnp.array([1]) + xfalse[1],
|
||||
... arg2)
|
||||
...
|
||||
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
|
||||
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
|
||||
e:bool[] = ge b 0.0
|
||||
f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
|
||||
g:f32[1] = cond[
|
||||
branches=(
|
||||
{ lambda ; h:i32[1] i:f32[1] j:f32[]. let
|
||||
k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h
|
||||
l:f32[1] = add k j
|
||||
in (l,) }
|
||||
{ lambda ; m_:i32[1] n:f32[1] o:f32[]. let in (n,) }
|
||||
)
|
||||
] f a c d
|
||||
in (g,) }
|
||||
|
||||
|
||||
|
||||
While
|
||||
^^^^^
|
||||
|
||||
Just like for conditionals, Python loops are inlined during tracing.
|
||||
If you want to capture a loop for dynamic execution, you must use one of several
|
||||
special operations, :py:func:`jax.lax.while_loop` (a primitive)
|
||||
and :py:func:`jax.lax.fori_loop`
|
||||
(a helper that generates a while_loop primitive)::
|
||||
|
||||
lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
|
||||
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
|
||||
|
||||
|
||||
In the above signature, “C” stands for the type of the loop “carry” value.
|
||||
For example, here is an example fori loop
|
||||
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> def func10(arg, n):
|
||||
... ones = jnp.ones(arg.shape) # A constant
|
||||
... return lax.fori_loop(0, n,
|
||||
... lambda i, carry: carry + ones * 3. + arg,
|
||||
... arg + ones)
|
||||
...
|
||||
>>> print(make_jaxpr(func10)(np.ones(16), 5))
|
||||
{ lambda ; a:f32[16] b:i32[]. let
|
||||
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
|
||||
d:f32[16] = add a c
|
||||
_:i32[] _:i32[] e:f32[16] = while[
|
||||
body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
|
||||
k:i32[] = add h 1
|
||||
l:f32[16] = mul f 3.0
|
||||
m:f32[16] = add j l
|
||||
n:f32[16] = add m g
|
||||
in (k, i, n) }
|
||||
body_nconsts=2
|
||||
cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let
|
||||
r:bool[] = lt o p
|
||||
in (r,) }
|
||||
cond_nconsts=0
|
||||
] c a 0 b d
|
||||
in (e,) }
|
||||
|
||||
The while primitive takes 5 arguments: ``c a 0 b d``, as follows:
|
||||
|
||||
* 0 constants for ``cond_jaxpr`` (since ``cond_nconsts`` is 0)
|
||||
* 2 constants for ``body_jaxpr`` (``c``, and ``a``)
|
||||
* 3 parameters for the initial value of carry
|
||||
|
||||
Scan
|
||||
^^^^
|
||||
|
||||
JAX supports a special form of loop over the elements of an array (with
|
||||
statically known shape). The fact that there are a fixed number of iterations
|
||||
makes this form of looping easily reverse-differentiable. Such loops are
|
||||
constructed with the :py:func:`jax.lax.scan` function::
|
||||
|
||||
lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
|
||||
|
||||
This is written in terms of a `Haskell Type Signature`_:
|
||||
``C`` is the type of the scan carry, ``A`` is the element type of the
|
||||
input array(s), and ``B`` is the element type of the output array(s).
|
||||
|
||||
For the example consider the function ``func11`` below
|
||||
|
||||
>>> def func11(arr, extra):
|
||||
... ones = jnp.ones(arr.shape) # A constant
|
||||
... def body(carry, aelems):
|
||||
... # carry: running dot-product of the two arrays
|
||||
... # aelems: a pair with corresponding elements from the two arrays
|
||||
... ae1, ae2 = aelems
|
||||
... return (carry + ae1 * ae2 + extra, carry)
|
||||
... return lax.scan(body, 0., (arr, ones))
|
||||
...
|
||||
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
|
||||
{ lambda ; a:f32[16] b:f32[]. let
|
||||
c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
|
||||
d:f32[] e:f32[16] = scan[
|
||||
_split_transpose=False
|
||||
jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
|
||||
j:f32[] = mul h i
|
||||
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
|
||||
l:f32[] = add k j
|
||||
m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
|
||||
n:f32[] = add l m
|
||||
in (n, g) }
|
||||
length=16
|
||||
linear=(False, False, False, False)
|
||||
num_carry=1
|
||||
num_consts=1
|
||||
reverse=False
|
||||
unroll=1
|
||||
] b 0.0 a c
|
||||
in (d, e) }
|
||||
|
||||
The ``linear`` parameter describes for each of the input variables whether they
|
||||
are guaranteed to be used linearly in the body. Once the scan goes through
|
||||
linearization, more arguments will be linear.
|
||||
|
||||
The scan primitive takes 4 arguments: ``b 0.0 a c``, of which:
|
||||
|
||||
* one is the free variable for the body
|
||||
* one is the initial value of the carry
|
||||
* The next 2 are the arrays over which the scan operates.
|
||||
|
||||
XLA_call
|
||||
^^^^^^^^
|
||||
|
||||
The call primitive arises from JIT compilation, and it encapsulates
|
||||
a sub-jaxpr along with parameters that specify the backend and the device on
|
||||
which the computation should run. For example
|
||||
|
||||
>>> from jax import jit
|
||||
>>>
|
||||
>>> def func12(arg):
|
||||
... @jit
|
||||
... def inner(x):
|
||||
... return x + arg * jnp.ones(1) # Include a constant in the inner function
|
||||
... return arg + inner(arg - 2.)
|
||||
...
|
||||
>>> print(make_jaxpr(func12)(1.)) # doctest:+ELLIPSIS
|
||||
{ lambda ; a:f32[]. let
|
||||
b:f32[] = sub a 2.0
|
||||
c:f32[1] = pjit[
|
||||
name=inner
|
||||
jaxpr={ lambda ; d:f32[] e:f32[]. let
|
||||
f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
|
||||
g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
|
||||
h:f32[1] = mul g f
|
||||
i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
|
||||
j:f32[1] = add i h
|
||||
in (j,) }
|
||||
] a b
|
||||
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
|
||||
l:f32[1] = add k c
|
||||
in (l,) }
|
||||
|
||||
|
||||
XLA_pmap
|
||||
^^^^^^^^
|
||||
|
||||
If you use the :py:func:`jax.pmap` transformation, the function to be mapped is
|
||||
captured using the ``xla_pmap`` primitive. Consider this example
|
||||
|
||||
>>> from jax import pmap
|
||||
>>>
|
||||
>>> def func13(arr, extra):
|
||||
... def inner(x):
|
||||
... # use a free variable "extra" and a constant jnp.ones(1)
|
||||
... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
|
||||
... return pmap(inner, axis_name='rows')(arr)
|
||||
...
|
||||
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
|
||||
{ lambda ; a:f32[1,3] b:f32[]. let
|
||||
c:f32[1,3] = xla_pmap[
|
||||
axis_name=rows
|
||||
axis_size=1
|
||||
backend=None
|
||||
call_jaxpr={ lambda ; d:f32[] e:f32[3]. let
|
||||
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
|
||||
g:f32[3] = add e f
|
||||
h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
|
||||
i:f32[3] = add g h
|
||||
j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e
|
||||
k:f32[3] = div i j
|
||||
in (k,) }
|
||||
devices=None
|
||||
donated_invars=(False, False)
|
||||
global_axis_size=1
|
||||
in_axes=(None, 0)
|
||||
is_explicit_global_axis_size=False
|
||||
name=inner
|
||||
out_axes=(0,)
|
||||
] b a
|
||||
in (c,) }
|
||||
|
||||
The ``xla_pmap`` primitive specifies the name of the axis (parameter
|
||||
``axis_name``) and the body of the function to be mapped as the ``call_jaxpr``
|
||||
parameter. The value of this parameter is a Jaxpr with 2 input variables.
|
||||
|
||||
The parameter ``in_axes`` specifies which of the input variables should be
|
||||
mapped and which should be broadcast. In our example, the value of ``extra``
|
||||
is broadcast and the value of ``arr`` is mapped.
|
||||
|
||||
.. _Haskell Type Signature: https://wiki.haskell.org/Type_signature
|
@ -51,7 +51,7 @@ def log2(x):
|
||||
print(jax.make_jaxpr(log2)(3.0))
|
||||
```
|
||||
|
||||
The {ref}`understanding-jaxprs` section of the documentation provides more information on the meaning of the above output.
|
||||
The {ref}`jax-internals-jaxpr` section of the documentation provides more information on the meaning of the above output.
|
||||
|
||||
Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`.
|
||||
This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,771 +0,0 @@
|
||||
---
|
||||
jupytext:
|
||||
formats: ipynb,md:myst
|
||||
text_representation:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.4
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
---
|
||||
|
||||
+++ {"id": "vfxqky4PCUnh"}
|
||||
|
||||
# How JAX primitives work
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)
|
||||
|
||||
*necula@google.com*, October 2019.
|
||||
|
||||
JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,
|
||||
`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,
|
||||
which means that as the Python function executes
|
||||
the only operations it applies to the data are either inspections of data
|
||||
attributes such as shape or type, or special operations called JAX primitives.
|
||||
In particular, a JAX-traceable function is sometimes invoked by JAX with
|
||||
abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,
|
||||
which captures the type and the shape of values, but not the concrete data values.
|
||||
JAX primitives know how to operate on both concrete data
|
||||
values and on the JAX abstract values.
|
||||
|
||||
|
||||
The JAX-transformed functions must themselves be JAX-traceable functions,
|
||||
to ensure that these transformations
|
||||
can be composed, e.g., `jit(jacfwd(grad(f)))`.
|
||||
|
||||
There are pre-defined JAX primitives corresponding to most XLA operations,
|
||||
e.g., add, matmul, sin, cos, indexing.
|
||||
JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs
|
||||
using JAX’s implementation of numpy are JAX-traceable and therefore transformable.
|
||||
Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.
|
||||
|
||||
The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives,
|
||||
one can define a new primitive that encapsulates the behavior of the function.
|
||||
|
||||
**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**
|
||||
|
||||
Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically
|
||||
as "multiply_add(x, y, z) = x * y + z".
|
||||
This function operates on 3 identically-shaped tensors of floating point
|
||||
values and performs the operations pointwise.
|
||||
|
||||
+++ {"id": "HIJYIHNTD1yI"}
|
||||
|
||||
## Using existing primitives
|
||||
|
||||
The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other
|
||||
functions that are themselves written using JAX primitives, e.g., those
|
||||
defined in the `jax.lax` module:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: tbOF0LB0EMne
|
||||
:outputId: 3fb1c8a7-7a4c-4a3a-f7ff-37b7dc740528
|
||||
|
||||
from jax import lax
|
||||
from jax._src import api
|
||||
|
||||
def multiply_add_lax(x, y, z):
|
||||
"""Implementation of multiply-add using the jax.lax primitives."""
|
||||
return lax.add(lax.mul(x, y), z)
|
||||
|
||||
|
||||
def square_add_lax(a, b):
|
||||
"""A square-add function using the newly defined multiply-add."""
|
||||
return multiply_add_lax(a, a, b)
|
||||
|
||||
print("square_add_lax = ", square_add_lax(2., 10.))
|
||||
# Differentiate w.r.t. the first argument
|
||||
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
|
||||
```
|
||||
|
||||
+++ {"id": "Cgv60Wm3E_D5"}
|
||||
|
||||
In order to understand how JAX is internally using the primitives,
|
||||
we add some helpers for tracing function calls.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:cellView: form
|
||||
:id: mQRQGEGiE53K
|
||||
|
||||
#@title Helper functions (execute this cell)
|
||||
import functools
|
||||
import traceback
|
||||
|
||||
_indentation = 0
|
||||
def _trace(msg=None):
|
||||
"""Print a message at current indentation."""
|
||||
if msg is not None:
|
||||
print(" " * _indentation + msg)
|
||||
|
||||
def _trace_indent(msg=None):
|
||||
"""Print a message and then indent the rest."""
|
||||
global _indentation
|
||||
_trace(msg)
|
||||
_indentation = 1 + _indentation
|
||||
|
||||
def _trace_unindent(msg=None):
|
||||
"""Unindent then print a message."""
|
||||
global _indentation
|
||||
_indentation = _indentation - 1
|
||||
_trace(msg)
|
||||
|
||||
def trace(name):
|
||||
"""A decorator for functions to trace arguments and results."""
|
||||
|
||||
def trace_func(func): # pylint: disable=missing-docstring
|
||||
def pp(v):
|
||||
"""Print certain values more succinctly"""
|
||||
vtype = str(type(v))
|
||||
if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
|
||||
return "<JaxComputationBuilder>"
|
||||
elif "jaxlib.xla_extension.XlaOp" in vtype:
|
||||
return "<XlaOp at 0x{:x}>".format(id(v))
|
||||
elif ("partial_eval.JaxprTracer" in vtype or
|
||||
"batching.BatchTracer" in vtype or
|
||||
"ad.JVPTracer" in vtype):
|
||||
return "Traced<{}>".format(v.aval)
|
||||
elif isinstance(v, tuple):
|
||||
return "({})".format(pp_values(v))
|
||||
else:
|
||||
return str(v)
|
||||
def pp_values(args):
|
||||
return ", ".join([pp(arg) for arg in args])
|
||||
|
||||
@functools.wraps(func)
|
||||
def func_wrapper(*args):
|
||||
_trace_indent("call {}({})".format(name, pp_values(args)))
|
||||
res = func(*args)
|
||||
_trace_unindent("|<- {} = {}".format(name, pp(res)))
|
||||
return res
|
||||
|
||||
return func_wrapper
|
||||
|
||||
return trace_func
|
||||
|
||||
class expectNotImplementedError(object):
|
||||
"""Context manager to check for NotImplementedError."""
|
||||
def __enter__(self): pass
|
||||
def __exit__(self, type, value, tb):
|
||||
global _indentation
|
||||
_indentation = 0
|
||||
if type is NotImplementedError:
|
||||
print("\nFound expected exception:")
|
||||
traceback.print_exc(limit=3)
|
||||
return True
|
||||
elif type is None: # No exception
|
||||
assert False, "Expected NotImplementedError"
|
||||
else:
|
||||
return False
|
||||
```
|
||||
|
||||
+++ {"id": "Qf4eLrLCFYDl"}
|
||||
|
||||
Instead of using `jax.lax` primitives directly, we can use other functions
|
||||
that are already written in terms of those primitives, such as those in `jax.numpy`:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: QhKorz6cFRJb
|
||||
:outputId: aba3cef3-6bcc-4eb3-c7b3-34e405f2f82a
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@trace("multiply_add_numpy")
|
||||
def multiply_add_numpy(x, y, z):
|
||||
return jnp.add(jnp.multiply(x, y), z)
|
||||
|
||||
@trace("square_add_numpy")
|
||||
def square_add_numpy(a, b):
|
||||
return multiply_add_numpy(a, a, b)
|
||||
|
||||
print("\nNormal evaluation:")
|
||||
print("square_add_numpy = ", square_add_numpy(2., 10.))
|
||||
print("\nGradient evaluation:")
|
||||
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
|
||||
```
|
||||
|
||||
+++ {"id": "Sg-D8EdeFn4a"}
|
||||
|
||||
Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and
|
||||
`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further
|
||||
below in this colab).
|
||||
It is important to remember that a JAX-traceable function must be able to
|
||||
operate not only on concrete arguments but also on special abstract arguments
|
||||
that JAX may use to abstract the function execution.
|
||||
|
||||
The JAX traceability property is satisfied as long as the function is written
|
||||
in terms of JAX primitives.
|
||||
|
||||
+++ {"id": "WxrQO7-XGLcg"}
|
||||
|
||||
## Defining new JAX primitives
|
||||
|
||||
The right way to add support for multiply-add is in terms of existing
|
||||
JAX primitives, as shown above. However, in order to demonstrate how JAX
|
||||
primitives work let us pretend that we want to add a new primitive to
|
||||
JAX for the multiply-add functionality.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: cPqAH1XOGTN4
|
||||
|
||||
from jax import core
|
||||
multiply_add_p = core.Primitive("multiply_add") # Create the primitive
|
||||
|
||||
@trace("multiply_add_prim")
|
||||
def multiply_add_prim(x, y, z):
|
||||
"""The JAX-traceable way to use the JAX primitive.
|
||||
|
||||
Note that the traced arguments must be passed as positional arguments
|
||||
to `bind`.
|
||||
"""
|
||||
return multiply_add_p.bind(x, y, z)
|
||||
|
||||
@trace("square_add_prim")
|
||||
def square_add_prim(a, b):
|
||||
"""A square-add function implemented using the new JAX-primitive."""
|
||||
return multiply_add_prim(a, a, b)
|
||||
```
|
||||
|
||||
+++ {"id": "LMzs5PAKGr-4"}
|
||||
|
||||
If we try to call the newly defined functions we get an error, because
|
||||
we have not yet told JAX anything about the semantics of the new primitive.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: _X3PAYxhGpWd
|
||||
:outputId: 90ea2c6a-9ef3-40ea-e9a3-3ab1cfc59fc8
|
||||
|
||||
with expectNotImplementedError():
|
||||
square_add_prim(2., 10.)
|
||||
```
|
||||
|
||||
+++ {"id": "elha0FdgHSEF"}
|
||||
|
||||
### Primal evaluation rules
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: FT34FFAGHARU
|
||||
:outputId: 4c54f1c2-8a50-4788-90e1-06aee412c43b
|
||||
|
||||
@trace("multiply_add_impl")
|
||||
def multiply_add_impl(x, y, z):
|
||||
"""Concrete implementation of the primitive.
|
||||
|
||||
This function does not need to be JAX traceable.
|
||||
Args:
|
||||
x, y, z: the concrete arguments of the primitive. Will only be called with
|
||||
concrete values.
|
||||
Returns:
|
||||
the concrete result of the primitive.
|
||||
"""
|
||||
# Note that we can use the original numpy, which is not JAX traceable
|
||||
return np.add(np.multiply(x, y), z)
|
||||
|
||||
# Now we register the primal implementation with JAX
|
||||
multiply_add_p.def_impl(multiply_add_impl)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: G5bstKaeNAVV
|
||||
:outputId: deb94d5b-dfea-4e6f-9ec2-70b416c996c5
|
||||
|
||||
assert square_add_prim(2., 10.) == 14.
|
||||
```
|
||||
|
||||
+++ {"id": "upBf-uAuHhPJ"}
|
||||
|
||||
### JIT
|
||||
|
||||
If we now try to use `jit` we get a `NotImplementedError`:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: QG-LULjiHk4b
|
||||
:outputId: d4ef4406-8dae-4c96-97ca-b662340474ee
|
||||
|
||||
with expectNotImplementedError():
|
||||
api.jit(square_add_prim)(2., 10.)
|
||||
```
|
||||
|
||||
+++ {"id": "rHS1bAGHH44E"}
|
||||
|
||||
#### Abstract evaluation rules
|
||||
In order to JIT the function, and for other transformations as well,
|
||||
JAX first evaluates it abstractly using only the
|
||||
shape and type of the arguments. This abstract evaluation serves multiple
|
||||
purposes:
|
||||
|
||||
* Gets the sequence of JAX primitives that are used in the computation. This
|
||||
sequence will be compiled.
|
||||
* Computes the shape and type of all vectors and operations used in the computation.
|
||||
|
||||
|
||||
For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.
|
||||
In the latter case, JAX uses the actual concrete value wrapped as an abstract value.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: ctQmEeckIbdo
|
||||
:outputId: e751d0cc-460e-4ffd-df2e-fdabf9cffdc2
|
||||
|
||||
from jax import core
|
||||
@trace("multiply_add_abstract_eval")
|
||||
def multiply_add_abstract_eval(xs, ys, zs):
|
||||
"""Abstract evaluation of the primitive.
|
||||
|
||||
This function does not need to be JAX traceable. It will be invoked with
|
||||
abstractions of the actual arguments.
|
||||
Args:
|
||||
xs, ys, zs: abstractions of the arguments.
|
||||
Result:
|
||||
a ShapedArray for the result of the primitive.
|
||||
"""
|
||||
assert xs.shape == ys.shape
|
||||
assert xs.shape == zs.shape
|
||||
return core.ShapedArray(xs.shape, xs.dtype)
|
||||
|
||||
# Now we register the abstract evaluation with JAX
|
||||
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
|
||||
```
|
||||
|
||||
+++ {"id": "RPN88X6YI43A"}
|
||||
|
||||
If we re-attempt to JIT, we see how the abstract evaluation proceeds, but
|
||||
we get another error, about missing the actual XLA compilation rule:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: eOcNR92SI2h-
|
||||
:outputId: 356ef229-3703-4696-cc3d-7c05de405fb0
|
||||
|
||||
with expectNotImplementedError():
|
||||
api.jit(square_add_prim)(2., 10.)
|
||||
```
|
||||
|
||||
+++ {"id": "9IOV1R-fJMHp"}
|
||||
|
||||
#### XLA Compilation rules
|
||||
|
||||
JAX compilation works by compiling each primitive into a graph of XLA operations.
|
||||
|
||||
This is the biggest hurdle to adding new functionality to JAX, because the
|
||||
set of XLA operations is limited, and JAX already has pre-defined primitives
|
||||
for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: FYQWSSjKJaWP
|
||||
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@trace("multiply_add_lowering")
|
||||
def multiply_add_lowering(ctx, xc, yc, zc):
|
||||
"""The compilation to XLA of the primitive.
|
||||
|
||||
Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
|
||||
the results of the function.
|
||||
|
||||
Does not need to be a JAX-traceable function.
|
||||
"""
|
||||
return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]
|
||||
|
||||
# Now we register the lowering rule with JAX
|
||||
# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
|
||||
# TODO: TPU?
|
||||
from jax.interpreters import mlir
|
||||
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
|
||||
```
|
||||
|
||||
+++ {"id": "K98LX-VaJkFu"}
|
||||
|
||||
Now we succeed to JIT. Notice below that JAX first evaluates the function
|
||||
abstractly, which triggers the `multiply_add_abstract_eval` function, and
|
||||
then compiles the set of primitives it has encountered, including `multiply_add`.
|
||||
At this point JAX invokes `multiply_add_xla_translation`.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: rj3TLsolJgEc
|
||||
:outputId: e384bee4-1e9c-4344-f49c-d3b5ec08eb32
|
||||
|
||||
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
|
||||
```
|
||||
|
||||
+++ {"id": "Omrez-2_KFfo"}
|
||||
|
||||
Below is another use of `jit` where we compile only
|
||||
with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads
|
||||
in the third argument to `multiply_add_abstract_eval` being
|
||||
`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with
|
||||
both `ShapedArray` and `ConcreteArray`.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: mPfTwIBoKOEK
|
||||
:outputId: b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b
|
||||
|
||||
assert api.jit(lambda x, y: square_add_prim(x, y),
|
||||
static_argnums=1)(2., 10.) == 14.
|
||||
```
|
||||
|
||||
+++ {"id": "_Ya3B5l4J1VA"}
|
||||
|
||||
### Forward differentiation
|
||||
|
||||
JAX implements forward differentiation in the form of
|
||||
a Jacobian-vector product (see the [JAX autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Jacobian-Matrix-and-Matrix-Jacobian-products)).
|
||||
|
||||
If we attempt now to compute the `jvp` function we get an
|
||||
error because we have not yet told JAX how to differentiate
|
||||
the `multiply_add` primitive.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: OxDx6NQnKwMI
|
||||
:outputId: ce659ef3-c03c-4856-f252-49ec4b6eb964
|
||||
|
||||
# The second argument `(2., 10.)` are the argument values
|
||||
# where we evaluate the Jacobian, and the third `(1., 1.)`
|
||||
# are the values of the tangents for the arguments.
|
||||
with expectNotImplementedError():
|
||||
api.jvp(square_add_prim, (2., 10.), (1., 1.))
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: zxG24C1JMIMM
|
||||
|
||||
from jax.interpreters import ad
|
||||
|
||||
|
||||
@trace("multiply_add_value_and_jvp")
|
||||
def multiply_add_value_and_jvp(arg_values, arg_tangents):
|
||||
"""Evaluates the primal output and the tangents (Jacobian-vector product).
|
||||
|
||||
Given values of the arguments and perturbation of the arguments (tangents),
|
||||
compute the output of the primitive and the perturbation of the output.
|
||||
|
||||
This method must be JAX-traceable. JAX may invoke it with abstract values
|
||||
for the arguments and tangents.
|
||||
|
||||
Args:
|
||||
arg_values: a tuple of arguments
|
||||
arg_tangents: a tuple with the tangents of the arguments. The tuple has
|
||||
the same length as the arg_values. Some of the tangents may also be the
|
||||
special value ad.Zero to specify a zero tangent.
|
||||
Returns:
|
||||
a pair of the primal output and the tangent.
|
||||
"""
|
||||
x, y, z = arg_values
|
||||
xt, yt, zt = arg_tangents
|
||||
_trace("Primal evaluation:")
|
||||
# Now we have a JAX-traceable computation of the output.
|
||||
# Normally, we can use the ma primitive itself to compute the primal output.
|
||||
primal_out = multiply_add_prim(x, y, z)
|
||||
|
||||
_trace("Tangent evaluation:")
|
||||
# We must use a JAX-traceable way to compute the tangent. It turns out that
|
||||
# the output tangent can be computed as (xt * y + x * yt + zt),
|
||||
# which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
|
||||
|
||||
# We do need to deal specially with Zero. Here we just turn it into a
|
||||
# proper tensor of 0s (of the same shape as 'x').
|
||||
# An alternative would be to check for Zero and perform algebraic
|
||||
# simplification of the output tangent computation.
|
||||
def make_zero(tan):
|
||||
return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
|
||||
|
||||
output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
|
||||
return (primal_out, output_tangent)
|
||||
|
||||
# Register the forward differentiation rule with JAX
|
||||
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: ma3KBkiAMfW1
|
||||
:outputId: f34cbbc6-20d9-48ca-9a9a-b5d91a972cdd
|
||||
|
||||
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
|
||||
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
|
||||
```
|
||||
|
||||
+++ {"id": "69QsEcu-lP4u"}
|
||||
|
||||
TO EXPLAIN:
|
||||
|
||||
* Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.
|
||||
* Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet
|
||||
we do not call the multiply_add_abstract_eval.
|
||||
* I think it would be useful to show the jaxpr here
|
||||
|
||||
+++ {"id": "Sb6e3ZAHOPHv"}
|
||||
|
||||
#### JIT of forward differentiation
|
||||
|
||||
We can apply JIT to the forward differentiation function:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: hg-hzVu-N-hv
|
||||
:outputId: 38d32067-e152-4046-ad80-7f95a31ba628
|
||||
|
||||
assert api.jit(lambda arg_values, arg_tangents:
|
||||
api.jvp(square_add_prim, arg_values, arg_tangents))(
|
||||
(2., 10.), (1., 1.)) == (14., 5.)
|
||||
```
|
||||
|
||||
+++ {"id": "jlZt1_v2mU88"}
|
||||
|
||||
Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn
|
||||
evaluates abstractly both the primal and the tangent evaluation (a total of
|
||||
3 invocations of the `ma` primitive). Then we compile the 3 occurrences
|
||||
of the primitive.
|
||||
|
||||
+++ {"id": "555yt6ZIOePB"}
|
||||
|
||||
### Reverse differentiation
|
||||
|
||||
If we attempt now to use reverse differentiation we
|
||||
see that JAX starts by using the `multiply_add_value_and_jvp` to
|
||||
compute the forward differentiation for abstract values, but then runs
|
||||
into a `NotImplementedError`.
|
||||
|
||||
When computing the reverse differentiation JAX first does abstract evaluation
|
||||
of the forward differentiation code `multiply_add_value_and_jvp` to obtain a
|
||||
trace of primitives that compute the output tangent.
|
||||
Observe that JAX performs this abstract evaluation with concrete values
|
||||
for the differentiation point, and abstract values for the tangents.
|
||||
Observe also that JAX uses the special abstract tangent value `Zero` for
|
||||
the tangent corresponding to the 3rd argument of `ma`. This reflects the
|
||||
fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,
|
||||
which flows to the 3rd argument to `multiply_add_prim`.
|
||||
|
||||
Observe also that during the abstract evaluation of the tangent we pass the
|
||||
value 0.0 as the tangent for the 3rd argument. This is due to the use
|
||||
of the `make_zero` function in the definition of `multiply_add_value_and_jvp`.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 8eAVnexaOjBn
|
||||
:outputId: e4ee89cf-ab4a-4505-9817-fa978a2865ab
|
||||
|
||||
# This is reverse differentiation w.r.t. the first argument of square_add_prim
|
||||
with expectNotImplementedError():
|
||||
api.grad(square_add_prim)(2., 10.)
|
||||
```
|
||||
|
||||
+++ {"id": "fSHLUMDN26AY"}
|
||||
|
||||
The above error is because there is a missing piece for JAX to be able
|
||||
to use the forward differentiation code to compute reverse differentiation.
|
||||
|
||||
+++ {"id": "3ibDbGF-PjK9"}
|
||||
|
||||
#### Transposition
|
||||
|
||||
|
||||
As explained above, when computing reverse differentiation JAX obtains
|
||||
a trace of primitives that compute the tangent using forward differentiation.
|
||||
Then, **JAX interprets this trace abstractly backwards** and for each
|
||||
primitive it applies a **transposition** rule.
|
||||
|
||||
To understand what is going on, consider for now a simpler example of the function "f(x, y) = x * y + y". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:
|
||||
```
|
||||
a = xt * 4.
|
||||
b = 2. * yt
|
||||
c = a + b
|
||||
ft = c + yt
|
||||
```
|
||||
|
||||
By construction, the tangent calculation is always linear in the input tangents.
|
||||
The only non-linear operator that may arise in the tangent calculation is multiplication,
|
||||
but then one of the operands is constant.
|
||||
|
||||
JAX will produce the reverse differentiation computation by processing the
|
||||
JVP computation backwards. For each operation in the tangent computation,
|
||||
it accumulates the cotangents
|
||||
of the variables used by the operation, using the cotangent of the result
|
||||
of the operation:
|
||||
```
|
||||
# Initialize cotangents of inputs and intermediate vars
|
||||
xct = yct = act = bct = cct = 0.
|
||||
# Initialize cotangent of the output
|
||||
fct = 1.
|
||||
# Process "ft = c + yt"
|
||||
cct += fct
|
||||
yct += fct
|
||||
# Process "c = a + b"
|
||||
act += cct
|
||||
bct += cct
|
||||
# Process "b = 2. * yt"
|
||||
yct += 2. * bct
|
||||
# Process "a = xt * 4."
|
||||
xct += act * 4.
|
||||
```
|
||||
|
||||
One can verify that this computation produces `xct = 4.` and `yct = 3.`, which
|
||||
are the partial derivatives of the function `f`.
|
||||
|
||||
JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:
|
||||
```
|
||||
p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)
|
||||
```
|
||||
|
||||
Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other
|
||||
arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned
|
||||
for the constant arguments.
|
||||
|
||||
In particular,
|
||||
```
|
||||
add_transpose(out_ct, _, _) = (out_ct, out_ct)
|
||||
mult_transpose(out_ct, x, _) = (None, x * out_ct)
|
||||
mult_transpose(out_ct, _, y) = (out_ct * y, None)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: JaHxFdkRO42r
|
||||
|
||||
@trace("multiply_add_transpose")
|
||||
def multiply_add_transpose(ct, x, y, z):
|
||||
"""Evaluates the transpose of a linear primitive.
|
||||
|
||||
This method is only used when computing the backward gradient following
|
||||
value_and_jvp, and is only needed for primitives that are used in the JVP
|
||||
calculation for some other primitive. We need transposition for multiply_add_prim,
|
||||
because we have used multiply_add_prim in the computation of the output_tangent in
|
||||
multiply_add_value_and_jvp.
|
||||
|
||||
In our case, multiply_add is not a linear primitive. However, it is used linearly
|
||||
w.r.t. tangents in multiply_add_value_and_jvp:
|
||||
output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
|
||||
|
||||
Always one of the first two multiplicative arguments is a constant.
|
||||
|
||||
Args:
|
||||
ct: the cotangent of the output of the primitive.
|
||||
x, y, z: values of the arguments. The arguments that are used linearly
|
||||
get an ad.UndefinedPrimal value. The other arguments get a constant
|
||||
value.
|
||||
Returns:
|
||||
a tuple with the cotangent of the inputs, with the value None
|
||||
corresponding to the constant arguments.
|
||||
"""
|
||||
if not ad.is_undefined_primal(x):
|
||||
# This use of multiply_add is with a constant "x"
|
||||
assert ad.is_undefined_primal(y)
|
||||
ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
|
||||
res = None, ct_y, ct
|
||||
else:
|
||||
# This use of multiply_add is with a constant "y"
|
||||
assert ad.is_undefined_primal(x)
|
||||
ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
|
||||
res = ct_x, None, ct
|
||||
return res
|
||||
|
||||
|
||||
ad.primitive_transposes[multiply_add_p] = multiply_add_transpose
|
||||
```
|
||||
|
||||
+++ {"id": "PpChox-Jp7wb"}
|
||||
|
||||
Now we can complete the run of the `grad`:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: PogPKS4MPevd
|
||||
:outputId: d33328d4-3e87-45b5-9b31-21ad624b67af
|
||||
|
||||
assert api.grad(square_add_prim)(2., 10.) == 4.
|
||||
```
|
||||
|
||||
+++ {"id": "8M1xLCXW4fK7"}
|
||||
|
||||
Notice the two calls to `multiply_add_transpose`. They correspond to the two
|
||||
uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the
|
||||
last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0.
|
||||
|
||||
+++ {"id": "EIJs6FYmPg6c"}
|
||||
|
||||
#### JIT of reverse differentiation
|
||||
|
||||
Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only
|
||||
abstract values, while in the absence of JIT we used `ConcreteArray`.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: FZ-JGbWZPq2-
|
||||
:outputId: e42b5222-9c3e-4853-e13a-874f6605d178
|
||||
|
||||
assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
|
||||
```
|
||||
|
||||
+++ {"id": "-3lqPkdQPvl5"}
|
||||
|
||||
### Batching
|
||||
|
||||
The batching transformation takes a point-wise computation and turns it
|
||||
into a computation on vectors. If we try it right now, we get a `NotImplementedError`:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: hFvBR3I9Pzh3
|
||||
:outputId: 434608bc-281f-4d3b-83bd-eaaf3b51b1cd
|
||||
|
||||
# The arguments are two vectors instead of two scalars
|
||||
with expectNotImplementedError():
|
||||
api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
|
||||
np.array([10., 20.]))
|
||||
```
|
||||
|
||||
+++ {"id": "gILasMiP6elR"}
|
||||
|
||||
We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the `multiply_add_prim` already operates pointwise for any dimension of input vectors. So the batched version can use the same `multiply_add_prim` implementation.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: KQfeqRIrP7zg
|
||||
|
||||
from jax.interpreters import batching
|
||||
|
||||
|
||||
@trace("multiply_add_batch")
|
||||
def multiply_add_batch(vector_arg_values, batch_axes):
|
||||
"""Computes the batched version of the primitive.
|
||||
|
||||
This must be a JAX-traceable function.
|
||||
|
||||
Since the multiply_add primitive already operates pointwise on arbitrary
|
||||
dimension tensors, to batch it we can use the primitive itself. This works as
|
||||
long as both the inputs have the same dimensions and are batched along the
|
||||
same axes. The result is batched along the axis that the inputs are batched.
|
||||
|
||||
Args:
|
||||
vector_arg_values: a tuple of two arguments, each being a tensor of matching
|
||||
shape.
|
||||
batch_axes: the axes that are being batched. See vmap documentation.
|
||||
Returns:
|
||||
a tuple of the result, and the result axis that was batched.
|
||||
"""
|
||||
assert batch_axes[0] == batch_axes[1]
|
||||
assert batch_axes[0] == batch_axes[2]
|
||||
_trace("Using multiply_add to compute the batch:")
|
||||
res = multiply_add_prim(*vector_arg_values)
|
||||
return res, batch_axes[0]
|
||||
|
||||
|
||||
batching.primitive_batchers[multiply_add_p] = multiply_add_batch
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: VwxNk869P_YG
|
||||
:outputId: 9d22c921-5803-4d33-9e88-b6e439ba9738
|
||||
|
||||
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
|
||||
np.array([2., 3.]),
|
||||
np.array([10., 20.])),
|
||||
[14., 29.])
|
||||
```
|
||||
|
||||
+++ {"id": "NmqLlV1TQDCC"}
|
||||
|
||||
#### JIT of batching
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: xqEdXVUgQCTt
|
||||
:outputId: 9c22fd9c-919c-491d-bbeb-32c241b808fa
|
||||
|
||||
assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
|
||||
(np.array([2., 3.]),
|
||||
np.array([10., 20.])),
|
||||
[14., 29.])
|
||||
```
|
File diff suppressed because it is too large
Load Diff
@ -1,515 +0,0 @@
|
||||
---
|
||||
jupytext:
|
||||
formats: ipynb,md:myst
|
||||
text_representation:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.4
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
---
|
||||
|
||||
+++ {"id": "7XNMxdTwURqI"}
|
||||
|
||||
# External callbacks
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-08' } *-->
|
||||
|
||||
+++ {"id": "h6lXo6bSUYGq"}
|
||||
|
||||
This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation.
|
||||
|
||||
+++ {"id": "Xi_nhfpnlmbm"}
|
||||
|
||||
## Why callbacks?
|
||||
|
||||
A callback routine is a way to perform **host-side** execution of code at runtime.
|
||||
As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation.
|
||||
Using a simple Python `print` statement, it looks like this:
|
||||
|
||||
```{code-cell}
|
||||
:id: lz8rEL1Amb4r
|
||||
:outputId: bbd37102-19f2-46d2-b794-3d4952c6fe97
|
||||
|
||||
import jax
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x + 1
|
||||
print("intermediate value: {}".format(y))
|
||||
return y * 2
|
||||
|
||||
result = f(2)
|
||||
```
|
||||
|
||||
+++ {"id": "yEy41sFAmxOp"}
|
||||
|
||||
What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)).
|
||||
|
||||
To print the value at runtime we need a callback, for example `jax.debug.print`:
|
||||
|
||||
```{code-cell}
|
||||
:id: wFfHmoQxnKDF
|
||||
:outputId: 6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x + 1
|
||||
jax.debug.print("intermediate value: {}", y)
|
||||
return y * 2
|
||||
|
||||
result = f(2)
|
||||
```
|
||||
|
||||
+++ {"id": "CvWv3pudn9X5"}
|
||||
|
||||
This works by passing the runtime value represented by `y` back to the host process, where the host can print the value.
|
||||
|
||||
+++ {"id": "X0vR078znuT-"}
|
||||
|
||||
## Flavors of Callback
|
||||
|
||||
In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:
|
||||
|
||||
- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect.
|
||||
- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.
|
||||
- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.
|
||||
|
||||
(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`).
|
||||
|
||||
From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.
|
||||
|
||||
|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |
|
||||
|-------------------------------------|----|----|----|----|----|----|
|
||||
|`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ |
|
||||
|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ |
|
||||
|`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
||||
¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff
|
||||
|
||||
² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`.
|
||||
|
||||
³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases.
|
||||
|
||||
+++ {"id": "hE_M8DaPvoym"}
|
||||
|
||||
### Exploring `jax.pure_callback`
|
||||
|
||||
`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).
|
||||
|
||||
The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times.
|
||||
|
||||
```{code-cell}
|
||||
:id: 4lQDzXy6t_-k
|
||||
:outputId: 279e4daf-0540-4eab-f535-d3bcbac74c44
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
def f_host(x):
|
||||
# call a numpy (not jax.numpy) operation:
|
||||
return np.sin(x).astype(x.dtype)
|
||||
|
||||
def f(x):
|
||||
result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
return jax.pure_callback(f_host, result_shape, x)
|
||||
|
||||
x = jnp.arange(5.0)
|
||||
f(x)
|
||||
```
|
||||
|
||||
+++ {"id": "q7YCIr8qMrDs"}
|
||||
|
||||
Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:"
|
||||
|
||||
```{code-cell}
|
||||
:id: bgoZ0fxsuoWV
|
||||
:outputId: 901443bd-5cb4-4923-ce53-6f832ac22ca9
|
||||
|
||||
jax.jit(f)(x)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: ajBRGWGfupu2
|
||||
:outputId: b28e31ee-7457-4b92-872b-52d819f53ddf
|
||||
|
||||
jax.vmap(f)(x)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: xe7AOGexvC13
|
||||
:outputId: 8fa77977-1f2b-41c5-cc5e-11993ee5aa3e
|
||||
|
||||
def body_fun(_, x):
|
||||
return _, f(x)
|
||||
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
|
||||
```
|
||||
|
||||
+++ {"id": "tMzAVs2VNj5G"}
|
||||
|
||||
However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:
|
||||
|
||||
```{code-cell}
|
||||
:id: 4QAF4VhUu5bb
|
||||
:outputId: f8a06d02-47e9-4240-8077-d7be81e5a480
|
||||
|
||||
%xmode minimal
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: qUpKPxlOurfY
|
||||
:outputId: 11a665e8-40eb-4b0e-dc2e-a544a25fc57e
|
||||
:tags: [raises-exception]
|
||||
|
||||
jax.grad(f)(x)
|
||||
```
|
||||
|
||||
+++ {"id": "y9DAibV4Nwpo"}
|
||||
|
||||
For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below.
|
||||
|
||||
+++ {"id": "LrvdAloMZbIe"}
|
||||
|
||||
By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:
|
||||
|
||||
```{code-cell}
|
||||
:id: mmFc_zawZrBq
|
||||
:outputId: a4df7568-3f64-4b2f-9a2c-7adb2e0815e0
|
||||
|
||||
def print_something():
|
||||
print('printing something')
|
||||
return np.int32(0)
|
||||
|
||||
@jax.jit
|
||||
def f1():
|
||||
return jax.pure_callback(print_something, np.int32(0))
|
||||
f1();
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: tTwE4kpmaNei
|
||||
|
||||
@jax.jit
|
||||
def f2():
|
||||
jax.pure_callback(print_something, np.int32(0))
|
||||
return 1.0
|
||||
f2();
|
||||
```
|
||||
|
||||
+++ {"id": "qfyGYbw4Z5U3"}
|
||||
|
||||
In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.
|
||||
In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects.
|
||||
|
||||
+++ {"id": "JHcJybr7OEBM"}
|
||||
|
||||
### Exploring `jax.experimental.io_callback`
|
||||
|
||||
In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.
|
||||
|
||||
As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!).
|
||||
|
||||
```{code-cell}
|
||||
:id: eAg5xIhrOiWV
|
||||
:outputId: e3cfec21-d843-4852-a49d-69a69fba9fc1
|
||||
|
||||
from jax.experimental import io_callback
|
||||
from functools import partial
|
||||
|
||||
global_rng = np.random.default_rng(0)
|
||||
|
||||
def host_side_random_like(x):
|
||||
"""Generate a random array like x using the global_rng state"""
|
||||
# We have two side-effects here:
|
||||
# - printing the shape and dtype
|
||||
# - calling global_rng, thus updating its state
|
||||
print(f'generating {x.dtype}{list(x.shape)}')
|
||||
return global_rng.uniform(size=x.shape).astype(x.dtype)
|
||||
|
||||
@jax.jit
|
||||
def numpy_random_like(x):
|
||||
return io_callback(host_side_random_like, x, x)
|
||||
|
||||
x = jnp.zeros(5)
|
||||
numpy_random_like(x)
|
||||
```
|
||||
|
||||
+++ {"id": "mAIF31MlXj33"}
|
||||
|
||||
The `io_callback` is compatible with `vmap` by default:
|
||||
|
||||
```{code-cell}
|
||||
:id: NY3o5dG6Vg6u
|
||||
:outputId: a67a8a98-214e-40ca-ad98-a930cd3db85e
|
||||
|
||||
jax.vmap(numpy_random_like)(x)
|
||||
```
|
||||
|
||||
+++ {"id": "XXvSeeOXXquZ"}
|
||||
|
||||
Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.
|
||||
|
||||
If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error:
|
||||
|
||||
```{code-cell}
|
||||
:id: 3aNmRsDrX3-2
|
||||
:outputId: a8ff4b77-f4cb-442f-8cfb-ea7251c66274
|
||||
:tags: [raises-exception]
|
||||
|
||||
@jax.jit
|
||||
def numpy_random_like_ordered(x):
|
||||
return io_callback(host_side_random_like, x, x, ordered=True)
|
||||
|
||||
jax.vmap(numpy_random_like_ordered)(x)
|
||||
```
|
||||
|
||||
+++ {"id": "fD2FTHlUYAZH"}
|
||||
|
||||
On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced:
|
||||
|
||||
```{code-cell}
|
||||
:id: lMVzZlIEWL7F
|
||||
:outputId: f9741c18-a30d-4d46-b706-8102849286b5
|
||||
|
||||
def body_fun(_, x):
|
||||
return _, numpy_random_like_ordered(x)
|
||||
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
|
||||
```
|
||||
|
||||
+++ {"id": "w_sf8mCbbo8K"}
|
||||
|
||||
Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable:
|
||||
|
||||
```{code-cell}
|
||||
:id: Cn6_RG4JcKZm
|
||||
:outputId: 336ae5d2-e35b-4fe5-cbfb-14a7aef28c07
|
||||
:tags: [raises-exception]
|
||||
|
||||
jax.grad(numpy_random_like)(x)
|
||||
```
|
||||
|
||||
+++ {"id": "plvfn9lWcKu4"}
|
||||
|
||||
However, if the callback is not dependent on a differentiated variable, it will execute:
|
||||
|
||||
```{code-cell}
|
||||
:id: wxgfDmDfb5bx
|
||||
:outputId: d8c0285c-cd04-4b4d-d15a-1b07f778882d
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
io_callback(lambda: print('hello'), None)
|
||||
return x
|
||||
|
||||
jax.grad(f)(1.0);
|
||||
```
|
||||
|
||||
+++ {"id": "STLI40EZcVIY"}
|
||||
|
||||
Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation.
|
||||
|
||||
+++ {"id": "pkkM1ZmqclV-"}
|
||||
|
||||
### Exploring `debug.callback`
|
||||
|
||||
Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program.
|
||||
|
||||
```{code-cell}
|
||||
:id: 74TdWyu9eqBa
|
||||
:outputId: d8551dab-2e61-492e-9ac3-dc3db51b2c18
|
||||
|
||||
from jax import debug
|
||||
|
||||
def log_value(x):
|
||||
# This could be an actual logging call; we'll use
|
||||
# print() for demonstration
|
||||
print("log:", x)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
debug.callback(log_value, x)
|
||||
return x
|
||||
|
||||
f(1.0);
|
||||
```
|
||||
|
||||
+++ {"id": "P848STlsfzmW"}
|
||||
|
||||
The debug callback is compatible with `vmap`:
|
||||
|
||||
```{code-cell}
|
||||
:id: 2sSNsPB-fGVI
|
||||
:outputId: fff58575-d94c-48fb-b88a-c1c395595fd0
|
||||
|
||||
x = jnp.arange(5.0)
|
||||
jax.vmap(f)(x);
|
||||
```
|
||||
|
||||
+++ {"id": "VDMacqpXf3La"}
|
||||
|
||||
And is also compatible with `grad` and other autodiff transformations
|
||||
|
||||
```{code-cell}
|
||||
:id: wkFRle-tfTDe
|
||||
:outputId: 4e8a81d0-5012-4c51-d843-3fbdc498df31
|
||||
|
||||
jax.grad(f)(1.0);
|
||||
```
|
||||
|
||||
+++ {"id": "w8t-SDZ3gRzE"}
|
||||
|
||||
This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`.
|
||||
|
||||
+++ {"id": "dF7hoWGQUneJ"}
|
||||
|
||||
## Example: `pure_callback` with `custom_jvp`
|
||||
|
||||
One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`).
|
||||
Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers.
|
||||
|
||||
Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`.
|
||||
We can start by defining a straightforward `pure_callback`:
|
||||
|
||||
```{code-cell}
|
||||
:id: Ge4fNPZdVSJY
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import scipy.special
|
||||
|
||||
def jv(v, z):
|
||||
v, z = jnp.asarray(v), jnp.asarray(z)
|
||||
|
||||
# Require the order v to be integer type: this simplifies
|
||||
# the JVP rule below.
|
||||
assert jnp.issubdtype(v.dtype, jnp.integer)
|
||||
|
||||
# Promote the input to inexact (float/complex).
|
||||
# Note that jnp.result_type() accounts for the enable_x64 flag.
|
||||
z = z.astype(jnp.result_type(float, z.dtype))
|
||||
|
||||
# Wrap scipy function to return the expected dtype.
|
||||
_scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)
|
||||
|
||||
# Define the expected shape & dtype of output.
|
||||
result_shape_dtype = jax.ShapeDtypeStruct(
|
||||
shape=jnp.broadcast_shapes(v.shape, z.shape),
|
||||
dtype=z.dtype)
|
||||
|
||||
# We use vectorize=True because scipy.special.jv handles broadcasted inputs.
|
||||
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
|
||||
```
|
||||
|
||||
+++ {"id": "vyjQj-0QVuoN"}
|
||||
|
||||
This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`:
|
||||
|
||||
```{code-cell}
|
||||
:id: f4e46670f4e4
|
||||
|
||||
j1 = partial(jv, 1)
|
||||
z = jnp.arange(5.0)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: 6svImqFHWBwj
|
||||
:outputId: bc8c778a-6c10-443b-9be2-c0f28e2ac1a9
|
||||
|
||||
print(j1(z))
|
||||
```
|
||||
|
||||
+++ {"id": "d48eb4f2d48e"}
|
||||
|
||||
Here is the same result with `jit`:
|
||||
|
||||
```{code-cell}
|
||||
:id: txvRqR9DWGdC
|
||||
:outputId: d25f3476-23b1-48e4-dda1-3c06d32c3b87
|
||||
|
||||
print(jax.jit(j1)(z))
|
||||
```
|
||||
|
||||
+++ {"id": "d861a472d861"}
|
||||
|
||||
And here is the same result again with `vmap`:
|
||||
|
||||
```{code-cell}
|
||||
:id: BS-Ve5u_WU0C
|
||||
:outputId: 08cecd1f-6953-4853-e9db-25a03eb5b000
|
||||
|
||||
print(jax.vmap(j1)(z))
|
||||
```
|
||||
|
||||
+++ {"id": "SCH2ii_dWXP6"}
|
||||
|
||||
However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function:
|
||||
|
||||
```{code-cell}
|
||||
:id: q3qh_4DrWxdQ
|
||||
:outputId: c46b0bfa-96f3-4629-b9af-a4d4f3ccb870
|
||||
:tags: [raises-exception]
|
||||
|
||||
jax.grad(j1)(z)
|
||||
```
|
||||
|
||||
+++ {"id": "PtYeJ_xUW09v"}
|
||||
|
||||
Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:
|
||||
|
||||
$$
|
||||
d J_\nu(z) = \left\{
|
||||
\begin{eqnarray}
|
||||
-J_1(z),\ &\nu=0\\
|
||||
[J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0
|
||||
\end{eqnarray}\right.
|
||||
$$
|
||||
|
||||
The gradient with respect to $\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example.
|
||||
|
||||
We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function:
|
||||
|
||||
```{code-cell}
|
||||
:id: BOVQnt05XvLs
|
||||
|
||||
jv = jax.custom_jvp(jv)
|
||||
|
||||
@jv.defjvp
|
||||
def _jv_jvp(primals, tangents):
|
||||
v, z = primals
|
||||
_, z_dot = tangents # Note: v_dot is always 0 because v is integer.
|
||||
jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
|
||||
djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
|
||||
return jv(v, z), z_dot * djv_dz
|
||||
```
|
||||
|
||||
+++ {"id": "W1SxcvQSX44c"}
|
||||
|
||||
Now computing the gradient of our function will work correctly:
|
||||
|
||||
```{code-cell}
|
||||
:id: sCGceBs-X8nL
|
||||
:outputId: 71c5589f-f996-44a0-f09a-ca8bb40c167a
|
||||
|
||||
j1 = partial(jv, 1)
|
||||
print(jax.grad(j1)(2.0))
|
||||
```
|
||||
|
||||
+++ {"id": "gWQ4phN5YB26"}
|
||||
|
||||
Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free:
|
||||
|
||||
```{code-cell}
|
||||
:id: QTe5mRAvYQBh
|
||||
:outputId: d58ecff3-9419-422a-fd0e-14a7d9cf2cc3
|
||||
|
||||
jax.hessian(j1)(2.0)
|
||||
```
|
||||
|
||||
+++ {"id": "QEXGxU4uYZii"}
|
||||
|
||||
Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device.
|
||||
When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called.
|
||||
However, if you are running JAX on a single CPU (where the "host" and "device" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities.
|
@ -16,3 +16,13 @@ Tutorials
|
||||
working-with-pytrees
|
||||
sharded-computation
|
||||
stateful-computations
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Advanced tutorials
|
||||
|
||||
advanced-autodiff
|
||||
external-callbacks
|
||||
gradient-checkpointing
|
||||
jax-primitives
|
||||
jaxpr
|
||||
|
@ -33,7 +33,6 @@ or deployed codebases.
|
||||
:maxdepth: 1
|
||||
:caption: Custom operations
|
||||
|
||||
notebooks/external_callbacks
|
||||
pallas/index
|
||||
ffi
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user