2020-04-08 11:29:02 -07:00
|
|
|
Understanding Jaxprs
|
2020-02-10 11:40:05 +01:00
|
|
|
====================
|
|
|
|
|
2020-05-04 11:20:21 +01:00
|
|
|
Updated: May 3, 2020 (for commit f1a46fe).
|
2020-02-13 09:28:01 +01:00
|
|
|
|
2020-02-10 11:40:05 +01:00
|
|
|
Conceptually, one can think of JAX transformations as first tracing the Python
|
|
|
|
function to be transformed into a small and well-behaved intermediate form,
|
2020-01-15 15:00:38 -08:00
|
|
|
the jaxpr, that is then transformed accordingly, and ultimately compiled and executed.
|
2020-02-10 11:40:05 +01:00
|
|
|
One of the reasons JAX can pack so much power into such a small software package
|
|
|
|
is that it starts with a familiar and flexible programming interface (Python with NumPy)
|
|
|
|
and it uses the actual Python interpreter to do most of the heavy lifting to distill the
|
|
|
|
essence of the computation into a simple statically-typed expression language
|
2020-01-15 15:00:38 -08:00
|
|
|
with limited higher-order features: the jaxpr language.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Not all Python programs can be processed this way, but it turns out that many
|
|
|
|
scientific computing and machine learning programs do have this property.
|
|
|
|
|
|
|
|
Before we proceed, it is important to point out that not all JAX transformations
|
2020-01-15 15:00:38 -08:00
|
|
|
materialize a jaxpr as described above; some, e.g., differentiation,
|
2020-02-10 11:40:05 +01:00
|
|
|
will apply transformations incrementally during tracing.
|
|
|
|
Nevertheless, if one wants to understand how JAX works internally, or to
|
2020-01-15 15:00:38 -08:00
|
|
|
make use of the result of JAX tracing, it is useful to understand jaxpr.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-09-12 16:10:01 -04:00
|
|
|
A jaxpr instance represents a function with one or more typed parameters (input variables)
|
2020-02-10 11:40:05 +01:00
|
|
|
and one or more typed results. The results depend only on the input
|
|
|
|
variables; there are no free variables captured from enclosing scopes.
|
|
|
|
The inputs and outputs have types, which in JAX are represented as abstract
|
2020-01-15 15:00:38 -08:00
|
|
|
values. There are two related representations in the code for jaxprs. The main
|
2020-02-10 11:40:05 +01:00
|
|
|
one is :py:class:`jax.core.TypedJaxpr` and is what you obtain when you
|
2020-01-15 15:00:38 -08:00
|
|
|
use :py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following
|
2020-02-10 11:40:05 +01:00
|
|
|
fields:
|
|
|
|
|
|
|
|
* ``jaxpr``: is the actual computation content of the actual function (described below).
|
|
|
|
* ``literals`` is a list of constants. For various reasons, during tracing JAX
|
|
|
|
will collect the non-scalar constants that arise and will replace them with
|
|
|
|
variables, e.g., constants that appear in the Python program, or the result of
|
|
|
|
constant folding such constants. The variables that stand for these constants
|
|
|
|
are mentioned separately in the enclosed ``jaxpr``.
|
|
|
|
When applying a ``TypedJaxpr`` to some actual
|
|
|
|
arguments, one must pass first the ``literals`` followed by the actual arguments.
|
|
|
|
* ``in_avals`` and ``out_avals`` are the types of the input variables
|
|
|
|
(excluding the ones that correspond to the ``literals``), and of the output values.
|
|
|
|
These types are called in JAX abstract values, e.g., ``ShapedArray(float32[10,10])``.
|
|
|
|
|
|
|
|
The most interesting part of the TypedJaxpr is the actual execution content,
|
|
|
|
represented as a :py:class:`jax.core.Jaxpr` as printed using the following
|
|
|
|
grammar::
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
jaxpr ::= { lambda Var* ; Var+.
|
2020-02-10 11:40:05 +01:00
|
|
|
let Eqn*
|
|
|
|
in [Expr+] }
|
|
|
|
|
|
|
|
where:
|
2020-01-15 15:00:38 -08:00
|
|
|
* The parameter of the jaxpr are shown as two lists of variables separated by
|
2020-02-10 11:40:05 +01:00
|
|
|
``;``. The first set of variables are the ones that have been introduced
|
|
|
|
to stand for constants that have been hoisted out. These are called the
|
|
|
|
`constvars`. The second list of variables are the real input variables.
|
|
|
|
* ``Eqn*`` is a list of equations, defining intermediate variables referring to
|
|
|
|
intermediate expressions. Each equation defines one or more variables as the
|
|
|
|
result of applying a primitive on some atomic expressions. Each equation uses only
|
|
|
|
input variables and intermediate variables defined by previous equations.
|
2020-01-15 15:00:38 -08:00
|
|
|
* ``Expr+``: is a list of output atomic expressions for the jaxpr.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Equations are printed as follows::
|
|
|
|
|
|
|
|
Eqn ::= let Var+ = Primitive [ Param* ] Expr+
|
|
|
|
|
|
|
|
where:
|
2020-09-10 11:23:29 -07:00
|
|
|
* ``Var+`` are one or more intermediate variables to be defined as the
|
2020-02-10 11:40:05 +01:00
|
|
|
output of a primitive invocation (some primitives can return multiple values)
|
|
|
|
* ``Expr+`` are one or more atomic expressions, each either a variable or a
|
|
|
|
literal constant. A special form of an atomic expression is the `unit`
|
|
|
|
expression, printed as ``*`` and standing for a value that is not needed
|
|
|
|
in the rest of the computation and has been elided.
|
|
|
|
* ``Param*`` are zero or more named parameters to the primitive, printed in
|
|
|
|
square brackets. Each parameter is shown as ``Name = Value``.
|
|
|
|
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
Most jaxpr primitives are first-order (they take just one or more Expr as arguments)::
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Primitive := add | sub | sin | mul | ...
|
|
|
|
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
The jaxpr primitives are documented in the :py:mod:`jax.lax` module.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-04-28 00:44:46 +01:00
|
|
|
For example, here is the jaxpr produced for the function ``func1`` below
|
|
|
|
|
|
|
|
>>> from jax import make_jaxpr
|
2020-07-15 13:17:38 -07:00
|
|
|
>>> import jax.numpy as jnp
|
2020-04-28 00:44:46 +01:00
|
|
|
>>> def func1(first, second):
|
|
|
|
... temp = first + jnp.sin(second) * 3.
|
|
|
|
... return jnp.sum(temp)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = sin b
|
|
|
|
d = mul c 3.0
|
|
|
|
e = add a d
|
|
|
|
f = reduce_sum[ axes=(0,) ] e
|
|
|
|
in (f,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Here there are no constvars, ``a`` and ``b`` are the input variables
|
|
|
|
and they correspond respectively to
|
|
|
|
``first`` and ``second`` function parameters. The scalar literal ``3.0`` is kept
|
|
|
|
inline.
|
|
|
|
The ``reduce_sum`` primitive has named parameters ``axes`` and ``input_shape``, in
|
|
|
|
addition to the operand ``e``.
|
|
|
|
|
|
|
|
Note that JAX traces through Python-level control-flow and higher-order functions
|
2020-01-15 15:00:38 -08:00
|
|
|
when it extracts the jaxpr. This means that just because a Python program contains
|
|
|
|
functions and control-flow, the resulting jaxpr does not have
|
2020-02-10 11:40:05 +01:00
|
|
|
to contain control-flow or higher-order features.
|
|
|
|
For example, when tracing the function ``func3`` JAX will inline the call to
|
|
|
|
``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
|
2020-04-28 00:44:46 +01:00
|
|
|
jaxpr as before
|
|
|
|
|
|
|
|
>>> def func2(inner, first, second):
|
|
|
|
... temp = first + inner(second) * 3.
|
|
|
|
... return jnp.sum(temp)
|
|
|
|
...
|
|
|
|
>>> def inner(second):
|
|
|
|
... if second.shape[0] > 4:
|
|
|
|
... return jnp.sin(second)
|
|
|
|
... else:
|
|
|
|
... assert False
|
|
|
|
...
|
|
|
|
>>> def func3(first, second):
|
|
|
|
... return func2(inner, first, second)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = sin b
|
|
|
|
d = mul c 3.0
|
|
|
|
e = add a d
|
|
|
|
f = reduce_sum[ axes=(0,) ] e
|
|
|
|
in (f,) }
|
|
|
|
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Handling PyTrees
|
|
|
|
----------------
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
In jaxpr there are no tuple types; instead primitives take multiple inputs
|
2020-02-10 11:40:05 +01:00
|
|
|
and produce multiple outputs. When processing a function that has structured
|
2020-01-15 15:00:38 -08:00
|
|
|
inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists
|
2020-02-10 11:40:05 +01:00
|
|
|
of inputs and outputs. For more details, please see the documentation for
|
|
|
|
PyTrees (:doc:`notebooks/JAX_pytrees`).
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
For example, the following code produces an identical jaxpr to what we saw
|
2020-04-28 00:44:46 +01:00
|
|
|
before (with two input vars, one for each element of the input tuple)
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
|
2020-04-28 00:44:46 +01:00
|
|
|
>>> def func4(arg): # Arg is a pair
|
|
|
|
... temp = arg[0] + jnp.sin(arg[1]) * 3.
|
|
|
|
... return jnp.sum(temp)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = sin b
|
|
|
|
d = mul c 3.0
|
|
|
|
e = add a d
|
|
|
|
f = reduce_sum[ axes=(0,) ] e
|
|
|
|
in (f,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Constant Vars
|
2020-09-15 08:06:46 -07:00
|
|
|
-------------
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
Some values in jaxprs are constants, in that their value does not depend on the
|
|
|
|
jaxpr's arguments. When these values are scalars they are represented directly
|
|
|
|
in the jaxpr equations; non-scalar array constants are instead hoisted out to
|
|
|
|
the top-level jaxpr, where they correspond to constant variables ("constvars").
|
|
|
|
These constvars differ from the other jaxpr parameters ("invars") only as a
|
|
|
|
bookkeeping convention.
|
2020-04-28 00:44:46 +01:00
|
|
|
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Higher-order primitives
|
|
|
|
-----------------------
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
jaxpr includes several higher-order primitives. They are more complicated because
|
|
|
|
they include sub-jaxprs.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-05-11 16:48:30 -07:00
|
|
|
Conditionals
|
|
|
|
^^^^^^^^^^^^
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-05-11 16:48:30 -07:00
|
|
|
JAX traces through normal Python conditionals. To capture a
|
|
|
|
conditional expression for dynamic execution, one must use the
|
2020-06-02 19:54:23 -07:00
|
|
|
:py:func:`jax.lax.switch` and :py:func:`jax.lax.cond` constructors,
|
|
|
|
which have the signatures::
|
|
|
|
|
|
|
|
lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-05-11 16:48:30 -07:00
|
|
|
lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B
|
2020-04-28 00:44:46 +01:00
|
|
|
|
2020-06-02 19:54:23 -07:00
|
|
|
Both of these will bind a primitive called ``cond`` internally. The
|
|
|
|
``cond`` primitive in jaxprs reflects the more general signature of
|
|
|
|
:py:func:`lax.switch`: it takes an integer denoting the index of the branch
|
|
|
|
to execute (clamped into valid indexing range).
|
|
|
|
|
2020-05-11 16:48:30 -07:00
|
|
|
For example:
|
2020-04-28 00:44:46 +01:00
|
|
|
|
2020-06-02 19:54:23 -07:00
|
|
|
>>> from jax import lax
|
|
|
|
>>>
|
|
|
|
>>> def one_of_three(index, arg):
|
|
|
|
... return lax.switch(index, [lambda x: x + 1.,
|
|
|
|
... lambda x: x - 2.,
|
|
|
|
... lambda x: x + 3.],
|
|
|
|
... arg)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(one_of_three)(1, 5.))
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = clamp 0 a 2
|
|
|
|
d = cond[ branches=( { lambda ; a.
|
|
|
|
let b = add a 1.0
|
|
|
|
in (b,) }
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = sub a 2.0
|
|
|
|
in (b,) }
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = add a 3.0
|
|
|
|
in (b,) } )
|
|
|
|
linear=(False,) ] c b
|
|
|
|
in (d,) }
|
|
|
|
|
|
|
|
The cond primitive has a number of parameters:
|
|
|
|
|
|
|
|
* `branches` are jaxprs that correspond to the branch
|
|
|
|
functionals. In this example, those functionals each take one
|
|
|
|
input variable, corresponding to ``x``.
|
|
|
|
* `linear` is a tuple of booleans that is used internally by the
|
|
|
|
auto-differentiation machinery to encode which of the input
|
|
|
|
parameters are used linearly in the conditional.
|
|
|
|
|
|
|
|
The above instance of the cond primitive takes two operands. The first
|
|
|
|
one (``c``) is the branch index, then ``b`` is the operand (``arg``) to
|
|
|
|
be passed to whichever jaxpr in ``branches`` is selected by the branch
|
|
|
|
index.
|
|
|
|
|
|
|
|
Another example, using :py:func:`lax.cond`:
|
|
|
|
|
2020-04-28 00:44:46 +01:00
|
|
|
>>> from jax import lax
|
|
|
|
>>>
|
|
|
|
>>> def func7(arg):
|
|
|
|
... return lax.cond(arg >= 0.,
|
|
|
|
... lambda xtrue: xtrue + 3.,
|
2020-05-11 16:48:30 -07:00
|
|
|
... lambda xfalse: xfalse - 3.,
|
|
|
|
... arg)
|
2020-04-28 00:44:46 +01:00
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func7)(5.))
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = ge a 0.0
|
2020-06-02 19:54:23 -07:00
|
|
|
c = convert_element_type[ new_dtype=int32
|
|
|
|
old_dtype=bool ] b
|
|
|
|
d = cond[ branches=( { lambda ; a.
|
|
|
|
let b = sub a 3.0
|
|
|
|
in (b,) }
|
|
|
|
{ lambda ; a.
|
2020-04-28 00:44:46 +01:00
|
|
|
let b = add a 3.0
|
2020-06-02 19:54:23 -07:00
|
|
|
in (b,) } )
|
|
|
|
linear=(False,) ] c a
|
|
|
|
in (d,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
|
2020-06-02 19:54:23 -07:00
|
|
|
In this case, the boolean predicate is converted to an integer index
|
|
|
|
(0 or 1), and ``branches`` are jaxprs that correspond to the false and
|
|
|
|
true branch functionals, in that order. Again, each functional takes
|
|
|
|
one input variable, corresponding to ``xtrue`` and ``xfalse``
|
|
|
|
respectively.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
The following example shows a more complicated situation when the input
|
|
|
|
to the branch functionals is a tuple, and the `false` branch functional
|
2020-04-28 00:44:46 +01:00
|
|
|
contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
|
|
|
|
|
|
|
|
>>> def func8(arg1, arg2): # arg2 is a pair
|
|
|
|
... return lax.cond(arg1 >= 0.,
|
|
|
|
... lambda xtrue: xtrue[0],
|
2020-09-15 08:06:46 -07:00
|
|
|
... lambda xfalse: jnp.array([1]) + xfalse[1],
|
2020-05-11 16:48:30 -07:00
|
|
|
... arg2)
|
2020-04-28 00:44:46 +01:00
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
|
2020-09-15 08:06:46 -07:00
|
|
|
{ lambda a ; b c d.
|
|
|
|
let e = ge b 0.0
|
|
|
|
f = convert_element_type[ new_dtype=int32
|
|
|
|
old_dtype=bool ] e
|
|
|
|
g = cond[ branches=( { lambda ; a b c.
|
|
|
|
let d = convert_element_type[ new_dtype=float32
|
|
|
|
old_dtype=int32 ] a
|
|
|
|
e = add d c
|
|
|
|
in (e,) }
|
|
|
|
{ lambda ; f_ a b.
|
|
|
|
let
|
2020-06-02 19:54:23 -07:00
|
|
|
in (a,) } )
|
2020-09-15 08:06:46 -07:00
|
|
|
linear=(False, False, False) ] f a c d
|
2020-06-02 19:54:23 -07:00
|
|
|
in (g,) }
|
|
|
|
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
While
|
|
|
|
^^^^^
|
|
|
|
|
|
|
|
Just like for conditionals, Python loops are inlined during tracing.
|
|
|
|
If you want to capture a loop for dynamic execution, you must use one of several
|
|
|
|
special operations, :py:func:`jax.lax.while_loop` (a primitive)
|
|
|
|
and :py:func:`jax.lax.fori_loop`
|
|
|
|
(a helper that generates a while_loop primitive)::
|
|
|
|
|
|
|
|
lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
|
|
|
|
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
|
|
|
|
|
|
|
|
|
|
|
|
In the above signature, “C” stands for the type of a the loop “carry” value.
|
2020-04-28 00:44:46 +01:00
|
|
|
For example, here is an example fori loop
|
|
|
|
|
2020-07-15 13:17:38 -07:00
|
|
|
>>> import numpy as np
|
2020-04-28 00:44:46 +01:00
|
|
|
>>>
|
|
|
|
>>> def func10(arg, n):
|
|
|
|
... ones = jnp.ones(arg.shape) # A constant
|
|
|
|
... return lax.fori_loop(0, n,
|
|
|
|
... lambda i, carry: carry + ones * 3. + arg,
|
|
|
|
... arg + ones)
|
|
|
|
...
|
2020-07-15 13:17:38 -07:00
|
|
|
>>> print(make_jaxpr(func10)(np.ones(16), 5))
|
2020-09-15 08:06:46 -07:00
|
|
|
{ lambda ; a b.
|
|
|
|
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
|
|
|
shape=(16,) ] 1.0
|
|
|
|
d = add a c
|
|
|
|
_ _ e = while[ body_jaxpr={ lambda ; a b c d e.
|
|
|
|
let f = add c 1
|
|
|
|
g = mul a 3.0
|
|
|
|
h = add e g
|
|
|
|
i = add h b
|
|
|
|
in (f, d, i) }
|
2020-04-28 00:44:46 +01:00
|
|
|
body_nconsts=2
|
|
|
|
cond_jaxpr={ lambda ; a b c.
|
|
|
|
let d = lt a b
|
|
|
|
in (d,) }
|
2020-09-15 08:06:46 -07:00
|
|
|
cond_nconsts=0 ] c a 0 b d
|
|
|
|
in (e,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
The while primitive takes 5 arguments: ``c a 0 b e``, as follows:
|
|
|
|
|
|
|
|
* 0 constants for ``cond_jaxpr`` (since ``cond_nconsts`` is 0)
|
|
|
|
* 2 constants for ``body_jaxpr`` (``c``, and ``a``)
|
|
|
|
* 3 parameters for the initial value of carry
|
|
|
|
|
|
|
|
Scan
|
|
|
|
^^^^
|
|
|
|
|
|
|
|
JAX supports a special form of loop over the elements of an array (with
|
|
|
|
statically known shape). The fact that there are a fixed number of iterations
|
2020-09-15 08:06:46 -07:00
|
|
|
makes this form of looping easily reverse-differentiable. Such loops are
|
|
|
|
constructed with the :py:func:`jax.lax.scan` function::
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
Here ``C`` is the type of the scan carry, ``A`` is the element type of the
|
|
|
|
input array(s), and ``B`` is the element type of the output array(s).
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-04-28 00:44:46 +01:00
|
|
|
For the example consider the function ``func11`` below
|
|
|
|
|
|
|
|
>>> def func11(arr, extra):
|
|
|
|
... ones = jnp.ones(arr.shape) # A constant
|
|
|
|
... def body(carry, aelems):
|
|
|
|
... # carry: running dot-product of the two arrays
|
|
|
|
... # aelems: a pair with corresponding elements from the two arrays
|
|
|
|
... ae1, ae2 = aelems
|
|
|
|
... return (carry + ae1 * ae2 + extra, carry)
|
|
|
|
... return lax.scan(body, 0., (arr, ones))
|
|
|
|
...
|
2020-07-15 13:17:38 -07:00
|
|
|
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
|
2020-09-15 08:06:46 -07:00
|
|
|
{ lambda ; a b.
|
|
|
|
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
|
|
|
shape=(16,) ] 1.0
|
|
|
|
d e = scan[ jaxpr={ lambda ; a b c d.
|
|
|
|
let e = mul c d
|
|
|
|
f = add b e
|
|
|
|
g = add f a
|
|
|
|
in (g, b) }
|
2020-04-28 00:44:46 +01:00
|
|
|
length=16
|
|
|
|
linear=(False, False, False, False)
|
|
|
|
num_carry=1
|
2020-05-04 19:44:22 -07:00
|
|
|
num_consts=1
|
2020-07-15 11:00:50 -07:00
|
|
|
reverse=False
|
|
|
|
unroll=1 ] b 0.0 a c
|
2020-04-28 00:44:46 +01:00
|
|
|
in (d, e) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
The ``linear`` parameter describes for each of the input variables whether they
|
2020-04-07 19:03:41 -07:00
|
|
|
are guaranteed to be used linearly in the body. Once the scan goes through
|
|
|
|
linearization, more arguments will be linear.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-04-07 19:03:41 -07:00
|
|
|
The scan primitive takes 4 arguments: ``b 0.0 a c``, of which:
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
* one is the free variable for the body
|
|
|
|
* one is the initial value of the carry
|
2020-04-07 19:03:41 -07:00
|
|
|
* The next 2 are the arrays over which the scan operates.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
XLA_call
|
|
|
|
^^^^^^^^
|
|
|
|
|
|
|
|
The call primitive arises from JIT compilation, and it encapsulates
|
2020-01-15 15:00:38 -08:00
|
|
|
a sub-jaxpr along with parameters the specify the backend and the device the
|
2020-04-28 00:44:46 +01:00
|
|
|
computation should run. For example
|
|
|
|
|
|
|
|
>>> from jax import jit
|
|
|
|
>>>
|
|
|
|
>>> def func12(arg):
|
|
|
|
... @jit
|
|
|
|
... def inner(x):
|
|
|
|
... return x + arg * jnp.ones(1) # Include a constant in the inner function
|
|
|
|
... return arg + inner(arg - 2.)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func12)(1.))
|
2020-09-15 08:06:46 -07:00
|
|
|
{ lambda ; a.
|
|
|
|
let b = sub a 2.0
|
|
|
|
c = xla_call[ backend=None
|
|
|
|
call_jaxpr={ lambda ; a b.
|
|
|
|
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
|
|
|
shape=(1,) ] 1.0
|
|
|
|
d = mul a c
|
|
|
|
e = add b d
|
2020-04-28 00:44:46 +01:00
|
|
|
in (e,) }
|
|
|
|
device=None
|
2020-09-15 08:06:46 -07:00
|
|
|
donated_invars=(False, False)
|
|
|
|
name=inner ] a b
|
|
|
|
d = add a c
|
|
|
|
in (d,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
|
|
|
|
XLA_pmap
|
|
|
|
^^^^^^^^
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
If you use the :py:func:`jax.pmap` transformation, the function to be mapped is
|
|
|
|
captured using the ``xla_pmap`` primitive. Consider this example
|
2020-04-28 00:44:46 +01:00
|
|
|
|
|
|
|
>>> from jax import pmap
|
|
|
|
>>>
|
|
|
|
>>> def func13(arr, extra):
|
|
|
|
... def inner(x):
|
|
|
|
... # use a free variable "extra" and a constant jnp.ones(1)
|
|
|
|
... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
|
|
|
|
... return pmap(inner, axis_name='rows')(arr)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
|
2020-09-15 08:06:46 -07:00
|
|
|
{ lambda ; a b.
|
|
|
|
let c = xla_pmap[ axis_name=rows
|
2020-04-28 00:44:46 +01:00
|
|
|
axis_size=1
|
|
|
|
backend=None
|
2020-09-15 08:06:46 -07:00
|
|
|
call_jaxpr={ lambda ; a b.
|
|
|
|
let c = add b a
|
|
|
|
d = broadcast_in_dim[ broadcast_dimensions=( )
|
|
|
|
shape=(1,) ] 1.0
|
2020-04-28 00:44:46 +01:00
|
|
|
e = add c d
|
2020-05-08 14:00:34 -07:00
|
|
|
f = psum[ axis_index_groups=None
|
2020-09-15 08:06:46 -07:00
|
|
|
axis_name=rows ] b
|
2020-04-28 00:44:46 +01:00
|
|
|
g = div e f
|
|
|
|
in (g,) }
|
|
|
|
devices=None
|
2020-09-15 08:06:46 -07:00
|
|
|
donated_invars=(False, False)
|
2020-04-28 00:44:46 +01:00
|
|
|
global_axis_size=None
|
2020-09-15 08:06:46 -07:00
|
|
|
mapped_invars=(False, True)
|
|
|
|
name=inner ] b a
|
|
|
|
in (c,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
The ``xla_pmap`` primitive specifies the name of the axis (parameter ``rows``)
|
2020-09-15 08:06:46 -07:00
|
|
|
and the body of the function to be mapped as the ``call_jaxpr`` parameter.
|
2020-02-10 11:40:05 +01:00
|
|
|
value of this parameter is a Jaxpr with 3 input variables:
|
|
|
|
|
|
|
|
The parameter ``mapped_invars`` specify which of the input variables should be
|
|
|
|
mapped and which should be broadcast. In our example, the value of ``extra``
|
|
|
|
is broadcast, the other input values are mapped.
|