Merge pull request #19868 from dan-zheng:jaxpr-docs

PiperOrigin-RevId: 628393328
This commit is contained in:
jax authors 2024-04-26 06:59:37 -07:00
commit 00663474ce

View File

@ -34,7 +34,7 @@ There are two related representations in the code for jaxprs,
:py:class:`jax.core.Jaxpr`, and is what you obtain when you use
:py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following fields:
* ``jaxpr``: is a :py:class:`jax.core.Jaxpr` representing the actual
* ``jaxpr`` is a :py:class:`jax.core.Jaxpr` representing the actual
computation content of the function (described below).
* ``consts`` is a list of constants.
@ -42,9 +42,9 @@ The most interesting part of the ClosedJaxpr is the actual execution content,
represented as a :py:class:`jax.core.Jaxpr` as printed using the following
grammar::
jaxpr ::= { lambda Var* ; Var+.
let Eqn*
in [Expr+] }
Jaxpr ::= { lambda Var* ; Var+. let
Eqn*
in [Expr+] }
where:
* The parameters of the jaxpr are shown as two lists of variables separated by
@ -62,7 +62,7 @@ where:
Equations are printed as follows::
Eqn ::= let Var+ = Primitive [ Param* ] Expr+
Eqn ::= Var+ = Primitive [ Param* ] Expr+
where:
* ``Var+`` are one or more intermediate variables to be defined as the output
@ -76,7 +76,7 @@ where:
square brackets. Each parameter is shown as ``Name = Value``.
Most jaxpr primitives are first-order (they take just one or more Expr as arguments)::
Most jaxpr primitives are first-order (they take just one or more ``Expr`` as arguments)::
Primitive := add | sub | sin | mul | ...