Update scan jaxpr documentation. (#2641)

Closes #2640.
This commit is contained in:
Skye Wanderman-Milne 2020-04-07 19:03:41 -07:00 committed by GitHub
parent f0ab673d9e
commit f8dc650b2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -385,29 +385,39 @@ For the example consider the function ``func11`` below::
num_carry=1
num_consts=1 ] b 0.0 a * c
in (d, e) }
{ lambda c ; a b.
let d e = scan[ forward=True
jaxpr={ lambda ; f a b c.
let d = mul b c
e = add a d
g = add e f
in (g, a) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1 ] b 0.0 a c
in (d, e) }
The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
and two input variables corresponding to the arguments ``arr`` and ``extra``.
The body of the scan has 5 input variables, of which:
The body of the scan has 4 input variables, of which:
* one (``a``) is a constant (since ``num_consts = 1``), and stands for the
* one (``f``) is a constant (since ``num_consts = 1``), and stands for the
captured variable ``extra`` used in the loop body,
* one (``b``) is the value of the carry (since ``num_carry = 1``)
* The remaining 3 are the input values. Notice that only ``c`` and ``e`` are used,
and stand respectively for the array element from the first array passed to
lax.scan (``arr``) and to the second array (``ones``). The input variables
(``d``) seems to be an artifact of the translation.
* one (``a``) is the value of the carry (since ``num_carry = 1``)
* The remaining 2 are the input values. ``b`` is the array element from the
first array passed to lax.scan (``arr``) and ``c`` is the second array
(``ones``).
The ``linear`` parameter describes for each of the input variables whether they
are guaranteed to be used linearly in the body. Here, only the unused input
variable is marked linear. Once the scan goes through linearization, more arguments
will be linear.
are guaranteed to be used linearly in the body. Once the scan goes through
linearization, more arguments will be linear.
The scan primitive takes 5 arguments: ``b 0.0 a * c``, of which:
The scan primitive takes 4 arguments: ``b 0.0 a c``, of which:
* one is the free variable for the body
* one is the initial value of the carry
* The next 3 are the arrays over which the scan operates. The middle one is not used (*).
* The next 2 are the arrays over which the scan operates.
XLA_call
^^^^^^^^
@ -490,4 +500,3 @@ value of this parameter is a Jaxpr with 3 input variables:
The parameter ``mapped_invars`` specify which of the input variables should be
mapped and which should be broadcast. In our example, the value of ``extra``
is broadcast, the other input values are mapped.