mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add 'inline' to jit docstring
This commit is contained in:
parent
3c400a3e58
commit
fe297e39ca
@ -263,7 +263,10 @@ def jit(
|
||||
buffers to reduce the amount of memory needed to perform a computation,
|
||||
for example recycling one of your input buffers to store a result. You
|
||||
should not reuse buffers that you donate to a computation, JAX will raise
|
||||
an error if you try to.
|
||||
an error if you try to. By default, no arguments are donated.
|
||||
inline: Specify whether this function should be inlined into enclosing
|
||||
jaxprs (rather than being represented as an application of the xla_call
|
||||
primitive with its own subjaxpr). Default False.
|
||||
|
||||
Returns:
|
||||
A wrapped version of ``fun``, set up for just-in-time compilation.
|
||||
@ -304,7 +307,7 @@ def _python_jit(
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
inline: bool = False,
|
||||
) -> F:
|
||||
"""The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit."""
|
||||
# The Python implementation of `jax.jit`, being slowly replaced by _cpp_jit.
|
||||
_check_callable(fun)
|
||||
static_argnums, static_argnames = _infer_argnums_and_argnames(
|
||||
fun, static_argnums, static_argnames)
|
||||
@ -365,15 +368,13 @@ def _cpp_jit(
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
inline: bool = False,
|
||||
) -> F:
|
||||
"""An implementation of `jit` that tries to do as much as possible in C++.
|
||||
|
||||
The goal of this function is to speed up the time it takes to process the
|
||||
arguments, find the correct C++ executable, start the transfer of arguments
|
||||
and schedule the computation.
|
||||
As long as it does not support all features of the Python implementation
|
||||
the C++ code will fallback to `_python_jit` when it faces some unsupported
|
||||
feature.
|
||||
"""
|
||||
# An implementation of `jit` that tries to do as much as possible in C++.
|
||||
# The goal of this function is to speed up the time it takes to process the
|
||||
# arguments, find the correct C++ executable, start the transfer of arguments
|
||||
# and schedule the computation.
|
||||
# As long as it does not support all features of the Python implementation
|
||||
# the C++ code will fallback to `_python_jit` when it faces some unsupported
|
||||
# feature.
|
||||
_check_callable(fun)
|
||||
static_argnums, static_argnames = _infer_argnums_and_argnames(
|
||||
fun, static_argnums, static_argnames)
|
||||
|
Loading…
x
Reference in New Issue
Block a user