mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Update jaxpr.rst (#2859)
* Update jaxpr doc * Make jaxpr.rst doctestable
This commit is contained in:
parent
5da74d4b82
commit
283393f773
366
docs/jaxpr.rst
366
docs/jaxpr.rst
@ -86,21 +86,21 @@ Most jaxpr primitives are first-order (they take just one or more Expr as argume
|
||||
|
||||
The jaxpr primitives are documented in the :py:mod:`jax.lax` module.
|
||||
|
||||
For example, here is the jaxpr produced for the function ``func1`` below::
|
||||
For example, here is the jaxpr produced for the function ``func1`` below
|
||||
|
||||
from jax import numpy as jnp
|
||||
def func1(first, second):
|
||||
temp = first + jnp.sin(second) * 3.
|
||||
return jnp.sum(temp)
|
||||
|
||||
print(jax.make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
|
||||
{ lambda ; a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,)
|
||||
input_shape=(8,) ] e
|
||||
in f }
|
||||
>>> from jax import make_jaxpr
|
||||
>>> from jax import numpy as jnp
|
||||
>>> def func1(first, second):
|
||||
... temp = first + jnp.sin(second) * 3.
|
||||
... return jnp.sum(temp)
|
||||
...
|
||||
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
|
||||
{ lambda ; a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,) ] e
|
||||
in (f,) }
|
||||
|
||||
Here there are no constvars, ``a`` and ``b`` are the input variables
|
||||
and they correspond respectively to
|
||||
@ -115,29 +115,29 @@ functions and control-flow, the resulting jaxpr does not have
|
||||
to contain control-flow or higher-order features.
|
||||
For example, when tracing the function ``func3`` JAX will inline the call to
|
||||
``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
|
||||
jaxpr as before::
|
||||
jaxpr as before
|
||||
|
||||
def func2(inner, first, second):
|
||||
temp = first + inner(second) * 3.
|
||||
return jnp.sum(temp)
|
||||
>>> def func2(inner, first, second):
|
||||
... temp = first + inner(second) * 3.
|
||||
... return jnp.sum(temp)
|
||||
...
|
||||
>>> def inner(second):
|
||||
... if second.shape[0] > 4:
|
||||
... return jnp.sin(second)
|
||||
... else:
|
||||
... assert False
|
||||
...
|
||||
>>> def func3(first, second):
|
||||
... return func2(inner, first, second)
|
||||
...
|
||||
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
|
||||
{ lambda ; a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,) ] e
|
||||
in (f,) }
|
||||
|
||||
def inner(second):
|
||||
if second.shape[0] > 4:
|
||||
return jnp.sin(second)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def func3(first, second):
|
||||
return func2(inner, first, second)
|
||||
|
||||
print(api.make_jaxpr(func2)(jnp.zeros(8), jnp.ones(8)))
|
||||
{ lambda ; a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,)
|
||||
input_shape=(8,) ] e
|
||||
in f }
|
||||
|
||||
Handling PyTrees
|
||||
----------------
|
||||
@ -149,21 +149,20 @@ of inputs and outputs. For more details, please see the documentation for
|
||||
PyTrees (:doc:`notebooks/JAX_pytrees`).
|
||||
|
||||
For example, the following code produces an identical jaxpr to what we saw
|
||||
before (with two input vars, one for each element of the input tuple)::
|
||||
before (with two input vars, one for each element of the input tuple)
|
||||
|
||||
|
||||
def func4(arg): # Arg is a pair
|
||||
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
||||
return jnp.sum(temp)
|
||||
|
||||
print(api.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
|
||||
{ lambda a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,)
|
||||
input_shape=(8,) ] e
|
||||
in f }
|
||||
>>> def func4(arg): # Arg is a pair
|
||||
... temp = arg[0] + jnp.sin(arg[1]) * 3.
|
||||
... return jnp.sum(temp)
|
||||
...
|
||||
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
|
||||
{ lambda ; a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,) ] e
|
||||
in (f,) }
|
||||
|
||||
|
||||
|
||||
@ -172,7 +171,9 @@ Constant Vars
|
||||
|
||||
ConstVars arise when the computation ontains array constants, either
|
||||
from the Python program, or from constant-folding. For example, the function
|
||||
``func6`` below::
|
||||
``func6`` below
|
||||
|
||||
.. testcode::
|
||||
|
||||
def func5(first, second):
|
||||
temp = first + jnp.sin(second) * 3. - jnp.ones(8)
|
||||
@ -181,15 +182,17 @@ from the Python program, or from constant-folding. For example, the function
|
||||
def func6(first):
|
||||
return func5(first, jnp.ones(8))
|
||||
|
||||
print(api.make_jaxpr(func6)(jnp.ones(8)))
|
||||
print(make_jaxpr(func6)(jnp.ones(8)))
|
||||
|
||||
|
||||
JAX produces the following jaxpr::
|
||||
JAX produces the following jaxpr
|
||||
|
||||
{ lambda b d a.
|
||||
.. testoutput::
|
||||
|
||||
{ lambda b d ; a.
|
||||
let c = add a b
|
||||
e = sub c d
|
||||
in e }
|
||||
in (e,) }
|
||||
|
||||
When tracing ``func6``, the function ``func5`` is invoked with a constant value
|
||||
(``onp.ones(8)``) for the second argument. As a result, the sub-expression
|
||||
@ -213,27 +216,29 @@ with the following signature::
|
||||
|
||||
lax.cond(pred : bool, true_op: A, true_body: A -> B, false_op: C, false_body: C -> B) -> B
|
||||
|
||||
For example::
|
||||
For example
|
||||
|
||||
|
||||
def func7(arg):
|
||||
return lax.cond(arg >= 0.,
|
||||
arg,
|
||||
lambda xtrue: xtrue + 3.,
|
||||
arg,
|
||||
lambda xfalse: xfalse - 3.)
|
||||
|
||||
print(api.make_jaxpr(func7)(5.))
|
||||
{ lambda ; a.
|
||||
let b = ge a 0.0
|
||||
c = cond[ false_jaxpr={ lambda ; a.
|
||||
let b = sub a 3.0
|
||||
in b }
|
||||
linear=(False, False)
|
||||
true_jaxpr={ lambda ; a.
|
||||
let b = add a 3.0
|
||||
in b } ] b a a
|
||||
in c }
|
||||
>>> from jax import lax
|
||||
>>>
|
||||
>>> def func7(arg):
|
||||
... return lax.cond(arg >= 0.,
|
||||
... arg,
|
||||
... lambda xtrue: xtrue + 3.,
|
||||
... arg,
|
||||
... lambda xfalse: xfalse - 3.)
|
||||
...
|
||||
>>> print(make_jaxpr(func7)(5.))
|
||||
{ lambda ; a.
|
||||
let b = ge a 0.0
|
||||
c = cond[ false_jaxpr={ lambda ; a.
|
||||
let b = sub a 3.0
|
||||
in (b,) }
|
||||
linear=(False, False)
|
||||
true_jaxpr={ lambda ; a.
|
||||
let b = add a 3.0
|
||||
in (b,) } ] b a a
|
||||
in (c,) }
|
||||
|
||||
|
||||
The cond primitive has a number of parameters:
|
||||
@ -252,26 +257,26 @@ passed to ``true_jaxpr``) and also ``a`` is the ``false_op``
|
||||
|
||||
The following example shows a more complicated situation when the input
|
||||
to the branch functionals is a tuple, and the `false` branch functional
|
||||
contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`::
|
||||
contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
|
||||
|
||||
def func8(arg1, arg2): # arg2 is a pair
|
||||
return lax.cond(arg1 >= 0.,
|
||||
arg2,
|
||||
lambda xtrue: xtrue[0],
|
||||
arg2,
|
||||
lambda xfalse: jnp.ones(1) + xfalse[1])
|
||||
|
||||
print(api.make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
|
||||
{ lambda e ; a b c.
|
||||
let d = ge a 0.0
|
||||
f = cond[ false_jaxpr={ lambda ; c a b.
|
||||
let d = add c b
|
||||
in d }
|
||||
linear=(False, False, False, False, False)
|
||||
true_jaxpr={ lambda ; a b.
|
||||
let
|
||||
in a } ] d b c e b c
|
||||
in f }
|
||||
>>> def func8(arg1, arg2): # arg2 is a pair
|
||||
... return lax.cond(arg1 >= 0.,
|
||||
... arg2,
|
||||
... lambda xtrue: xtrue[0],
|
||||
... arg2,
|
||||
... lambda xfalse: jnp.ones(1) + xfalse[1])
|
||||
...
|
||||
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
|
||||
{ lambda e ; a b c.
|
||||
let d = ge a 0.0
|
||||
f = cond[ false_jaxpr={ lambda ; c a b.
|
||||
let d = add c b
|
||||
in (d,) }
|
||||
linear=(False, False, False, False, False)
|
||||
true_jaxpr={ lambda ; a b.
|
||||
let
|
||||
in (a,) } ] d b c e b c
|
||||
in (f,) }
|
||||
|
||||
The top-level jaxpr has one `constvar` ``e`` (corresponding to ``jnp.ones(1)`` from the
|
||||
body of the ``false_jaxpr``) and three input variables ``a b c`` (corresponding to ``arg1``
|
||||
@ -305,28 +310,30 @@ and :py:func:`jax.lax.fori_loop`
|
||||
|
||||
|
||||
In the above signature, “C” stands for the type of a the loop “carry” value.
|
||||
For example, here is an example fori loop::
|
||||
For example, here is an example fori loop
|
||||
|
||||
def func10(arg, n):
|
||||
ones = jnp.ones(arg.shape) # A constant
|
||||
return lax.fori_loop(0, n,
|
||||
lambda i, carry: carry + ones * 3. + arg,
|
||||
arg + ones)
|
||||
|
||||
print(api.make_jaxpr(func10)(onp.ones(16), 5))
|
||||
{ lambda c d ; a b.
|
||||
let e = add a d
|
||||
f g h = while[ body_jaxpr={ lambda ; e g a b c.
|
||||
let d = add a 1
|
||||
f = add c e
|
||||
h = add f g
|
||||
in (d, b, h) }
|
||||
body_nconsts=2
|
||||
cond_jaxpr={ lambda ; a b c.
|
||||
let d = lt a b
|
||||
in d }
|
||||
cond_nconsts=0 ] c a 0 b e
|
||||
in h }
|
||||
>>> import numpy as onp
|
||||
>>>
|
||||
>>> def func10(arg, n):
|
||||
... ones = jnp.ones(arg.shape) # A constant
|
||||
... return lax.fori_loop(0, n,
|
||||
... lambda i, carry: carry + ones * 3. + arg,
|
||||
... arg + ones)
|
||||
...
|
||||
>>> print(make_jaxpr(func10)(onp.ones(16), 5))
|
||||
{ lambda c d ; a b.
|
||||
let e = add a d
|
||||
f g h = while[ body_jaxpr={ lambda ; e g a b c.
|
||||
let d = add a 1
|
||||
f = add c e
|
||||
h = add f g
|
||||
in (d, b, h) }
|
||||
body_nconsts=2
|
||||
cond_jaxpr={ lambda ; a b c.
|
||||
let d = lt a b
|
||||
in (d,) }
|
||||
cond_nconsts=0 ] c a 0 b e
|
||||
in (h,) }
|
||||
|
||||
The top-level jaxpr has two constvars: ``c`` (corresponding to ``ones * 3.`` from the body
|
||||
of the loop) and ``d`` (corresponding to the use of ``ones`` in the initial carry).
|
||||
@ -360,31 +367,30 @@ with the :py:func:`jax.lax.scan` operator::
|
||||
Here ``C`` is the type of the scan carry, ``A`` is the element type of the input array(s),
|
||||
and ``B`` is the element type of the output array(s).
|
||||
|
||||
For the example consider the function ``func11`` below::
|
||||
For the example consider the function ``func11`` below
|
||||
|
||||
def func11(arr, extra):
|
||||
ones = jnp.ones(arr.shape) # A constant
|
||||
def body(carry, aelems):
|
||||
# carry: running dot-product of the two arrays
|
||||
# aelems: a pair with corresponding elements from the two arrays
|
||||
ae1, ae2 = aelems
|
||||
return (carry + ae1 * ae2 + extra, carry)
|
||||
|
||||
return lax.scan(body, 0., (arr, ones))
|
||||
|
||||
print(api.make_jaxpr(func11)(onp.ones(16), 5.))
|
||||
{ lambda c ; a b.
|
||||
let d e = scan[ forward=True
|
||||
jaxpr={ lambda ; f a b c.
|
||||
let d = mul b c
|
||||
e = add a d
|
||||
g = add e f
|
||||
in (g, a) }
|
||||
length=16
|
||||
linear=(False, False, False, False)
|
||||
num_carry=1
|
||||
num_consts=1 ] b 0.0 a c
|
||||
in (d, e) }
|
||||
>>> def func11(arr, extra):
|
||||
... ones = jnp.ones(arr.shape) # A constant
|
||||
... def body(carry, aelems):
|
||||
... # carry: running dot-product of the two arrays
|
||||
... # aelems: a pair with corresponding elements from the two arrays
|
||||
... ae1, ae2 = aelems
|
||||
... return (carry + ae1 * ae2 + extra, carry)
|
||||
... return lax.scan(body, 0., (arr, ones))
|
||||
...
|
||||
>>> print(make_jaxpr(func11)(onp.ones(16), 5.))
|
||||
{ lambda c ; a b.
|
||||
let d e = scan[ forward=True
|
||||
jaxpr={ lambda ; f a b c.
|
||||
let d = mul b c
|
||||
e = add a d
|
||||
g = add e f
|
||||
in (g, a) }
|
||||
length=16
|
||||
linear=(False, False, False, False)
|
||||
num_carry=1
|
||||
num_consts=1 ] b 0.0 a c
|
||||
in (d, e) }
|
||||
|
||||
The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
|
||||
and two input variables corresponding to the arguments ``arr`` and ``extra``.
|
||||
@ -412,26 +418,28 @@ XLA_call
|
||||
|
||||
The call primitive arises from JIT compilation, and it encapsulates
|
||||
a sub-jaxpr along with parameters the specify the backend and the device the
|
||||
computation should run. For example::
|
||||
computation should run. For example
|
||||
|
||||
def func12(arg):
|
||||
@api.jit
|
||||
def inner(x):
|
||||
return x + arg * jnp.ones(1) # Include a constant in the inner function
|
||||
return arg + inner(arg - 2.)
|
||||
|
||||
print(api.make_jaxpr(func12)(1.))
|
||||
{ lambda b ; a.
|
||||
let c = sub a 2.0
|
||||
d = xla_call[ backend=None
|
||||
call_jaxpr={ lambda ; c b a.
|
||||
let d = mul b c
|
||||
e = add a d
|
||||
in e }
|
||||
device=None
|
||||
name=inner ] b a c
|
||||
e = add a d
|
||||
in e }
|
||||
>>> from jax import jit
|
||||
>>>
|
||||
>>> def func12(arg):
|
||||
... @jit
|
||||
... def inner(x):
|
||||
... return x + arg * jnp.ones(1) # Include a constant in the inner function
|
||||
... return arg + inner(arg - 2.)
|
||||
...
|
||||
>>> print(make_jaxpr(func12)(1.))
|
||||
{ lambda b ; a.
|
||||
let c = sub a 2.0
|
||||
d = xla_call[ backend=None
|
||||
call_jaxpr={ lambda ; c b a.
|
||||
let d = mul b c
|
||||
e = add a d
|
||||
in (e,) }
|
||||
device=None
|
||||
name=inner ] b a c
|
||||
e = add a d
|
||||
in (e,) }
|
||||
|
||||
The top-level constvar ``b`` refers to the ``jnp.ones(1)`` constant, and
|
||||
the top-level input variable `a` refers to the ``arg`` parameter of ``func12``.
|
||||
@ -450,30 +458,32 @@ XLA_pmap
|
||||
|
||||
If you use the :py:func:`jax.pmap` transformation, the function to be
|
||||
mapped is captured using the ``xla_pmap`` primitive. Consider this
|
||||
example::
|
||||
example
|
||||
|
||||
def func13(arr, extra):
|
||||
def inner(x):
|
||||
# use a free variable "extra" and a constant jnp.ones(1)
|
||||
return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
|
||||
return api.pmap(inner, axis_name='rows')(arr)
|
||||
|
||||
print(api.make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
|
||||
{ lambda c ; a b.
|
||||
let d = xla_pmap[ axis_name=rows
|
||||
axis_size=1
|
||||
backend=None
|
||||
call_jaxpr={ lambda ; d b a.
|
||||
let c = add a b
|
||||
e = add c d
|
||||
f = psum[ axis_name=rows ] a
|
||||
g = div e f
|
||||
in g }
|
||||
devices=None
|
||||
global_axis_size=None
|
||||
mapped_invars=(True, False, True)
|
||||
name=inner ] c b a
|
||||
in d }
|
||||
>>> from jax import pmap
|
||||
>>>
|
||||
>>> def func13(arr, extra):
|
||||
... def inner(x):
|
||||
... # use a free variable "extra" and a constant jnp.ones(1)
|
||||
... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
|
||||
... return pmap(inner, axis_name='rows')(arr)
|
||||
...
|
||||
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
|
||||
{ lambda c ; a b.
|
||||
let d = xla_pmap[ axis_name=rows
|
||||
axis_size=1
|
||||
backend=None
|
||||
call_jaxpr={ lambda ; d b a.
|
||||
let c = add a b
|
||||
e = add c d
|
||||
f = psum[ axis_name=rows ] a
|
||||
g = div e f
|
||||
in (g,) }
|
||||
devices=None
|
||||
global_axis_size=None
|
||||
mapped_invars=(True, False, True)
|
||||
name=inner ] c b a
|
||||
in (d,) }
|
||||
|
||||
The top-level constvar ``c`` refers to the ``jnp.ones(1)`` constant.
|
||||
The ``xla_pmap`` primitive specifies the name of the axis (parameter ``rows``)
|
||||
|
Loading…
x
Reference in New Issue
Block a user