Update jaxpr.rst (#2859)

* Update jaxpr doc

* Make jaxpr.rst doctestable
This commit is contained in:
Jamie Townsend 2020-04-28 00:44:46 +01:00 committed by GitHub
parent 5da74d4b82
commit 283393f773
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -86,21 +86,21 @@ Most jaxpr primitives are first-order (they take just one or more Expr as argume
The jaxpr primitives are documented in the :py:mod:`jax.lax` module.
For example, here is the jaxpr produced for the function ``func1`` below::
For example, here is the jaxpr produced for the function ``func1`` below
from jax import numpy as jnp
def func1(first, second):
temp = first + jnp.sin(second) * 3.
return jnp.sum(temp)
print(jax.make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,)
input_shape=(8,) ] e
in f }
>>> from jax import make_jaxpr
>>> from jax import numpy as jnp
>>> def func1(first, second):
... temp = first + jnp.sin(second) * 3.
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,) ] e
in (f,) }
Here there are no constvars, ``a`` and ``b`` are the input variables
and they correspond respectively to
@ -115,29 +115,29 @@ functions and control-flow, the resulting jaxpr does not have
to contain control-flow or higher-order features.
For example, when tracing the function ``func3`` JAX will inline the call to
``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
jaxpr as before::
jaxpr as before
def func2(inner, first, second):
temp = first + inner(second) * 3.
return jnp.sum(temp)
>>> def func2(inner, first, second):
... temp = first + inner(second) * 3.
... return jnp.sum(temp)
...
>>> def inner(second):
... if second.shape[0] > 4:
... return jnp.sin(second)
... else:
... assert False
...
>>> def func3(first, second):
... return func2(inner, first, second)
...
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,) ] e
in (f,) }
def inner(second):
if second.shape[0] > 4:
return jnp.sin(second)
else:
assert False
def func3(first, second):
return func2(inner, first, second)
print(api.make_jaxpr(func2)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,)
input_shape=(8,) ] e
in f }
Handling PyTrees
----------------
@ -149,21 +149,20 @@ of inputs and outputs. For more details, please see the documentation for
PyTrees (:doc:`notebooks/JAX_pytrees`).
For example, the following code produces an identical jaxpr to what we saw
before (with two input vars, one for each element of the input tuple)::
before (with two input vars, one for each element of the input tuple)
def func4(arg): # Arg is a pair
temp = arg[0] + jnp.sin(arg[1]) * 3.
return jnp.sum(temp)
print(api.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,)
input_shape=(8,) ] e
in f }
>>> def func4(arg): # Arg is a pair
... temp = arg[0] + jnp.sin(arg[1]) * 3.
... return jnp.sum(temp)
...
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a b.
let c = sin b
d = mul c 3.0
e = add a d
f = reduce_sum[ axes=(0,) ] e
in (f,) }
@ -172,7 +171,9 @@ Constant Vars
ConstVars arise when the computation ontains array constants, either
from the Python program, or from constant-folding. For example, the function
``func6`` below::
``func6`` below
.. testcode::
def func5(first, second):
temp = first + jnp.sin(second) * 3. - jnp.ones(8)
@ -181,15 +182,17 @@ from the Python program, or from constant-folding. For example, the function
def func6(first):
return func5(first, jnp.ones(8))
print(api.make_jaxpr(func6)(jnp.ones(8)))
print(make_jaxpr(func6)(jnp.ones(8)))
JAX produces the following jaxpr::
JAX produces the following jaxpr
{ lambda b d a.
.. testoutput::
{ lambda b d ; a.
let c = add a b
e = sub c d
in e }
in (e,) }
When tracing ``func6``, the function ``func5`` is invoked with a constant value
(``onp.ones(8)``) for the second argument. As a result, the sub-expression
@ -213,27 +216,29 @@ with the following signature::
lax.cond(pred : bool, true_op: A, true_body: A -> B, false_op: C, false_body: C -> B) -> B
For example::
For example
def func7(arg):
return lax.cond(arg >= 0.,
arg,
lambda xtrue: xtrue + 3.,
arg,
lambda xfalse: xfalse - 3.)
print(api.make_jaxpr(func7)(5.))
{ lambda ; a.
let b = ge a 0.0
c = cond[ false_jaxpr={ lambda ; a.
let b = sub a 3.0
in b }
linear=(False, False)
true_jaxpr={ lambda ; a.
let b = add a 3.0
in b } ] b a a
in c }
>>> from jax import lax
>>>
>>> def func7(arg):
... return lax.cond(arg >= 0.,
... arg,
... lambda xtrue: xtrue + 3.,
... arg,
... lambda xfalse: xfalse - 3.)
...
>>> print(make_jaxpr(func7)(5.))
{ lambda ; a.
let b = ge a 0.0
c = cond[ false_jaxpr={ lambda ; a.
let b = sub a 3.0
in (b,) }
linear=(False, False)
true_jaxpr={ lambda ; a.
let b = add a 3.0
in (b,) } ] b a a
in (c,) }
The cond primitive has a number of parameters:
@ -252,26 +257,26 @@ passed to ``true_jaxpr``) and also ``a`` is the ``false_op``
The following example shows a more complicated situation when the input
to the branch functionals is a tuple, and the `false` branch functional
contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`::
contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
def func8(arg1, arg2): # arg2 is a pair
return lax.cond(arg1 >= 0.,
arg2,
lambda xtrue: xtrue[0],
arg2,
lambda xfalse: jnp.ones(1) + xfalse[1])
print(api.make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda e ; a b c.
let d = ge a 0.0
f = cond[ false_jaxpr={ lambda ; c a b.
let d = add c b
in d }
linear=(False, False, False, False, False)
true_jaxpr={ lambda ; a b.
let
in a } ] d b c e b c
in f }
>>> def func8(arg1, arg2): # arg2 is a pair
... return lax.cond(arg1 >= 0.,
... arg2,
... lambda xtrue: xtrue[0],
... arg2,
... lambda xfalse: jnp.ones(1) + xfalse[1])
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda e ; a b c.
let d = ge a 0.0
f = cond[ false_jaxpr={ lambda ; c a b.
let d = add c b
in (d,) }
linear=(False, False, False, False, False)
true_jaxpr={ lambda ; a b.
let
in (a,) } ] d b c e b c
in (f,) }
The top-level jaxpr has one `constvar` ``e`` (corresponding to ``jnp.ones(1)`` from the
body of the ``false_jaxpr``) and three input variables ``a b c`` (corresponding to ``arg1``
@ -305,28 +310,30 @@ and :py:func:`jax.lax.fori_loop`
In the above signature, “C” stands for the type of a the loop “carry” value.
For example, here is an example fori loop::
For example, here is an example fori loop
def func10(arg, n):
ones = jnp.ones(arg.shape) # A constant
return lax.fori_loop(0, n,
lambda i, carry: carry + ones * 3. + arg,
arg + ones)
print(api.make_jaxpr(func10)(onp.ones(16), 5))
{ lambda c d ; a b.
let e = add a d
f g h = while[ body_jaxpr={ lambda ; e g a b c.
let d = add a 1
f = add c e
h = add f g
in (d, b, h) }
body_nconsts=2
cond_jaxpr={ lambda ; a b c.
let d = lt a b
in d }
cond_nconsts=0 ] c a 0 b e
in h }
>>> import numpy as onp
>>>
>>> def func10(arg, n):
... ones = jnp.ones(arg.shape) # A constant
... return lax.fori_loop(0, n,
... lambda i, carry: carry + ones * 3. + arg,
... arg + ones)
...
>>> print(make_jaxpr(func10)(onp.ones(16), 5))
{ lambda c d ; a b.
let e = add a d
f g h = while[ body_jaxpr={ lambda ; e g a b c.
let d = add a 1
f = add c e
h = add f g
in (d, b, h) }
body_nconsts=2
cond_jaxpr={ lambda ; a b c.
let d = lt a b
in (d,) }
cond_nconsts=0 ] c a 0 b e
in (h,) }
The top-level jaxpr has two constvars: ``c`` (corresponding to ``ones * 3.`` from the body
of the loop) and ``d`` (corresponding to the use of ``ones`` in the initial carry).
@ -360,31 +367,30 @@ with the :py:func:`jax.lax.scan` operator::
Here ``C`` is the type of the scan carry, ``A`` is the element type of the input array(s),
and ``B`` is the element type of the output array(s).
For the example consider the function ``func11`` below::
For the example consider the function ``func11`` below
def func11(arr, extra):
ones = jnp.ones(arr.shape) # A constant
def body(carry, aelems):
# carry: running dot-product of the two arrays
# aelems: a pair with corresponding elements from the two arrays
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
print(api.make_jaxpr(func11)(onp.ones(16), 5.))
{ lambda c ; a b.
let d e = scan[ forward=True
jaxpr={ lambda ; f a b c.
let d = mul b c
e = add a d
g = add e f
in (g, a) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1 ] b 0.0 a c
in (d, e) }
>>> def func11(arr, extra):
... ones = jnp.ones(arr.shape) # A constant
... def body(carry, aelems):
... # carry: running dot-product of the two arrays
... # aelems: a pair with corresponding elements from the two arrays
... ae1, ae2 = aelems
... return (carry + ae1 * ae2 + extra, carry)
... return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(onp.ones(16), 5.))
{ lambda c ; a b.
let d e = scan[ forward=True
jaxpr={ lambda ; f a b c.
let d = mul b c
e = add a d
g = add e f
in (g, a) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1 ] b 0.0 a c
in (d, e) }
The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
and two input variables corresponding to the arguments ``arr`` and ``extra``.
@ -412,26 +418,28 @@ XLA_call
The call primitive arises from JIT compilation, and it encapsulates
a sub-jaxpr along with parameters the specify the backend and the device the
computation should run. For example::
computation should run. For example
def func12(arg):
@api.jit
def inner(x):
return x + arg * jnp.ones(1) # Include a constant in the inner function
return arg + inner(arg - 2.)
print(api.make_jaxpr(func12)(1.))
{ lambda b ; a.
let c = sub a 2.0
d = xla_call[ backend=None
call_jaxpr={ lambda ; c b a.
let d = mul b c
e = add a d
in e }
device=None
name=inner ] b a c
e = add a d
in e }
>>> from jax import jit
>>>
>>> def func12(arg):
... @jit
... def inner(x):
... return x + arg * jnp.ones(1) # Include a constant in the inner function
... return arg + inner(arg - 2.)
...
>>> print(make_jaxpr(func12)(1.))
{ lambda b ; a.
let c = sub a 2.0
d = xla_call[ backend=None
call_jaxpr={ lambda ; c b a.
let d = mul b c
e = add a d
in (e,) }
device=None
name=inner ] b a c
e = add a d
in (e,) }
The top-level constvar ``b`` refers to the ``jnp.ones(1)`` constant, and
the top-level input variable `a` refers to the ``arg`` parameter of ``func12``.
@ -450,30 +458,32 @@ XLA_pmap
If you use the :py:func:`jax.pmap` transformation, the function to be
mapped is captured using the ``xla_pmap`` primitive. Consider this
example::
example
def func13(arr, extra):
def inner(x):
# use a free variable "extra" and a constant jnp.ones(1)
return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
return api.pmap(inner, axis_name='rows')(arr)
print(api.make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda c ; a b.
let d = xla_pmap[ axis_name=rows
axis_size=1
backend=None
call_jaxpr={ lambda ; d b a.
let c = add a b
e = add c d
f = psum[ axis_name=rows ] a
g = div e f
in g }
devices=None
global_axis_size=None
mapped_invars=(True, False, True)
name=inner ] c b a
in d }
>>> from jax import pmap
>>>
>>> def func13(arr, extra):
... def inner(x):
... # use a free variable "extra" and a constant jnp.ones(1)
... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
... return pmap(inner, axis_name='rows')(arr)
...
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda c ; a b.
let d = xla_pmap[ axis_name=rows
axis_size=1
backend=None
call_jaxpr={ lambda ; d b a.
let c = add a b
e = add c d
f = psum[ axis_name=rows ] a
g = div e f
in (g,) }
devices=None
global_axis_size=None
mapped_invars=(True, False, True)
name=inner ] c b a
in (d,) }
The top-level constvar ``c`` refers to the ``jnp.ones(1)`` constant.
The ``xla_pmap`` primitive specifies the name of the axis (parameter ``rows``)