add 'inline' to jit docstring

This commit is contained in:
Matthew Johnson 2021-05-01 12:32:44 -07:00
parent 3c400a3e58
commit fe297e39ca

View File

@ -263,7 +263,10 @@ def jit(
buffers to reduce the amount of memory needed to perform a computation,
for example recycling one of your input buffers to store a result. You
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to.
an error if you try to. By default, no arguments are donated.
inline: Specify whether this function should be inlined into enclosing
jaxprs (rather than being represented as an application of the xla_call
primitive with its own subjaxpr). Default False.
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation.
@ -304,7 +307,7 @@ def _python_jit(
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> F:
"""The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit."""
# The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit.
_check_callable(fun)
static_argnums, static_argnames = _infer_argnums_and_argnames(
fun, static_argnums, static_argnames)
@ -365,15 +368,13 @@ def _cpp_jit(
donate_argnums: Union[int, Iterable[int]] = (),
inline: bool = False,
) -> F:
"""An implementation of `jit` that tries to do as much as possible in C++.
The goal of this function is to speed up the time it takes to process the
arguments, find the correct C++ executable, start the transfer of arguments
and schedule the computation.
As long as it does not support all features of the Python implementation
the C++ code will fallback to `_python_jit` when it faces some unsupported
feature.
"""
# An implementation of `jit` that tries to do as much as possible in C++.
# The goal of this function is to speed up the time it takes to process the
# arguments, find the correct C++ executable, start the transfer of arguments
# and schedule the computation.
# As long as it does not support all features of the Python implementation
# the C++ code will fallback to `_python_jit` when it faces some unsupported
# feature.
_check_callable(fun)
static_argnums, static_argnames = _infer_argnums_and_argnames(
fun, static_argnums, static_argnames)