diff --git a/docs/jax.stages.rst b/docs/jax.stages.rst index 804019ee1..d26def14d 100644 --- a/docs/jax.stages.rst +++ b/docs/jax.stages.rst @@ -12,6 +12,9 @@ Classes :members: trace, lower :special-members: __call__ +.. autoclass:: Traced + :members: jaxpr, out_info, lower + .. autoclass:: Lowered :members: in_tree, out_tree, compile, as_text, compiler_ir, cost_analysis diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 9d6df18f3..8c2074524 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -733,6 +733,12 @@ class Lowered(Stage): class Traced(Stage): + """Traced form of a function specialized to argument types and values. + + A traced computation is ready for lowering. This class carries the + traced representation with the remaining information needed to later + lower, compile, and execute it. + """ __slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable", "_args_flat", "_arg_names", "_num_consts"] @@ -756,6 +762,7 @@ class Traced(Stage): def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, _private_parameters: mlir.LoweringParameters | None = None): + """Lower to compiler input, returning a ``Lowered`` instance.""" from jax._src.interpreters import pxla from jax._src import pjit @@ -805,6 +812,8 @@ class Wrapped(Protocol): def lower(self, *args, **kwargs) -> Lowered: """Lower this function explicitly for the given arguments. + This is a shortcut for ``self.trace(*args, **kwargs).lower()``. + A lowered function is staged out of Python and translated to a compiler's input language, possibly in a backend-dependent manner. It is ready for compilation but not yet compiled.