mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
docstring for jax.stages.Wrapped
This commit is contained in:
parent
0869183107
commit
4505d57a60
@ -497,6 +497,14 @@ class Lowered(Stage):
|
||||
|
||||
|
||||
class Wrapped(Protocol):
|
||||
"""A function ready to be specialized, lowered, and compiled.
|
||||
|
||||
This protocol reflects the output of functions such as
|
||||
``jax.jit``. Calling it results in JIT (just-in-time) lowering,
|
||||
compilation, and execution. It can also be explicitly lowered prior
|
||||
to compilation, and the result compiled prior to execution.
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Executes the wrapped function, lowering and compiling as needed."""
|
||||
raise NotImplementedError
|
||||
|
Loading…
x
Reference in New Issue
Block a user