DOC: clarify variable names in scan doc

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Jake VanderPlas 2023-03-02 10:25:54 -08:00
parent a002643a4a
commit 853f65fd99

View File

@ -126,8 +126,9 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
represents the type with the same pytree structure and corresponding leaves
each with an additional leading axis.
When ``a`` is an array type or None, and ``b`` is an array type, the semantics
of :func:`~scan` are given roughly by this Python implementation::
When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
given roughly by this Python implementation::
def scan(f, init, xs, length=None):
if xs is None:
@ -139,9 +140,10 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
ys.append(y)
return carry, np.stack(ys)
Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
types, and so multiple arrays can be scanned over at once and produce multiple
output arrays. (None is actually an empty pytree.)
Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
values, and so multiple arrays can be scanned over at once and produce multiple
output arrays. ``None`` is actually a special case of this, as it represents an
empty pytree.
Also unlike that Python version, :func:`~scan` is a JAX primitive and is
lowered to a single WhileOp. That makes it useful for reducing