mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Clarify tracking
Clarify tracing a bit and use wording that does not suggest that JAX executed python program.
This commit is contained in:
parent
84f55f87ec
commit
9dccf567ce
@ -104,10 +104,11 @@ inline.
|
||||
The ``reduce_sum`` primitive has named parameters ``axes`` and ``input_shape``, in
|
||||
addition to the operand ``e``.
|
||||
|
||||
Note that JAX traces through Python-level control-flow and higher-order functions
|
||||
when it extracts the jaxpr. This means that just because a Python program contains
|
||||
functions and control-flow, the resulting jaxpr does not have
|
||||
to contain control-flow or higher-order features.
|
||||
Note that even though execution of a program that calls into JAX builds a jaxpr,
|
||||
Python-level control-flow and Python-level functions execute normally.
|
||||
This means that just because a Python program contains functions and control-flow,
|
||||
the resulting jaxpr does not have to contain control-flow or higher-order features.
|
||||
|
||||
For example, when tracing the function ``func3`` JAX will inline the call to
|
||||
``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
|
||||
jaxpr as before
|
||||
|
Loading…
x
Reference in New Issue
Block a user