mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #4828 from j-towns:api-doc-fixes
PiperOrigin-RevId: 341395571
This commit is contained in:
commit
ffff3a42fc
@ -89,6 +89,7 @@ Parallelization (:code:`pmap`)
|
||||
.. autofunction:: hessian
|
||||
.. autofunction:: jvp
|
||||
.. autofunction:: linearize
|
||||
.. autofunction:: linear_transpose
|
||||
.. autofunction:: vjp
|
||||
.. autoclass:: custom_jvp
|
||||
|
||||
|
27
jax/api.py
27
jax/api.py
@ -469,6 +469,13 @@ def xla_computation(fun: Callable,
|
||||
the output of ``fun`` and where the leaves are objects with ``shape`` and
|
||||
``dtype`` attributes representing the corresponding types of the output
|
||||
leaves.
|
||||
donate_argnums: Specify which arguments are "donated" to the computation.
|
||||
It is safe to donate arguments if you no longer need them once the
|
||||
computation has finished. In some cases XLA can make use of donated
|
||||
buffers to reduce the amount of memory needed to perform a computation,
|
||||
for example recycling one of your input buffers to store a result. You
|
||||
should not re-use buffers that you donate to a computation, JAX will raise
|
||||
an error if you try to.
|
||||
|
||||
Returns:
|
||||
A wrapped version of ``fun`` that when applied to example arguments returns
|
||||
@ -718,7 +725,7 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
differentiated and the second element is auxiliary data. Default False.
|
||||
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
||||
holomorphic. If True, inputs and outputs must be complex. Default False.
|
||||
allow_int: Optional, bool. Whether to allow differentiating with
|
||||
allow_int: Optional, bool. Whether to allow differentiating with
|
||||
respect to integer valued inputs. The gradient of an integer input will
|
||||
have a trivial vector-space dtype (float0). Default False.
|
||||
|
||||
@ -774,11 +781,11 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
argnums: Optional, integer or sequence of integers. Specifies which
|
||||
positional argument(s) to differentiate with respect to (default 0).
|
||||
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
|
||||
first element is considered the output of the mathematical function to be
|
||||
differentiated and the second element is auxiliary data. Default False.
|
||||
first element is considered the output of the mathematical function to be
|
||||
differentiated and the second element is auxiliary data. Default False.
|
||||
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
||||
holomorphic. If True, inputs and outputs must be complex. Default False.
|
||||
allow_int: Optional, bool. Whether to allow differentiating with
|
||||
allow_int: Optional, bool. Whether to allow differentiating with
|
||||
respect to integer valued inputs. The gradient of an integer input will
|
||||
have a trivial vector-space dtype (float0). Default False.
|
||||
|
||||
@ -953,7 +960,7 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
positional argument(s) to differentiate with respect to (default ``0``).
|
||||
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
||||
holomorphic. Default False.
|
||||
allow_int: Optional, bool. Whether to allow differentiating with
|
||||
allow_int: Optional, bool. Whether to allow differentiating with
|
||||
respect to integer valued inputs. The gradient of an integer input will
|
||||
have a trivial vector-space dtype (float0). Default False.
|
||||
|
||||
@ -1953,11 +1960,11 @@ def make_jaxpr(fun: Callable,
|
||||
|
||||
Returns:
|
||||
A wrapped version of ``fun`` that when applied to example arguments returns
|
||||
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
||||
argument ``return_shape`` is ``True``, then the returned function instead
|
||||
returns a pair where the first element is the ``ClosedJaxpr``
|
||||
representation of ``fun`` and the second element is a pytree representing
|
||||
the structure, shape, and dtypes of the output of ``fun``.
|
||||
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
||||
argument ``return_shape`` is ``True``, then the returned function instead
|
||||
returns a pair where the first element is the ``ClosedJaxpr``
|
||||
representation of ``fun`` and the second element is a pytree representing
|
||||
the structure, shape, and dtypes of the output of ``fun``.
|
||||
|
||||
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
||||
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
||||
|
Loading…
x
Reference in New Issue
Block a user