diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index b99dddac5..c70c06af6 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -126,8 +126,9 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], represents the type with the same pytree structure and corresponding leaves each with an additional leading axis. - When ``a`` is an array type or None, and ``b`` is an array type, the semantics - of :func:`~scan` are given roughly by this Python implementation:: + When the type of ``xs`` (denoted `a` above) is an array type or None, and the type + of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are + given roughly by this Python implementation:: def scan(f, init, xs, length=None): if xs is None: @@ -139,9 +140,10 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], ys.append(y) return carry, np.stack(ys) - Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree - types, and so multiple arrays can be scanned over at once and produce multiple - output arrays. (None is actually an empty pytree.) + Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree + values, and so multiple arrays can be scanned over at once and produce multiple + output arrays. ``None`` is actually a special case of this, as it represents an + empty pytree. Also unlike that Python version, :func:`~scan` is a JAX primitive and is lowered to a single WhileOp. That makes it useful for reducing