2020-04-08 11:29:02 -07:00
|
|
|
Understanding Jaxprs
|
2020-02-10 11:40:05 +01:00
|
|
|
====================
|
|
|
|
|
2020-05-04 11:20:21 +01:00
|
|
|
Updated: May 3, 2020 (for commit f1a46fe).
|
2020-02-13 09:28:01 +01:00
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
Conceptually, one can think of JAX transformations as first trace-specializing
|
|
|
|
the Python function to be transformed into a small and well-behaved
|
|
|
|
intermediate form that is then interpreted with transformation-specific
|
|
|
|
interpretation rules. One of the reasons JAX can pack so much power into such a
|
|
|
|
small software package is that it starts with a familiar and flexible
|
|
|
|
programming interface (Python with NumPy) and it uses the actual Python
|
|
|
|
interpreter to do most of the heavy lifting to distill the essence of the
|
|
|
|
computation into a simple statically-typed expression language with limited
|
|
|
|
higher-order features. That language is the jaxpr language.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Not all Python programs can be processed this way, but it turns out that many
|
2020-09-18 10:07:13 -07:00
|
|
|
scientific computing and machine learning programs can.
|
|
|
|
|
|
|
|
Before we proceed, it is important to point out that not all JAX
|
|
|
|
transformations literally materialize a jaxpr as described above; some, e.g.,
|
|
|
|
differentiation or batching, will apply transformations incrementally during
|
|
|
|
tracing. Nevertheless, if one wants to understand how JAX works internally, or
|
|
|
|
to make use of the result of JAX tracing, it is useful to understand jaxprs.
|
|
|
|
|
|
|
|
A jaxpr instance represents a function with one or more typed parameters (input
|
|
|
|
variables) and one or more typed results. The results depend only on the input
|
|
|
|
variables; there are no free variables captured from enclosing scopes. The
|
|
|
|
inputs and outputs have types, which in JAX are represented as abstract values.
|
|
|
|
There are two related representations in the code for jaxprs,
|
|
|
|
:py:class:`jax.core.Jaxpr` and :py:class:`jax.core.ClosedJaxpr`. A
|
|
|
|
:py:class:`jax.core.ClosedJaxpr` represents a partially-applied
|
|
|
|
:py:class:`jax.core.Jaxpr`, and is what you obtain when you use
|
|
|
|
:py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following fields:
|
|
|
|
|
|
|
|
* ``jaxpr``: is a :py:class:`jax.core.Jaxpr` representing the actual
|
|
|
|
computation content of the function (described below).
|
|
|
|
* ``consts`` is a list of constants.
|
|
|
|
|
|
|
|
The most interesting part of the ClosedJaxpr is the actual execution content,
|
2020-02-10 11:40:05 +01:00
|
|
|
represented as a :py:class:`jax.core.Jaxpr` as printed using the following
|
|
|
|
grammar::
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
jaxpr ::= { lambda Var* ; Var+.
|
2020-02-10 11:40:05 +01:00
|
|
|
let Eqn*
|
|
|
|
in [Expr+] }
|
|
|
|
|
|
|
|
where:
|
2020-09-18 10:07:13 -07:00
|
|
|
* The parameters of the jaxpr are shown as two lists of variables separated by
|
2020-02-10 11:40:05 +01:00
|
|
|
``;``. The first set of variables are the ones that have been introduced
|
|
|
|
to stand for constants that have been hoisted out. These are called the
|
2020-09-18 10:07:13 -07:00
|
|
|
``constvars``, and in a :py:class:`jax.core.ClosedJaxpr` the ``consts``
|
|
|
|
field holds corresponding values. The second list of variables, called
|
|
|
|
``invars``, correspond to the inputs of the traced Python function.
|
2020-02-10 11:40:05 +01:00
|
|
|
* ``Eqn*`` is a list of equations, defining intermediate variables referring to
|
|
|
|
intermediate expressions. Each equation defines one or more variables as the
|
|
|
|
result of applying a primitive on some atomic expressions. Each equation uses only
|
|
|
|
input variables and intermediate variables defined by previous equations.
|
2020-09-18 10:07:13 -07:00
|
|
|
* ``Expr+``: is a list of output atomic expressions (literals or variables)
|
|
|
|
for the jaxpr.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Equations are printed as follows::
|
|
|
|
|
|
|
|
Eqn ::= let Var+ = Primitive [ Param* ] Expr+
|
|
|
|
|
|
|
|
where:
|
2020-09-10 11:23:29 -07:00
|
|
|
* ``Var+`` are one or more intermediate variables to be defined as the
|
2020-02-10 11:40:05 +01:00
|
|
|
output of a primitive invocation (some primitives can return multiple values)
|
|
|
|
* ``Expr+`` are one or more atomic expressions, each either a variable or a
|
2020-09-18 10:07:13 -07:00
|
|
|
literal constant. A special variable ``unitvar`` or literal ``unit``,
|
|
|
|
printed as ``*``, represents a value that is not needed
|
|
|
|
in the rest of the computation and has been elided. That is, units are just
|
|
|
|
placeholders.
|
2020-02-10 11:40:05 +01:00
|
|
|
* ``Param*`` are zero or more named parameters to the primitive, printed in
|
|
|
|
square brackets. Each parameter is shown as ``Name = Value``.
|
|
|
|
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
Most jaxpr primitives are first-order (they take just one or more Expr as arguments)::
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Primitive := add | sub | sin | mul | ...
|
|
|
|
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
The jaxpr primitives are documented in the :py:mod:`jax.lax` module.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-04-28 00:44:46 +01:00
|
|
|
For example, here is the jaxpr produced for the function ``func1`` below
|
|
|
|
|
|
|
|
>>> from jax import make_jaxpr
|
2020-07-15 13:17:38 -07:00
|
|
|
>>> import jax.numpy as jnp
|
2020-04-28 00:44:46 +01:00
|
|
|
>>> def func1(first, second):
|
|
|
|
... temp = first + jnp.sin(second) * 3.
|
|
|
|
... return jnp.sum(temp)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = sin b
|
|
|
|
d = mul c 3.0
|
|
|
|
e = add a d
|
|
|
|
f = reduce_sum[ axes=(0,) ] e
|
|
|
|
in (f,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Here there are no constvars, ``a`` and ``b`` are the input variables
|
|
|
|
and they correspond respectively to
|
|
|
|
``first`` and ``second`` function parameters. The scalar literal ``3.0`` is kept
|
|
|
|
inline.
|
|
|
|
The ``reduce_sum`` primitive has named parameters ``axes`` and ``input_shape``, in
|
|
|
|
addition to the operand ``e``.
|
|
|
|
|
|
|
|
Note that JAX traces through Python-level control-flow and higher-order functions
|
2020-01-15 15:00:38 -08:00
|
|
|
when it extracts the jaxpr. This means that just because a Python program contains
|
|
|
|
functions and control-flow, the resulting jaxpr does not have
|
2020-02-10 11:40:05 +01:00
|
|
|
to contain control-flow or higher-order features.
|
|
|
|
For example, when tracing the function ``func3`` JAX will inline the call to
|
|
|
|
``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
|
2020-04-28 00:44:46 +01:00
|
|
|
jaxpr as before
|
|
|
|
|
|
|
|
>>> def func2(inner, first, second):
|
|
|
|
... temp = first + inner(second) * 3.
|
|
|
|
... return jnp.sum(temp)
|
|
|
|
...
|
|
|
|
>>> def inner(second):
|
|
|
|
... if second.shape[0] > 4:
|
|
|
|
... return jnp.sin(second)
|
|
|
|
... else:
|
|
|
|
... assert False
|
|
|
|
...
|
|
|
|
>>> def func3(first, second):
|
|
|
|
... return func2(inner, first, second)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = sin b
|
|
|
|
d = mul c 3.0
|
|
|
|
e = add a d
|
|
|
|
f = reduce_sum[ axes=(0,) ] e
|
|
|
|
in (f,) }
|
|
|
|
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Handling PyTrees
|
|
|
|
----------------
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
In jaxpr there are no tuple types; instead primitives take multiple inputs
|
2020-02-10 11:40:05 +01:00
|
|
|
and produce multiple outputs. When processing a function that has structured
|
2020-01-15 15:00:38 -08:00
|
|
|
inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists
|
2020-02-10 11:40:05 +01:00
|
|
|
of inputs and outputs. For more details, please see the documentation for
|
|
|
|
PyTrees (:doc:`notebooks/JAX_pytrees`).
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
For example, the following code produces an identical jaxpr to what we saw
|
2020-04-28 00:44:46 +01:00
|
|
|
before (with two input vars, one for each element of the input tuple)
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
|
2020-04-28 00:44:46 +01:00
|
|
|
>>> def func4(arg): # Arg is a pair
|
|
|
|
... temp = arg[0] + jnp.sin(arg[1]) * 3.
|
|
|
|
... return jnp.sum(temp)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = sin b
|
|
|
|
d = mul c 3.0
|
|
|
|
e = add a d
|
|
|
|
f = reduce_sum[ axes=(0,) ] e
|
|
|
|
in (f,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Constant Vars
|
2020-09-15 08:06:46 -07:00
|
|
|
-------------
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
Some values in jaxprs are constants, in that their value does not depend on the
|
|
|
|
jaxpr's arguments. When these values are scalars they are represented directly
|
|
|
|
in the jaxpr equations; non-scalar array constants are instead hoisted out to
|
|
|
|
the top-level jaxpr, where they correspond to constant variables ("constvars").
|
|
|
|
These constvars differ from the other jaxpr parameters ("invars") only as a
|
|
|
|
bookkeeping convention.
|
2020-04-28 00:44:46 +01:00
|
|
|
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
Higher-order primitives
|
|
|
|
-----------------------
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
jaxpr includes several higher-order primitives. They are more complicated because
|
|
|
|
they include sub-jaxprs.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-05-11 16:48:30 -07:00
|
|
|
Conditionals
|
|
|
|
^^^^^^^^^^^^
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-05-11 16:48:30 -07:00
|
|
|
JAX traces through normal Python conditionals. To capture a
|
|
|
|
conditional expression for dynamic execution, one must use the
|
2020-06-02 19:54:23 -07:00
|
|
|
:py:func:`jax.lax.switch` and :py:func:`jax.lax.cond` constructors,
|
|
|
|
which have the signatures::
|
|
|
|
|
|
|
|
lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-05-11 16:48:30 -07:00
|
|
|
lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B
|
2020-04-28 00:44:46 +01:00
|
|
|
|
2020-06-02 19:54:23 -07:00
|
|
|
Both of these will bind a primitive called ``cond`` internally. The
|
|
|
|
``cond`` primitive in jaxprs reflects the more general signature of
|
|
|
|
:py:func:`lax.switch`: it takes an integer denoting the index of the branch
|
|
|
|
to execute (clamped into valid indexing range).
|
|
|
|
|
2020-05-11 16:48:30 -07:00
|
|
|
For example:
|
2020-04-28 00:44:46 +01:00
|
|
|
|
2020-06-02 19:54:23 -07:00
|
|
|
>>> from jax import lax
|
|
|
|
>>>
|
|
|
|
>>> def one_of_three(index, arg):
|
|
|
|
... return lax.switch(index, [lambda x: x + 1.,
|
|
|
|
... lambda x: x - 2.,
|
|
|
|
... lambda x: x + 3.],
|
|
|
|
... arg)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(one_of_three)(1, 5.))
|
|
|
|
{ lambda ; a b.
|
2020-12-23 11:01:58 -08:00
|
|
|
let c = clamp 0 a 2
|
|
|
|
d = cond[ branches=( { lambda ; a.
|
2020-06-02 19:54:23 -07:00
|
|
|
let b = add a 1.0
|
|
|
|
in (b,) }
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = sub a 2.0
|
|
|
|
in (b,) }
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = add a 3.0
|
|
|
|
in (b,) } )
|
2020-12-23 11:01:58 -08:00
|
|
|
linear=(False,) ] c b
|
|
|
|
in (d,) }
|
2020-06-02 19:54:23 -07:00
|
|
|
|
|
|
|
The cond primitive has a number of parameters:
|
|
|
|
|
|
|
|
* `branches` are jaxprs that correspond to the branch
|
|
|
|
functionals. In this example, those functionals each take one
|
|
|
|
input variable, corresponding to ``x``.
|
|
|
|
* `linear` is a tuple of booleans that is used internally by the
|
|
|
|
auto-differentiation machinery to encode which of the input
|
|
|
|
parameters are used linearly in the conditional.
|
|
|
|
|
|
|
|
The above instance of the cond primitive takes two operands. The first
|
|
|
|
one (``c``) is the branch index, then ``b`` is the operand (``arg``) to
|
|
|
|
be passed to whichever jaxpr in ``branches`` is selected by the branch
|
|
|
|
index.
|
|
|
|
|
|
|
|
Another example, using :py:func:`lax.cond`:
|
|
|
|
|
2020-04-28 00:44:46 +01:00
|
|
|
>>> from jax import lax
|
|
|
|
>>>
|
|
|
|
>>> def func7(arg):
|
|
|
|
... return lax.cond(arg >= 0.,
|
|
|
|
... lambda xtrue: xtrue + 3.,
|
2020-05-11 16:48:30 -07:00
|
|
|
... lambda xfalse: xfalse - 3.,
|
|
|
|
... arg)
|
2020-04-28 00:44:46 +01:00
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func7)(5.))
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = ge a 0.0
|
2020-12-23 11:01:58 -08:00
|
|
|
c = convert_element_type[ new_dtype=int32 ] b
|
2020-06-02 19:54:23 -07:00
|
|
|
d = cond[ branches=( { lambda ; a.
|
|
|
|
let b = sub a 3.0
|
|
|
|
in (b,) }
|
|
|
|
{ lambda ; a.
|
2020-04-28 00:44:46 +01:00
|
|
|
let b = add a 3.0
|
2020-06-02 19:54:23 -07:00
|
|
|
in (b,) } )
|
|
|
|
linear=(False,) ] c a
|
|
|
|
in (d,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
|
2020-06-02 19:54:23 -07:00
|
|
|
In this case, the boolean predicate is converted to an integer index
|
|
|
|
(0 or 1), and ``branches`` are jaxprs that correspond to the false and
|
|
|
|
true branch functionals, in that order. Again, each functional takes
|
|
|
|
one input variable, corresponding to ``xtrue`` and ``xfalse``
|
|
|
|
respectively.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
The following example shows a more complicated situation when the input
|
|
|
|
to the branch functionals is a tuple, and the `false` branch functional
|
2020-04-28 00:44:46 +01:00
|
|
|
contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
|
|
|
|
|
|
|
|
>>> def func8(arg1, arg2): # arg2 is a pair
|
|
|
|
... return lax.cond(arg1 >= 0.,
|
|
|
|
... lambda xtrue: xtrue[0],
|
2020-09-15 08:06:46 -07:00
|
|
|
... lambda xfalse: jnp.array([1]) + xfalse[1],
|
2020-05-11 16:48:30 -07:00
|
|
|
... arg2)
|
2020-04-28 00:44:46 +01:00
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
|
2020-09-15 08:06:46 -07:00
|
|
|
{ lambda a ; b c d.
|
|
|
|
let e = ge b 0.0
|
2020-12-23 11:01:58 -08:00
|
|
|
f = convert_element_type[ new_dtype=int32 ] e
|
2020-09-15 08:06:46 -07:00
|
|
|
g = cond[ branches=( { lambda ; a b c.
|
2020-12-23 11:01:58 -08:00
|
|
|
let d = convert_element_type[ new_dtype=float32 ] a
|
2020-09-15 08:06:46 -07:00
|
|
|
e = add d c
|
|
|
|
in (e,) }
|
|
|
|
{ lambda ; f_ a b.
|
2020-12-23 11:01:58 -08:00
|
|
|
let
|
2020-06-02 19:54:23 -07:00
|
|
|
in (a,) } )
|
2020-09-15 08:06:46 -07:00
|
|
|
linear=(False, False, False) ] f a c d
|
2020-06-02 19:54:23 -07:00
|
|
|
in (g,) }
|
|
|
|
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
While
|
|
|
|
^^^^^
|
|
|
|
|
|
|
|
Just like for conditionals, Python loops are inlined during tracing.
|
|
|
|
If you want to capture a loop for dynamic execution, you must use one of several
|
|
|
|
special operations, :py:func:`jax.lax.while_loop` (a primitive)
|
|
|
|
and :py:func:`jax.lax.fori_loop`
|
|
|
|
(a helper that generates a while_loop primitive)::
|
|
|
|
|
|
|
|
lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
|
|
|
|
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
|
|
|
|
|
|
|
|
|
|
|
|
In the above signature, “C” stands for the type of a the loop “carry” value.
|
2020-04-28 00:44:46 +01:00
|
|
|
For example, here is an example fori loop
|
|
|
|
|
2020-07-15 13:17:38 -07:00
|
|
|
>>> import numpy as np
|
2020-04-28 00:44:46 +01:00
|
|
|
>>>
|
|
|
|
>>> def func10(arg, n):
|
|
|
|
... ones = jnp.ones(arg.shape) # A constant
|
|
|
|
... return lax.fori_loop(0, n,
|
|
|
|
... lambda i, carry: carry + ones * 3. + arg,
|
|
|
|
... arg + ones)
|
|
|
|
...
|
2020-07-15 13:17:38 -07:00
|
|
|
>>> print(make_jaxpr(func10)(np.ones(16), 5))
|
2020-09-15 08:06:46 -07:00
|
|
|
{ lambda ; a b.
|
|
|
|
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
|
|
|
shape=(16,) ] 1.0
|
|
|
|
d = add a c
|
|
|
|
_ _ e = while[ body_jaxpr={ lambda ; a b c d e.
|
|
|
|
let f = add c 1
|
|
|
|
g = mul a 3.0
|
|
|
|
h = add e g
|
|
|
|
i = add h b
|
|
|
|
in (f, d, i) }
|
2020-04-28 00:44:46 +01:00
|
|
|
body_nconsts=2
|
|
|
|
cond_jaxpr={ lambda ; a b c.
|
|
|
|
let d = lt a b
|
|
|
|
in (d,) }
|
2020-09-15 08:06:46 -07:00
|
|
|
cond_nconsts=0 ] c a 0 b d
|
|
|
|
in (e,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
The while primitive takes 5 arguments: ``c a 0 b e``, as follows:
|
|
|
|
|
|
|
|
* 0 constants for ``cond_jaxpr`` (since ``cond_nconsts`` is 0)
|
|
|
|
* 2 constants for ``body_jaxpr`` (``c``, and ``a``)
|
|
|
|
* 3 parameters for the initial value of carry
|
|
|
|
|
|
|
|
Scan
|
|
|
|
^^^^
|
|
|
|
|
|
|
|
JAX supports a special form of loop over the elements of an array (with
|
|
|
|
statically known shape). The fact that there are a fixed number of iterations
|
2020-09-15 08:06:46 -07:00
|
|
|
makes this form of looping easily reverse-differentiable. Such loops are
|
|
|
|
constructed with the :py:func:`jax.lax.scan` function::
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
Here ``C`` is the type of the scan carry, ``A`` is the element type of the
|
|
|
|
input array(s), and ``B`` is the element type of the output array(s).
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-04-28 00:44:46 +01:00
|
|
|
For the example consider the function ``func11`` below
|
|
|
|
|
|
|
|
>>> def func11(arr, extra):
|
|
|
|
... ones = jnp.ones(arr.shape) # A constant
|
|
|
|
... def body(carry, aelems):
|
|
|
|
... # carry: running dot-product of the two arrays
|
|
|
|
... # aelems: a pair with corresponding elements from the two arrays
|
|
|
|
... ae1, ae2 = aelems
|
|
|
|
... return (carry + ae1 * ae2 + extra, carry)
|
|
|
|
... return lax.scan(body, 0., (arr, ones))
|
|
|
|
...
|
2020-07-15 13:17:38 -07:00
|
|
|
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
|
2020-09-15 08:06:46 -07:00
|
|
|
{ lambda ; a b.
|
|
|
|
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
|
|
|
shape=(16,) ] 1.0
|
|
|
|
d e = scan[ jaxpr={ lambda ; a b c d.
|
|
|
|
let e = mul c d
|
2020-12-23 11:01:58 -08:00
|
|
|
f = add b e
|
|
|
|
g = add f a
|
|
|
|
in (g, b) }
|
2020-04-28 00:44:46 +01:00
|
|
|
length=16
|
|
|
|
linear=(False, False, False, False)
|
|
|
|
num_carry=1
|
2020-05-04 19:44:22 -07:00
|
|
|
num_consts=1
|
2020-07-15 11:00:50 -07:00
|
|
|
reverse=False
|
|
|
|
unroll=1 ] b 0.0 a c
|
2020-04-28 00:44:46 +01:00
|
|
|
in (d, e) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
The ``linear`` parameter describes for each of the input variables whether they
|
2020-04-07 19:03:41 -07:00
|
|
|
are guaranteed to be used linearly in the body. Once the scan goes through
|
|
|
|
linearization, more arguments will be linear.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-04-07 19:03:41 -07:00
|
|
|
The scan primitive takes 4 arguments: ``b 0.0 a c``, of which:
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
* one is the free variable for the body
|
|
|
|
* one is the initial value of the carry
|
2020-04-07 19:03:41 -07:00
|
|
|
* The next 2 are the arrays over which the scan operates.
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
XLA_call
|
|
|
|
^^^^^^^^
|
|
|
|
|
|
|
|
The call primitive arises from JIT compilation, and it encapsulates
|
2020-01-15 15:00:38 -08:00
|
|
|
a sub-jaxpr along with parameters the specify the backend and the device the
|
2020-04-28 00:44:46 +01:00
|
|
|
computation should run. For example
|
|
|
|
|
|
|
|
>>> from jax import jit
|
|
|
|
>>>
|
|
|
|
>>> def func12(arg):
|
|
|
|
... @jit
|
|
|
|
... def inner(x):
|
|
|
|
... return x + arg * jnp.ones(1) # Include a constant in the inner function
|
|
|
|
... return arg + inner(arg - 2.)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func12)(1.))
|
2020-09-15 08:06:46 -07:00
|
|
|
{ lambda ; a.
|
|
|
|
let b = sub a 2.0
|
|
|
|
c = xla_call[ backend=None
|
|
|
|
call_jaxpr={ lambda ; a b.
|
|
|
|
let c = broadcast_in_dim[ broadcast_dimensions=( )
|
|
|
|
shape=(1,) ] 1.0
|
2020-12-23 11:01:58 -08:00
|
|
|
d = mul a c
|
|
|
|
e = add b d
|
|
|
|
in (e,) }
|
2020-04-28 00:44:46 +01:00
|
|
|
device=None
|
2020-09-15 08:06:46 -07:00
|
|
|
donated_invars=(False, False)
|
|
|
|
name=inner ] a b
|
2020-12-23 11:01:58 -08:00
|
|
|
d = add a c
|
|
|
|
in (d,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
|
|
|
|
XLA_pmap
|
|
|
|
^^^^^^^^
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
If you use the :py:func:`jax.pmap` transformation, the function to be mapped is
|
|
|
|
captured using the ``xla_pmap`` primitive. Consider this example
|
2020-04-28 00:44:46 +01:00
|
|
|
|
|
|
|
>>> from jax import pmap
|
|
|
|
>>>
|
|
|
|
>>> def func13(arr, extra):
|
|
|
|
... def inner(x):
|
|
|
|
... # use a free variable "extra" and a constant jnp.ones(1)
|
|
|
|
... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
|
|
|
|
... return pmap(inner, axis_name='rows')(arr)
|
|
|
|
...
|
|
|
|
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
|
2020-09-15 08:06:46 -07:00
|
|
|
{ lambda ; a b.
|
|
|
|
let c = xla_pmap[ axis_name=rows
|
2020-04-28 00:44:46 +01:00
|
|
|
axis_size=1
|
|
|
|
backend=None
|
2020-09-15 08:06:46 -07:00
|
|
|
call_jaxpr={ lambda ; a b.
|
2020-12-23 11:01:58 -08:00
|
|
|
let c = add b a
|
|
|
|
d = broadcast_in_dim[ broadcast_dimensions=( )
|
2020-09-15 08:06:46 -07:00
|
|
|
shape=(1,) ] 1.0
|
2020-12-23 11:01:58 -08:00
|
|
|
e = add c d
|
|
|
|
f = psum[ axis_index_groups=None
|
2020-11-24 10:45:03 -08:00
|
|
|
axis_name=('rows',) ] b
|
2020-12-23 11:01:58 -08:00
|
|
|
g = div e f
|
|
|
|
in (g,) }
|
2020-04-28 00:44:46 +01:00
|
|
|
devices=None
|
2020-09-15 08:06:46 -07:00
|
|
|
donated_invars=(False, False)
|
2020-11-13 13:30:29 -08:00
|
|
|
global_arg_shapes=(None,)
|
2020-04-28 00:44:46 +01:00
|
|
|
global_axis_size=None
|
2020-11-05 11:54:05 +00:00
|
|
|
in_axes=(None, 0)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
name=inner
|
|
|
|
out_axes=(0,) ] b a
|
2020-09-15 08:06:46 -07:00
|
|
|
in (c,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
|
|
|
|
The ``xla_pmap`` primitive specifies the name of the axis (parameter ``rows``)
|
2020-09-15 08:06:46 -07:00
|
|
|
and the body of the function to be mapped as the ``call_jaxpr`` parameter.
|
2020-02-10 11:40:05 +01:00
|
|
|
value of this parameter is a Jaxpr with 3 input variables:
|
|
|
|
|
2020-11-05 11:54:05 +00:00
|
|
|
The parameter ``in_axes`` specifies which of the input variables should be
|
2020-02-10 11:40:05 +01:00
|
|
|
mapped and which should be broadcast. In our example, the value of ``extra``
|
|
|
|
is broadcast, the other input values are mapped.
|