mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
DOC: clarify variable names in scan doc
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
a002643a4a
commit
853f65fd99
@ -126,8 +126,9 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
represents the type with the same pytree structure and corresponding leaves
|
||||
each with an additional leading axis.
|
||||
|
||||
When ``a`` is an array type or None, and ``b`` is an array type, the semantics
|
||||
of :func:`~scan` are given roughly by this Python implementation::
|
||||
When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
|
||||
of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
|
||||
given roughly by this Python implementation::
|
||||
|
||||
def scan(f, init, xs, length=None):
|
||||
if xs is None:
|
||||
@ -139,9 +140,10 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
ys.append(y)
|
||||
return carry, np.stack(ys)
|
||||
|
||||
Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
|
||||
types, and so multiple arrays can be scanned over at once and produce multiple
|
||||
output arrays. (None is actually an empty pytree.)
|
||||
Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
|
||||
values, and so multiple arrays can be scanned over at once and produce multiple
|
||||
output arrays. ``None`` is actually a special case of this, as it represents an
|
||||
empty pytree.
|
||||
|
||||
Also unlike that Python version, :func:`~scan` is a JAX primitive and is
|
||||
lowered to a single WhileOp. That makes it useful for reducing
|
||||
|
Loading…
x
Reference in New Issue
Block a user