mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
docstrings and API reference doc listing for the traced AOT stage
This commit is contained in:
parent
914adaf60c
commit
8720a9c0cd
@ -12,6 +12,9 @@ Classes
|
||||
:members: trace, lower
|
||||
:special-members: __call__
|
||||
|
||||
.. autoclass:: Traced
|
||||
:members: jaxpr, out_info, lower
|
||||
|
||||
.. autoclass:: Lowered
|
||||
:members: in_tree, out_tree, compile, as_text, compiler_ir, cost_analysis
|
||||
|
||||
|
@ -733,6 +733,12 @@ class Lowered(Stage):
|
||||
|
||||
|
||||
class Traced(Stage):
|
||||
"""Traced form of a function specialized to argument types and values.
|
||||
|
||||
A traced computation is ready for lowering. This class carries the
|
||||
traced representation with the remaining information needed to later
|
||||
lower, compile, and execute it.
|
||||
"""
|
||||
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
|
||||
"_args_flat", "_arg_names", "_num_consts"]
|
||||
|
||||
@ -756,6 +762,7 @@ class Traced(Stage):
|
||||
|
||||
def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
|
||||
_private_parameters: mlir.LoweringParameters | None = None):
|
||||
"""Lower to compiler input, returning a ``Lowered`` instance."""
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src import pjit
|
||||
|
||||
@ -805,6 +812,8 @@ class Wrapped(Protocol):
|
||||
def lower(self, *args, **kwargs) -> Lowered:
|
||||
"""Lower this function explicitly for the given arguments.
|
||||
|
||||
This is a shortcut for ``self.trace(*args, **kwargs).lower()``.
|
||||
|
||||
A lowered function is staged out of Python and translated to a
|
||||
compiler's input language, possibly in a backend-dependent
|
||||
manner. It is ready for compilation but not yet compiled.
|
||||
|
Loading…
x
Reference in New Issue
Block a user