docstring for jax.stages.Wrapped

This commit is contained in:
Roy Frostig 2022-09-01 18:30:33 -07:00
parent 0869183107
commit 4505d57a60

View File

@ -497,6 +497,14 @@ class Lowered(Stage):
class Wrapped(Protocol):
"""A function ready to be specialized, lowered, and compiled.
This protocol reflects the output of functions such as
``jax.jit``. Calling it results in JIT (just-in-time) lowering,
compilation, and execution. It can also be explicitly lowered prior
to compilation, and the result compiled prior to execution.
"""
def __call__(self, *args, **kwargs):
"""Executes the wrapped function, lowering and compiling as needed."""
raise NotImplementedError