Merge pull request #4828 from j-towns:api-doc-fixes

PiperOrigin-RevId: 341395571
This commit is contained in:
jax authors 2020-11-09 06:46:48 -08:00
commit ffff3a42fc
2 changed files with 18 additions and 10 deletions

View File

@ -89,6 +89,7 @@ Parallelization (:code:`pmap`)
.. autofunction:: hessian
.. autofunction:: jvp
.. autofunction:: linearize
.. autofunction:: linear_transpose
.. autofunction:: vjp
.. autoclass:: custom_jvp

View File

@ -469,6 +469,13 @@ def xla_computation(fun: Callable,
the output of ``fun`` and where the leaves are objects with ``shape`` and
``dtype`` attributes representing the corresponding types of the output
leaves.
donate_argnums: Specify which arguments are "donated" to the computation.
It is safe to donate arguments if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
buffers to reduce the amount of memory needed to perform a computation,
for example recycling one of your input buffers to store a result. You
should not re-use buffers that you donate to a computation, JAX will raise
an error if you try to.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
@ -718,7 +725,7 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
@ -774,11 +781,11 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
@ -953,7 +960,7 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
positional argument(s) to differentiate with respect to (default ``0``).
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
@ -1953,11 +1960,11 @@ def make_jaxpr(fun: Callable,
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
argument ``return_shape`` is ``True``, then the returned function instead
returns a pair where the first element is the ``ClosedJaxpr``
representation of ``fun`` and the second element is a pytree representing
the structure, shape, and dtypes of the output of ``fun``.
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
argument ``return_shape`` is ``True``, then the returned function instead
returns a pair where the first element is the ``ClosedJaxpr``
representation of ``fun`` and the second element is a pytree representing
the structure, shape, and dtypes of the output of ``fun``.
A ``jaxpr`` is JAX's intermediate representation for program traces. The
``jaxpr`` language is based on the simply-typed first-order lambda calculus