DOC: note compilation of funcs in lax control_flow

This commit is contained in:
Jake VanderPlas 2022-05-18 14:58:48 -07:00
parent bb0816227d
commit 9695ebe2f0

View File

@ -182,6 +182,10 @@ def fori_loop(lower, upper, body_fun, init_val):
structure with a fixed structure and arrays with fixed shape and dtype at the
leaves).
.. note::
:py:func:`fori_loop` compiles ``body_fun``, so while it can be combined with
:py:func:`jit`, it's usually unnecessary.
Args:
lower: an integer representing the loop index lower bound (inclusive)
upper: an integer representing the loop index upper bound (exclusive)
@ -267,6 +271,10 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
``while_loop`` is not reverse-mode differentiable because XLA computations
require static bounds on memory requirements.
.. note::
:py:func:`while_loop` compiles ``cond_fun`` and ``body_fun``, so while it
can be combined with :py:func:`jit`, it's usually unnecessary.
Args:
cond_fun: function of type ``a -> Bool``.
body_fun: function of type ``a -> a``.
@ -1398,6 +1406,10 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
dtype (or a nested tuple/list/dict container data structure with a fixed
structure and arrays with fixed shape and dtype at the leaves).
.. note::
:py:func:`scan` compiles ``f``, so while it can be combined with
:py:func:`jit`, it's usually unnecessary.
Args:
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
that ``f`` accepts two arguments where the first is a value of the loop