mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Added the first draft of the Jaxpr documentation.
This replaces the previous Google Doc version, and is now updated with the latest changes in Jaxpr.
This commit is contained in:
parent
9e6fe64a66
commit
a5c3468c93
@ -34,7 +34,7 @@ sys.path.insert(0, os.path.abspath('..'))
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'JAX'
|
||||
copyright = '2019, Google LLC. NumPy and SciPy documentation are copyright the respective authors.'
|
||||
copyright = '2020, Google LLC. NumPy and SciPy documentation are copyright the respective authors.'
|
||||
author = 'The JAX authors'
|
||||
|
||||
# The short X.Y version
|
||||
|
@ -34,6 +34,7 @@ For an introduction to JAX, start at the
|
||||
:maxdepth: 1
|
||||
:caption: Notes
|
||||
|
||||
jaxpr
|
||||
async_dispatch
|
||||
concurrency
|
||||
gpu_memory_allocation
|
||||
@ -46,6 +47,7 @@ For an introduction to JAX, start at the
|
||||
:caption: Developer documentation
|
||||
|
||||
developer
|
||||
jax_internal_api
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
6
docs/jax.dlpack.rst
Normal file
6
docs/jax.dlpack.rst
Normal file
@ -0,0 +1,6 @@
|
||||
jax.dlpack module
|
||||
=================
|
||||
|
||||
.. automodule:: jax.dlpack
|
||||
:members:
|
||||
:show-inheritance:
|
@ -1,7 +1,7 @@
|
||||
.. currentmodule:: jax
|
||||
|
||||
jax package
|
||||
===========
|
||||
Public API: jax package
|
||||
=======================
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
14
docs/jax_internal_api.rst
Normal file
14
docs/jax_internal_api.rst
Normal file
@ -0,0 +1,14 @@
|
||||
Internal APIs
|
||||
=============
|
||||
|
||||
core
|
||||
-----
|
||||
|
||||
.. currentmodule:: jax.core
|
||||
.. automodule:: jax.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Jaxpr
|
||||
TypedJaxpr
|
491
docs/jaxpr.rst
Normal file
491
docs/jaxpr.rst
Normal file
@ -0,0 +1,491 @@
|
||||
Understanding JAXPR
|
||||
====================
|
||||
|
||||
(Note: the code examples in this file can be seed also in
|
||||
``jax/tests/api_test::JaxprTest.testExamplesJaxprDoc``.)
|
||||
|
||||
Conceptually, one can think of JAX transformations as first tracing the Python
|
||||
function to be transformed into a small and well-behaved intermediate form,
|
||||
the JAXPR, that is then transformed accordingly, and ultimately compiled and executed.
|
||||
One of the reasons JAX can pack so much power into such a small software package
|
||||
is that it starts with a familiar and flexible programming interface (Python with NumPy)
|
||||
and it uses the actual Python interpreter to do most of the heavy lifting to distill the
|
||||
essence of the computation into a simple statically-typed expression language
|
||||
with limited higher-order features: the JAXPR language.
|
||||
|
||||
Not all Python programs can be processed this way, but it turns out that many
|
||||
scientific computing and machine learning programs do have this property.
|
||||
|
||||
Before we proceed, it is important to point out that not all JAX transformations
|
||||
materialize a JAXPR as described above; some, e.g., differentiation,
|
||||
will apply transformations incrementally during tracing.
|
||||
Nevertheless, if one wants to understand how JAX works internally, or to
|
||||
make use of the result of JAX tracing, it is useful to understand JAXPR.
|
||||
|
||||
A JAXPR instance represents a function with one of more typed parameters (input variables)
|
||||
and one or more typed results. The results depend only on the input
|
||||
variables; there are no free variables captured from enclosing scopes.
|
||||
The inputs and outputs have types, which in JAX are represented as abstract
|
||||
values. There are two related representations in the code for JAXPRs. The main
|
||||
one is :py:class:`jax.core.TypedJaxpr` and is what you obtain when you
|
||||
use :py:func:`jax.make_jaxpr` to inspect JAXPRs. It has the following
|
||||
fields:
|
||||
|
||||
* ``jaxpr``: is the actual computation content of the actual function (described below).
|
||||
* ``literals`` is a list of constants. For various reasons, during tracing JAX
|
||||
will collect the non-scalar constants that arise and will replace them with
|
||||
variables, e.g., constants that appear in the Python program, or the result of
|
||||
constant folding such constants. The variables that stand for these constants
|
||||
are mentioned separately in the enclosed ``jaxpr``.
|
||||
When applying a ``TypedJaxpr`` to some actual
|
||||
arguments, one must pass first the ``literals`` followed by the actual arguments.
|
||||
* ``in_avals`` and ``out_avals`` are the types of the input variables
|
||||
(excluding the ones that correspond to the ``literals``), and of the output values.
|
||||
These types are called in JAX abstract values, e.g., ``ShapedArray(float32[10,10])``.
|
||||
|
||||
The most interesting part of the TypedJaxpr is the actual execution content,
|
||||
represented as a :py:class:`jax.core.Jaxpr` as printed using the following
|
||||
grammar::
|
||||
|
||||
JAXPR ::= { lambda Var* ; Var+.
|
||||
let Eqn*
|
||||
in [Expr+] }
|
||||
|
||||
where:
|
||||
* The parameter of the JAXPR are shown as two lists of variables separated by
|
||||
``;``. The first set of variables are the ones that have been introduced
|
||||
to stand for constants that have been hoisted out. These are called the
|
||||
`constvars`. The second list of variables are the real input variables.
|
||||
* ``Eqn*`` is a list of equations, defining intermediate variables referring to
|
||||
intermediate expressions. Each equation defines one or more variables as the
|
||||
result of applying a primitive on some atomic expressions. Each equation uses only
|
||||
input variables and intermediate variables defined by previous equations.
|
||||
* ``Expr+``: is a list of output atomic expressions for the JAXPR.
|
||||
|
||||
Equations are printed as follows::
|
||||
|
||||
Eqn ::= let Var+ = Primitive [ Param* ] Expr+
|
||||
|
||||
where:
|
||||
* ``Var+”`` are one or more intermediate variables to be defined as the
|
||||
output of a primitive invocation (some primitives can return multiple values)
|
||||
* ``Expr+`` are one or more atomic expressions, each either a variable or a
|
||||
literal constant. A special form of an atomic expression is the `unit`
|
||||
expression, printed as ``*`` and standing for a value that is not needed
|
||||
in the rest of the computation and has been elided.
|
||||
* ``Param*`` are zero or more named parameters to the primitive, printed in
|
||||
square brackets. Each parameter is shown as ``Name = Value``.
|
||||
|
||||
|
||||
Most JAXPR primitives are first-order (they take just one or more Expr as arguments)::
|
||||
|
||||
Primitive := add | sub | sin | mul | ...
|
||||
|
||||
|
||||
The JAXPR primitives are documented in the :py:mod:`jax.lax` module.
|
||||
|
||||
For example, here is the JAXPR produced for the function ``func1`` below::
|
||||
|
||||
from jax import numpy as jnp
|
||||
def func1(first, second):
|
||||
temp = first + jnp.sin(second) * 3.
|
||||
return jnp.sum(temp)
|
||||
|
||||
print(jax.make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
|
||||
{ lambda ; a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,)
|
||||
input_shape=(8,) ] e
|
||||
in f }
|
||||
|
||||
Here there are no constvars, ``a`` and ``b`` are the input variables
|
||||
and they correspond respectively to
|
||||
``first`` and ``second`` function parameters. The scalar literal ``3.0`` is kept
|
||||
inline.
|
||||
The ``reduce_sum`` primitive has named parameters ``axes`` and ``input_shape``, in
|
||||
addition to the operand ``e``.
|
||||
|
||||
Note that JAX traces through Python-level control-flow and higher-order functions
|
||||
when it extracts the JAXPR. This means that just because a Python program contains
|
||||
functions and control-flow, the resulting JAXPR does not have
|
||||
to contain control-flow or higher-order features.
|
||||
For example, when tracing the function ``func3`` JAX will inline the call to
|
||||
``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
|
||||
JAXPR as before::
|
||||
|
||||
def func2(inner, first, second):
|
||||
temp = first + inner(second) * 3.
|
||||
return jnp.sum(temp)
|
||||
|
||||
def inner(second):
|
||||
if second.shape[0] > 4:
|
||||
return jnp.sin(second)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def func3(first, second):
|
||||
return func2(inner, first, second)
|
||||
|
||||
print(api.make_jaxpr(func2)(jnp.zeros(8), jnp.ones(8)))
|
||||
{ lambda ; a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,)
|
||||
input_shape=(8,) ] e
|
||||
in f }
|
||||
|
||||
Handling PyTrees
|
||||
----------------
|
||||
|
||||
In JAXPR there are no tuple types; instead primitives take multiple inputs
|
||||
and produce multiple outputs. When processing a function that has structured
|
||||
inputs or outputs, JAX will flatten those and in JAXPR they will appear as lists
|
||||
of inputs and outputs. For more details, please see the documentation for
|
||||
PyTrees (:doc:`notebooks/JAX_pytrees`).
|
||||
|
||||
For example, the following code produces an identical JAXPR to what we saw
|
||||
before (with two input vars, one for each element of the input tuple)::
|
||||
|
||||
|
||||
def func4(arg): # Arg is a pair
|
||||
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
||||
return jnp.sum(temp)
|
||||
|
||||
print(api.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
|
||||
{ lambda a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,)
|
||||
input_shape=(8,) ] e
|
||||
in f }
|
||||
|
||||
|
||||
|
||||
Constant Vars
|
||||
--------------
|
||||
|
||||
ConstVars arise when the computation ontains array constants, either
|
||||
from the Python program, or from constant-folding. For example, the function
|
||||
``func6`` below::
|
||||
|
||||
def func5(first, second):
|
||||
temp = first + jnp.sin(second) * 3. - jnp.ones(8)
|
||||
return temp
|
||||
|
||||
def func6(first):
|
||||
return func5(first, jnp.ones(8))
|
||||
|
||||
print(api.make_jaxpr(func6)(jnp.ones(8)))
|
||||
|
||||
|
||||
JAX produces the following JAXPR::
|
||||
|
||||
{ lambda b d a.
|
||||
let c = add a b
|
||||
e = sub c d
|
||||
in e }
|
||||
|
||||
When tracing ``func6``, the function ``func5`` is invoked with a constant value
|
||||
(``onp.ones(8)``) for the second argument. As a result, the sub-expression
|
||||
``jnp.sin(second) * 3.`` is constant-folded.
|
||||
There are two ConstVars, ``b`` (standing for ``jnp.sin(second) * 3.``) and ``d``
|
||||
(standing for ``jnp.ones(8)``). Unfortunately, it is not easy to tell from the
|
||||
JAXPR notation what constants the constant variables stand for.
|
||||
|
||||
Higher-order primitives
|
||||
-----------------------
|
||||
|
||||
JAXPR includes several higher-order primitives. They are more complicated because
|
||||
they include sub-JAXPRs.
|
||||
|
||||
Cond
|
||||
^^^^
|
||||
|
||||
JAX traces through normal Python conditionals. To capture a conditional expression
|
||||
for dynamic execution, one must use the :py:func:`jax.lax.cond` constructor
|
||||
with the following signature::
|
||||
|
||||
lax.cond(pred : bool, true_op: A, true_body: A -> B, false_op: C, false_body: C -> B) -> B
|
||||
|
||||
For example::
|
||||
|
||||
|
||||
def func7(arg):
|
||||
return lax.cond(arg >= 0.,
|
||||
arg,
|
||||
lambda xtrue: xtrue + 3.,
|
||||
arg,
|
||||
lambda xfalse: xfalse - 3.)
|
||||
|
||||
print(api.make_jaxpr(func7)(5.))
|
||||
{ lambda ; a.
|
||||
let b = ge a 0.0
|
||||
c = cond[ false_jaxpr={ lambda ; a.
|
||||
let b = sub a 3.0
|
||||
in b }
|
||||
linear=(False, False)
|
||||
true_jaxpr={ lambda ; a.
|
||||
let b = add a 3.0
|
||||
in b } ] b a a
|
||||
in c }
|
||||
|
||||
|
||||
The cond primitive has a number of parameters:
|
||||
|
||||
* `true_jaxpr` and `false_jaxpr` are JAXPRs that correspond to the true
|
||||
and false branch functionals. In this example, those functionals take each
|
||||
one input variable, corresponding to ``xtrue`` and ``xfalse`` respectively.
|
||||
* `linear` is a tuple of booleans that is used internally by the auto-differentiation
|
||||
machinery to encode which of the input parameters are used linearly in the
|
||||
conditional.
|
||||
|
||||
The above instance of the cond primitive takes 3 operands.
|
||||
The first one (``b``) is the predicate, then ``a` is the ``true_op`` (``arg``, to be
|
||||
passed to ``true_jaxpr``) and also ``a`` is the ``false_op``
|
||||
(``arg``, to be passed to ``false_jaxpr``).
|
||||
|
||||
The following example shows a more complicated situation when the input
|
||||
to the branch functionals is a tuple, and the `false` branch functional
|
||||
contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`::
|
||||
|
||||
def func8(arg1, arg2): # arg2 is a pair
|
||||
return lax.cond(arg1 >= 0.,
|
||||
arg2,
|
||||
lambda xtrue: xtrue[0],
|
||||
arg2,
|
||||
lambda xfalse: jnp.ones(1) + xfalse[1])
|
||||
|
||||
print(api.make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
|
||||
{ lambda e ; a b c.
|
||||
let d = ge a 0.0
|
||||
f = cond[ false_jaxpr={ lambda ; c a b.
|
||||
let d = add c b
|
||||
in d }
|
||||
linear=(False, False, False, False, False)
|
||||
true_jaxpr={ lambda ; a b.
|
||||
let
|
||||
in a } ] d b c e b c
|
||||
in f }
|
||||
|
||||
The top-level JAXPR has one `constvar` ``e`` (corresponding to ``jnp.ones(1)`` from the
|
||||
body of the ``false_jaxpr``) and three input variables ``a b c`` (corresponding to ``arg1``
|
||||
and the two elements of ``arg2``; note that ``arg2`` has been flattened).
|
||||
The ``true_jaxpr`` has two input variables (corresponding to the two elements of ``arg2``
|
||||
that is passed to ``true_jaxpr``).
|
||||
The ``false_jaxpr`` has three input variables (``c`` corresponding to the constant for
|
||||
``jnp.ones(1)``, and ``a b`` for the two elements of ``arg2`` that are passed
|
||||
to ``false_jaxpr``).
|
||||
|
||||
The actual operands to the cond primitive are: ``d b c e b c``, which correspond in order to:
|
||||
|
||||
* 1 operand for the predicate,
|
||||
* 2 operands for ``true_jaxpr``, i.e., ``b`` and ``c``, which are input vars,
|
||||
corresponding to ``arg2`` for the top-level JAXPR,
|
||||
* 1 constant for ``false_jaxpr``, i.e., ``e``, which is a consvar for the top-level JAXPR,
|
||||
* 2 operands for ``true_jaxpr``, i.e., ``b`` and ``c``, which are the input vars
|
||||
corresponding to ``arg2`` for the top-level JAXPR.
|
||||
|
||||
While
|
||||
^^^^^
|
||||
|
||||
Just like for conditionals, Python loops are inlined during tracing.
|
||||
If you want to capture a loop for dynamic execution, you must use one of several
|
||||
special operations, :py:func:`jax.lax.while_loop` (a primitive)
|
||||
and :py:func:`jax.lax.fori_loop`
|
||||
(a helper that generates a while_loop primitive)::
|
||||
|
||||
lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
|
||||
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
|
||||
|
||||
|
||||
In the above signature, “C” stands for the type of a the loop “carry” value.
|
||||
For example, here is an example fori loop::
|
||||
|
||||
def func10(arg, n):
|
||||
ones = jnp.ones(arg.shape) # A constant
|
||||
return lax.fori_loop(0, n,
|
||||
lambda i, carry: carry + ones * 3. + arg,
|
||||
arg + ones)
|
||||
|
||||
print(api.make_jaxpr(func10)(onp.ones(16), 5))
|
||||
{ lambda c d ; a b.
|
||||
let e = add a d
|
||||
f g h = while[ body_jaxpr={ lambda ; e g a b c.
|
||||
let d = add a 1
|
||||
f = add c e
|
||||
h = add f g
|
||||
in (d, b, h) }
|
||||
body_nconsts=2
|
||||
cond_jaxpr={ lambda ; a b c.
|
||||
let d = lt a b
|
||||
in d }
|
||||
cond_nconsts=0 ] c a 0 b e
|
||||
in h }
|
||||
|
||||
The top-level JAXPR has two constvars: ``c`` (corresponding to ``ones * 3.`` from the body
|
||||
of the loop) and ``d`` (corresponding to the use of ``ones`` in the initial carry).
|
||||
There are also two input variables (``a`` corresponding to ``arg`` and ``b`` corresponding
|
||||
to ``n``).
|
||||
The loop carry consists of three values, as seen in the body of ``cond_jaxpr``
|
||||
(corresponding to the iteration index, iteration end, and the accumulated value carry).
|
||||
Note that ``body_jaxpr`` takes 5 input variables. The first two are actually
|
||||
constvars: ``e`` corresponding to ``ones * 3`` and ``g`` corresponding to the
|
||||
captures use of ``arg`` in the loop body.
|
||||
The parameter ``body_nconsts = 2`` specifies that there are 2 constants for the
|
||||
``body_jaxpr``.
|
||||
The other 3 input variables for ``body_jaxpr`` correspond to the flattened carry values.
|
||||
|
||||
The while primitive takes 5 arguments: ``c a 0 b e``, as follows:
|
||||
|
||||
* 0 constants for ``cond_jaxpr`` (since ``cond_nconsts`` is 0)
|
||||
* 2 constants for ``body_jaxpr`` (``c``, and ``a``)
|
||||
* 3 parameters for the initial value of carry
|
||||
|
||||
Scan
|
||||
^^^^
|
||||
|
||||
JAX supports a special form of loop over the elements of an array (with
|
||||
statically known shape). The fact that there are a fixed number of iterations
|
||||
makes this form of looping easily reverse-differentiable. Such loops are constructed
|
||||
with the :py:func:`jax.lax.scan` operator::
|
||||
|
||||
lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
|
||||
|
||||
Here ``C`` is the type of the scan carry, ``A`` is the element type of the input array(s),
|
||||
and ``B`` is the element type of the output array(s).
|
||||
|
||||
For the example consider the function ``func11`` below::
|
||||
|
||||
def func11(arr, extra):
|
||||
ones = jnp.ones(arr.shape) # A constant
|
||||
def body(carry, aelems):
|
||||
# carry: running dot-product of the two arrays
|
||||
# aelems: a pair with corresponding elements from the two arrays
|
||||
ae1, ae2 = aelems
|
||||
return (carry + ae1 * ae2 + extra, carry)
|
||||
|
||||
return lax.scan(body, 0., (arr, ones))
|
||||
|
||||
print(api.make_jaxpr(func11)(onp.ones(16), 5.))
|
||||
{ lambda c ; a b.
|
||||
let d e = scan[ forward=True
|
||||
jaxpr={ lambda ; a b c d e.
|
||||
let f = mul c e
|
||||
g = add b f
|
||||
h = add g a
|
||||
in (h, b) }
|
||||
length=16
|
||||
linear=(False, False, False, True, False)
|
||||
num_carry=1
|
||||
num_consts=1 ] b 0.0 a * c
|
||||
in (d, e) }
|
||||
|
||||
The top-level JAXPR has one constvar ``c`` corresponding to the ``ones`` constant,
|
||||
and two input variables corresponding to the arguments ``arr`` and ``extra``.
|
||||
The body of the scan has 5 input variables, of which:
|
||||
|
||||
* one (``a``) is a constant (since ``num_consts = 1``), and stands for the
|
||||
captured variable ``extra`` used in the loop body,
|
||||
* one (``b``) is the value of the carry (since ``num_carry = 1``)
|
||||
* The remaining 3 are the input values. Notice that only ``c`` and ``e`` are used,
|
||||
and stand respectively for the array element from the first array passed to
|
||||
lax.scan (``arr``) and to the second array (``ones``). The input variables
|
||||
(``d``) seems to be an artifact of the translation.
|
||||
|
||||
The ``linear`` parameter describes for each of the input variables whether they
|
||||
are guaranteed to be used linearly in the body. Here, only the unused input
|
||||
variable is marked linear. Once the scan goes through linearization, more arguments
|
||||
will be linear.
|
||||
|
||||
The scan primitive takes 5 arguments: ``b 0.0 a * c``, of which:
|
||||
|
||||
* one is the free variable for the body
|
||||
* one is the initial value of the carry
|
||||
* The next 3 are the arrays over which the scan operates. The middle one is not used (*).
|
||||
|
||||
XLA_call
|
||||
^^^^^^^^
|
||||
|
||||
The call primitive arises from JIT compilation, and it encapsulates
|
||||
a sub-JAXPR along with parameters the specify the backend and the device the
|
||||
computation should run. For example::
|
||||
|
||||
def func12(arg):
|
||||
@api.jit
|
||||
def inner(x):
|
||||
return x + arg * jnp.ones(1) # Include a constant in the inner function
|
||||
return arg + inner(arg - 2.)
|
||||
|
||||
print(api.make_jaxpr(func12)(1.))
|
||||
{ lambda b ; a.
|
||||
let c = sub a 2.0
|
||||
d = xla_call[ backend=None
|
||||
call_jaxpr={ lambda ; c b a.
|
||||
let d = mul b c
|
||||
e = add a d
|
||||
in e }
|
||||
device=None
|
||||
name=inner ] b a c
|
||||
e = add a d
|
||||
in e }
|
||||
|
||||
The top-level constvar ``b`` refers to the ``jnp.ones(1)`` constant, and
|
||||
the top-level input variable `a` refers to the ``arg`` parameter of ``func12``.
|
||||
The ``xla_call`` primitive stands for a call to the jitted ``inner`` function.
|
||||
The primitive has the function body in the ``call_jaxpr`` parameter, a JAXPR
|
||||
with 3 input parameters:
|
||||
|
||||
* ``c`` is a constvar and stands for the ``ones`` constant,
|
||||
* ``b`` corresponds to the free variable ``arg`` captured in the ``inner`` function,
|
||||
* ``a`` corresponds to the ``inner`` parameter ``x`.
|
||||
|
||||
The primitive takes three arguments ``b a c``.
|
||||
|
||||
XLA_pmap
|
||||
^^^^^^^^
|
||||
|
||||
If you use the :py:func:`jax.pmap` transformation, the function to be
|
||||
mapped is captured using the ``xla_pmap`` primitive. Consider this
|
||||
example::
|
||||
|
||||
def func13(arr, extra):
|
||||
def inner(x):
|
||||
# use a free variable "extra" and a constant jnp.ones(1)
|
||||
return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
|
||||
return api.pmap(inner, axis_name='rows')(arr)
|
||||
|
||||
print(api.make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
|
||||
{ lambda c ; a b.
|
||||
let d = xla_pmap[ axis_name=rows
|
||||
axis_size=1
|
||||
backend=None
|
||||
call_jaxpr={ lambda ; d b a.
|
||||
let c = add a b
|
||||
e = add c d
|
||||
f = psum[ axis_name=rows ] a
|
||||
g = div e f
|
||||
in g }
|
||||
devices=None
|
||||
global_axis_size=None
|
||||
mapped_invars=(True, False, True)
|
||||
name=inner ] c b a
|
||||
in d }
|
||||
|
||||
The top-level constvar ``c`` refers to the ``jnp.ones(1)`` constant.
|
||||
The ``xla_pmap`` primitive specifies the name of the axis (parameter ``rows``)
|
||||
and the body of the function to be mapped as the ``call_jaxpr`` parameter. The
|
||||
value of this parameter is a Jaxpr with 3 input variables:
|
||||
|
||||
* ``d`` stands for the constant ``jnp.ones(1)``,
|
||||
* ``b`` stands for the free variable ``extra``,
|
||||
* ``a`` stands for the parameter ``x`` of ``inner``.
|
||||
|
||||
|
||||
The parameter ``mapped_invars`` specify which of the input variables should be
|
||||
mapped and which should be broadcast. In our example, the value of ``extra``
|
||||
is broadcast, the other input values are mapped.
|
||||
|
@ -1,7 +0,0 @@
|
||||
jax
|
||||
===
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
jax
|
@ -1154,6 +1154,7 @@ def linearize(fun, *primals):
|
||||
|
||||
In terms of values computed, `linearize` behaves much like a curried `jvp`,
|
||||
where these two code blocks compute the same values::
|
||||
|
||||
y, out_tangent = jax.jvp(f, (x,), (in_tangent,))
|
||||
|
||||
y, f_jvp = jax.linearize(f, x)
|
||||
@ -1168,8 +1169,10 @@ def linearize(fun, *primals):
|
||||
i.e. to evaluate a pushforward for many different input tangent vectors at the
|
||||
same linearization point. Moreover if all the input tangent vectors are known
|
||||
at once, it can be more efficient to vectorize using `vmap`, as in::
|
||||
|
||||
pushfwd = partial(jvp, f, (x,))
|
||||
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
|
||||
|
||||
By using `vmap` and `jvp` together like this we avoid the stored-linearization
|
||||
memory cost that scales with the depth of the computation, which is incurred
|
||||
by both `linearize` and `vjp`.
|
||||
|
21
jax/core.py
21
jax/core.py
@ -309,6 +309,7 @@ class Trace(object):
|
||||
|
||||
|
||||
def pure(self, val):
|
||||
"""Given a concrete value, makes a Tracer for it."""
|
||||
assert False
|
||||
|
||||
def lift(self, tracer):
|
||||
@ -317,6 +318,19 @@ class Trace(object):
|
||||
def sublift(self, tracer):
|
||||
assert False
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
"""Processes a primitive
|
||||
|
||||
Args:
|
||||
primitive: the primitive
|
||||
tracers: the tracers for the arguments
|
||||
params: the primitive parameters
|
||||
|
||||
Returns:
|
||||
either a tracer, or a list of tracers (if primitive.multiple_results)
|
||||
"""
|
||||
assert False, "Must override"
|
||||
|
||||
def __repr__(self):
|
||||
return '{}(level={}/{})'.format(
|
||||
self.__class__.__name__, self.level, self.sublevel)
|
||||
@ -713,8 +727,13 @@ def pp_eqn(eqn):
|
||||
>> pp(' ') >> pp(pp_vars(eqn.invars))) + pp_subexpr
|
||||
|
||||
def pp_jaxpr(jaxpr):
|
||||
if len(jaxpr.outvars) > 1:
|
||||
pp_outvars = str(tuple(jaxpr.outvars))
|
||||
else:
|
||||
pp_outvars = str(jaxpr.outvars[0])
|
||||
|
||||
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
|
||||
pp_vars(jaxpr.invars))) +
|
||||
((pp('let ') >>
|
||||
vcat(map(pp_eqn, jaxpr.eqns))) +
|
||||
pp('in {} }}'.format(jaxpr.outvars))).indent(2))
|
||||
pp('in {} }}'.format(pp_outvars))).indent(2))
|
||||
|
@ -174,6 +174,12 @@ def defbroadcasting(prim):
|
||||
primitive_batchers[prim] = partial(broadcast_batcher, prim)
|
||||
|
||||
def broadcast_batcher(prim, args, dims, **params):
|
||||
"""Process a primitive with built-in broadcasting.
|
||||
|
||||
Args:
|
||||
args: the arguments
|
||||
dims: for each argument, the dimension that is being batched (or None)
|
||||
"""
|
||||
shapes = {(x.shape, d) for x, d in zip(args, dims) if onp.ndim(x)}
|
||||
if len(shapes) == 1:
|
||||
# if there's only agreeing batch dims and scalars, just call the primitive
|
||||
|
@ -699,7 +699,7 @@ class JaxTestCase(parameterized.TestCase):
|
||||
expected_clean = re.sub(ignore_space_re, '\n', expected.strip())
|
||||
what_clean = re.sub(ignore_space_re, '\n', what.strip())
|
||||
self.assertMultiLineEqual(expected_clean, what_clean,
|
||||
msg="Expecting\n"+expected)
|
||||
msg="Found\n{}\nExpecting\n{}".format(what, expected))
|
||||
|
||||
def _CompileAndCheck(self, fun, args_maker, check_dtypes,
|
||||
rtol=None, atol=None):
|
||||
|
@ -174,6 +174,7 @@ _registry = {
|
||||
type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None),
|
||||
}
|
||||
def _replace_nones(sentinel, tree):
|
||||
"""Replaces `None` in `tree` with `sentinel`."""
|
||||
if tree is None:
|
||||
return sentinel
|
||||
else:
|
||||
|
@ -1735,11 +1735,11 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
return (x, 1., np.zeros(1))
|
||||
|
||||
jaxpr = api.make_jaxpr(fun)(0.)
|
||||
self.assertMultiLineStrippedEqual(str(jaxpr), """
|
||||
{ lambda b ; a.
|
||||
let
|
||||
in [a, 1.0, b] }
|
||||
""")
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
{ lambda b ; a.
|
||||
let
|
||||
in (a, 1.0, b) }
|
||||
""", str(jaxpr))
|
||||
|
||||
def test_cond(self):
|
||||
def f(x):
|
||||
@ -1749,21 +1749,193 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
x + 2.,
|
||||
lambda xf: xf - x)
|
||||
jaxpr = api.make_jaxpr(f)(3.)
|
||||
self.assertMultiLineStrippedEqual(str(jaxpr), """
|
||||
{ lambda ; a.
|
||||
let b = ge a 0.0
|
||||
c = add a 1.0
|
||||
d = add a 2.0
|
||||
e = cond[ false_jaxpr={ lambda ; b a.
|
||||
let c = sub a b
|
||||
in [c] }
|
||||
linear=(False, False, False, False)
|
||||
true_jaxpr={ lambda ; b a.
|
||||
let c = add a b
|
||||
in [c] } ] b a c a d
|
||||
in [e] }
|
||||
""")
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
{ lambda ; a.
|
||||
let b = ge a 0.0
|
||||
c = add a 1.0
|
||||
d = add a 2.0
|
||||
e = cond[ false_jaxpr={ lambda ; b a.
|
||||
let c = sub a b
|
||||
in c }
|
||||
linear=(False, False, False, False)
|
||||
true_jaxpr={ lambda ; b a.
|
||||
let c = add a b
|
||||
in c } ] b a c a d
|
||||
in e }
|
||||
""", str(jaxpr))
|
||||
|
||||
def testExamplesJaxprDoc(self):
|
||||
"""Tests examples included in the Understanding JAXPRs doc (docs/jaxpr.rst)."""
|
||||
from jax import numpy as jnp
|
||||
def func1(first, second):
|
||||
temp = first + jnp.sin(second) * 3.
|
||||
return jnp.sum(temp)
|
||||
|
||||
jaxpr = jax.make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8))
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
{ lambda ; a b.
|
||||
let c = sin b
|
||||
d = mul c 3.0
|
||||
e = add a d
|
||||
f = reduce_sum[ axes=(0,)
|
||||
input_shape=(8,) ] e
|
||||
in f }
|
||||
""", str(jaxpr))
|
||||
|
||||
def func5(first, second):
|
||||
temp = first + np.sin(second) * 3. - jnp.ones(8)
|
||||
return temp
|
||||
|
||||
def func6(first):
|
||||
return func5(first, jnp.ones(8))
|
||||
|
||||
jaxpr = api.make_jaxpr(func6)(jnp.ones(8))
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
{ lambda b d ; a.
|
||||
let c = add a b
|
||||
e = sub c d
|
||||
in e }
|
||||
""", str(jaxpr))
|
||||
|
||||
def func7(arg):
|
||||
return lax.cond(arg >= 0.,
|
||||
arg,
|
||||
lambda xtrue: xtrue + 3.,
|
||||
arg,
|
||||
lambda xfalse: xfalse - 3.)
|
||||
|
||||
jaxpr = api.make_jaxpr(func7)(5.)
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
{ lambda ; a.
|
||||
let b = ge a 0.0
|
||||
c = cond[ false_jaxpr={ lambda ; a.
|
||||
let b = sub a 3.0
|
||||
in b }
|
||||
linear=(False, False)
|
||||
true_jaxpr={ lambda ; a.
|
||||
let b = add a 3.0
|
||||
in b } ] b a a
|
||||
in c }
|
||||
""", str(jaxpr))
|
||||
|
||||
def func8(arg1, arg2): # arg2 is a pair
|
||||
return lax.cond(arg1 >= 0.,
|
||||
arg2,
|
||||
lambda xtrue: xtrue[0],
|
||||
arg2,
|
||||
lambda xfalse: jnp.ones(1) + xfalse[1])
|
||||
|
||||
jaxpr = api.make_jaxpr(func8)(5., (jnp.zeros(1), 2.))
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
{ lambda e ; a b c.
|
||||
let d = ge a 0.0
|
||||
f = cond[ false_jaxpr={ lambda ; c a b.
|
||||
let d = add c b
|
||||
in d }
|
||||
linear=(False, False, False, False, False)
|
||||
true_jaxpr={ lambda ; a b.
|
||||
let
|
||||
in a } ] d b c e b c
|
||||
in f }
|
||||
""", str(jaxpr))
|
||||
|
||||
def func10(arg, n):
|
||||
ones = jnp.ones(arg.shape) # A constant
|
||||
return lax.fori_loop(0, n,
|
||||
lambda i, carry: carry + ones * 3. + arg,
|
||||
arg + ones)
|
||||
|
||||
jaxpr = api.make_jaxpr(func10)(onp.ones(16), 5)
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
{ lambda c d ; a b.
|
||||
let e = add a d
|
||||
f g h = while[ body_jaxpr={ lambda ; e g a b c.
|
||||
let d = add a 1
|
||||
f = add c e
|
||||
h = add f g
|
||||
in (d, b, h) }
|
||||
body_nconsts=2
|
||||
cond_jaxpr={ lambda ; a b c.
|
||||
let d = lt a b
|
||||
in d }
|
||||
cond_nconsts=0 ] c a 0 b e
|
||||
in h }
|
||||
""", str(jaxpr))
|
||||
|
||||
def func11(arr, extra):
|
||||
ones = jnp.ones(arr.shape) # A constant
|
||||
|
||||
def body(carry, aelems):
|
||||
# carry: running dot-product of the two arrays
|
||||
# aelems: a pair with corresponding elements from the two arrays
|
||||
ae1, ae2 = aelems
|
||||
return (carry + ae1 * ae2 + extra, carry)
|
||||
|
||||
return lax.scan(body, 0., (arr, ones))
|
||||
|
||||
jaxpr = api.make_jaxpr(func11)(onp.ones(16), 5.)
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
{ lambda c ; a b.
|
||||
let d e = scan[ forward=True
|
||||
jaxpr={ lambda ; a b c d e.
|
||||
let f = mul c e
|
||||
g = add b f
|
||||
h = add g a
|
||||
in (h, b) }
|
||||
length=16
|
||||
linear=(False, False, False, True, False)
|
||||
num_carry=1
|
||||
num_consts=1 ] b 0.0 a * c
|
||||
in (d, e) }
|
||||
""", str(jaxpr))
|
||||
|
||||
def func12(arg):
|
||||
@api.jit
|
||||
def inner(x):
|
||||
return x + arg * jnp.ones(1) # Include a constant in the inner function
|
||||
|
||||
return arg + inner(arg - 2.)
|
||||
|
||||
jaxpr = api.make_jaxpr(func12)(1.)
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
{ lambda b ; a.
|
||||
let c = sub a 2.0
|
||||
d = xla_call[ backend=None
|
||||
call_jaxpr={ lambda ; c b a.
|
||||
let d = mul b c
|
||||
e = add a d
|
||||
in e }
|
||||
device=None
|
||||
name=inner ] b a c
|
||||
e = add a d
|
||||
in e }
|
||||
""", str(jaxpr))
|
||||
|
||||
def func13(arr, extra):
|
||||
def inner(x):
|
||||
# use a free variable "extra" and a constant jnp.ones(1)
|
||||
return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
|
||||
|
||||
return api.pmap(inner, axis_name='rows')(arr)
|
||||
|
||||
jaxpr = api.make_jaxpr(func13)(jnp.ones((1, 3)), 5.)
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
{ lambda c ; a b.
|
||||
let d = xla_pmap[ axis_name=rows
|
||||
axis_size=1
|
||||
backend=None
|
||||
call_jaxpr={ lambda ; d b a.
|
||||
let c = add a b
|
||||
e = add c d
|
||||
f = psum[ axis_name=rows ] a
|
||||
g = div e f
|
||||
in g }
|
||||
devices=None
|
||||
global_axis_size=None
|
||||
mapped_invars=(True, False, True)
|
||||
name=inner ] c b a
|
||||
in d }
|
||||
""", str(jaxpr))
|
||||
|
||||
class LazyTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user