Clarify tracking

Clarify tracing a bit and use wording that does not suggest that JAX executed python program.
This commit is contained in:
Łukasz Lew 2021-01-25 17:16:29 -08:00 committed by GitHub
parent 84f55f87ec
commit 9dccf567ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -104,10 +104,11 @@ inline.
The ``reduce_sum`` primitive has named parameters ``axes`` and ``input_shape``, in
addition to the operand ``e``.
Note that JAX traces through Python-level control-flow and higher-order functions
when it extracts the jaxpr. This means that just because a Python program contains
functions and control-flow, the resulting jaxpr does not have
to contain control-flow or higher-order features.
Note that even though execution of a program that calls into JAX builds a jaxpr,
Python-level control-flow and Python-level functions execute normally.
This means that just because a Python program contains functions and control-flow,
the resulting jaxpr does not have to contain control-flow or higher-order features.
For example, when tracing the function ``func3`` JAX will inline the call to
``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
jaxpr as before