mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 06:46:07 +00:00
docstrings and API reference doc listing for the traced AOT stage
This commit is contained in:
parent
914adaf60c
commit
8720a9c0cd
@ -12,6 +12,9 @@ Classes
|
|||||||
:members: trace, lower
|
:members: trace, lower
|
||||||
:special-members: __call__
|
:special-members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: Traced
|
||||||
|
:members: jaxpr, out_info, lower
|
||||||
|
|
||||||
.. autoclass:: Lowered
|
.. autoclass:: Lowered
|
||||||
:members: in_tree, out_tree, compile, as_text, compiler_ir, cost_analysis
|
:members: in_tree, out_tree, compile, as_text, compiler_ir, cost_analysis
|
||||||
|
|
||||||
|
@ -733,6 +733,12 @@ class Lowered(Stage):
|
|||||||
|
|
||||||
|
|
||||||
class Traced(Stage):
|
class Traced(Stage):
|
||||||
|
"""Traced form of a function specialized to argument types and values.
|
||||||
|
|
||||||
|
A traced computation is ready for lowering. This class carries the
|
||||||
|
traced representation with the remaining information needed to later
|
||||||
|
lower, compile, and execute it.
|
||||||
|
"""
|
||||||
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
|
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
|
||||||
"_args_flat", "_arg_names", "_num_consts"]
|
"_args_flat", "_arg_names", "_num_consts"]
|
||||||
|
|
||||||
@ -756,6 +762,7 @@ class Traced(Stage):
|
|||||||
|
|
||||||
def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
|
def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
|
||||||
_private_parameters: mlir.LoweringParameters | None = None):
|
_private_parameters: mlir.LoweringParameters | None = None):
|
||||||
|
"""Lower to compiler input, returning a ``Lowered`` instance."""
|
||||||
from jax._src.interpreters import pxla
|
from jax._src.interpreters import pxla
|
||||||
from jax._src import pjit
|
from jax._src import pjit
|
||||||
|
|
||||||
@ -805,6 +812,8 @@ class Wrapped(Protocol):
|
|||||||
def lower(self, *args, **kwargs) -> Lowered:
|
def lower(self, *args, **kwargs) -> Lowered:
|
||||||
"""Lower this function explicitly for the given arguments.
|
"""Lower this function explicitly for the given arguments.
|
||||||
|
|
||||||
|
This is a shortcut for ``self.trace(*args, **kwargs).lower()``.
|
||||||
|
|
||||||
A lowered function is staged out of Python and translated to a
|
A lowered function is staged out of Python and translated to a
|
||||||
compiler's input language, possibly in a backend-dependent
|
compiler's input language, possibly in a backend-dependent
|
||||||
manner. It is ready for compilation but not yet compiled.
|
manner. It is ready for compilation but not yet compiled.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user