docstrings and API reference doc listing for the traced AOT stage

This commit is contained in:
Roy Frostig 2025-02-11 22:28:56 -08:00
parent 914adaf60c
commit 8720a9c0cd
2 changed files with 12 additions and 0 deletions

View File

@ -12,6 +12,9 @@ Classes
:members: trace, lower
:special-members: __call__
.. autoclass:: Traced
:members: jaxpr, out_info, lower
.. autoclass:: Lowered
:members: in_tree, out_tree, compile, as_text, compiler_ir, cost_analysis

View File

@ -733,6 +733,12 @@ class Lowered(Stage):
class Traced(Stage):
"""Traced form of a function specialized to argument types and values.
A traced computation is ready for lowering. This class carries the
traced representation with the remaining information needed to later
lower, compile, and execute it.
"""
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
"_args_flat", "_arg_names", "_num_consts"]
@ -756,6 +762,7 @@ class Traced(Stage):
def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
_private_parameters: mlir.LoweringParameters | None = None):
"""Lower to compiler input, returning a ``Lowered`` instance."""
from jax._src.interpreters import pxla
from jax._src import pjit
@ -805,6 +812,8 @@ class Wrapped(Protocol):
def lower(self, *args, **kwargs) -> Lowered:
"""Lower this function explicitly for the given arguments.
This is a shortcut for ``self.trace(*args, **kwargs).lower()``.
A lowered function is staged out of Python and translated to a
compiler's input language, possibly in a backend-dependent
manner. It is ready for compilation but not yet compiled.