mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
f0ab673d9e
commit
f8dc650b2a
@ -385,29 +385,39 @@ For the example consider the function ``func11`` below::
|
||||
num_carry=1
|
||||
num_consts=1 ] b 0.0 a * c
|
||||
in (d, e) }
|
||||
{ lambda c ; a b.
|
||||
let d e = scan[ forward=True
|
||||
jaxpr={ lambda ; f a b c.
|
||||
let d = mul b c
|
||||
e = add a d
|
||||
g = add e f
|
||||
in (g, a) }
|
||||
length=16
|
||||
linear=(False, False, False, False)
|
||||
num_carry=1
|
||||
num_consts=1 ] b 0.0 a c
|
||||
in (d, e) }
|
||||
|
||||
The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
|
||||
and two input variables corresponding to the arguments ``arr`` and ``extra``.
|
||||
The body of the scan has 5 input variables, of which:
|
||||
The body of the scan has 4 input variables, of which:
|
||||
|
||||
* one (``a``) is a constant (since ``num_consts = 1``), and stands for the
|
||||
* one (``f``) is a constant (since ``num_consts = 1``), and stands for the
|
||||
captured variable ``extra`` used in the loop body,
|
||||
* one (``b``) is the value of the carry (since ``num_carry = 1``)
|
||||
* The remaining 3 are the input values. Notice that only ``c`` and ``e`` are used,
|
||||
and stand respectively for the array element from the first array passed to
|
||||
lax.scan (``arr``) and to the second array (``ones``). The input variables
|
||||
(``d``) seems to be an artifact of the translation.
|
||||
* one (``a``) is the value of the carry (since ``num_carry = 1``)
|
||||
* The remaining 2 are the input values. ``b`` is the array element from the
|
||||
first array passed to lax.scan (``arr``) and ``c`` is the second array
|
||||
(``ones``).
|
||||
|
||||
The ``linear`` parameter describes for each of the input variables whether they
|
||||
are guaranteed to be used linearly in the body. Here, only the unused input
|
||||
variable is marked linear. Once the scan goes through linearization, more arguments
|
||||
will be linear.
|
||||
are guaranteed to be used linearly in the body. Once the scan goes through
|
||||
linearization, more arguments will be linear.
|
||||
|
||||
The scan primitive takes 5 arguments: ``b 0.0 a * c``, of which:
|
||||
The scan primitive takes 4 arguments: ``b 0.0 a c``, of which:
|
||||
|
||||
* one is the free variable for the body
|
||||
* one is the initial value of the carry
|
||||
* The next 3 are the arrays over which the scan operates. The middle one is not used (*).
|
||||
* The next 2 are the arrays over which the scan operates.
|
||||
|
||||
XLA_call
|
||||
^^^^^^^^
|
||||
@ -490,4 +500,3 @@ value of this parameter is a Jaxpr with 3 input variables:
|
||||
The parameter ``mapped_invars`` specify which of the input variables should be
|
||||
mapped and which should be broadcast. In our example, the value of ``extra``
|
||||
is broadcast, the other input values are mapped.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user