mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
DOC: note compilation of funcs in lax control_flow
This commit is contained in:
parent
bb0816227d
commit
9695ebe2f0
@ -182,6 +182,10 @@ def fori_loop(lower, upper, body_fun, init_val):
|
||||
structure with a fixed structure and arrays with fixed shape and dtype at the
|
||||
leaves).
|
||||
|
||||
.. note::
|
||||
:py:func:`fori_loop` compiles ``body_fun``, so while it can be combined with
|
||||
:py:func:`jit`, it's usually unnecessary.
|
||||
|
||||
Args:
|
||||
lower: an integer representing the loop index lower bound (inclusive)
|
||||
upper: an integer representing the loop index upper bound (exclusive)
|
||||
@ -267,6 +271,10 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
``while_loop`` is not reverse-mode differentiable because XLA computations
|
||||
require static bounds on memory requirements.
|
||||
|
||||
.. note::
|
||||
:py:func:`while_loop` compiles ``cond_fun`` and ``body_fun``, so while it
|
||||
can be combined with :py:func:`jit`, it's usually unnecessary.
|
||||
|
||||
Args:
|
||||
cond_fun: function of type ``a -> Bool``.
|
||||
body_fun: function of type ``a -> a``.
|
||||
@ -1398,6 +1406,10 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
dtype (or a nested tuple/list/dict container data structure with a fixed
|
||||
structure and arrays with fixed shape and dtype at the leaves).
|
||||
|
||||
.. note::
|
||||
:py:func:`scan` compiles ``f``, so while it can be combined with
|
||||
:py:func:`jit`, it's usually unnecessary.
|
||||
|
||||
Args:
|
||||
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
||||
that ``f`` accepts two arguments where the first is a value of the loop
|
||||
|
Loading…
x
Reference in New Issue
Block a user