mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add docstring for defjvp_all
This commit is contained in:
parent
720dec4072
commit
ab20f0292c
@ -19,6 +19,6 @@ Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: jax
|
||||
:members: jit, disable_jit, grad, value_and_grad, vmap, pmap, jacfwd, jacrev, hessian, jvp, linearize, vjp, make_jaxpr, eval_shape, custom_transforms
|
||||
:members: jit, disable_jit, grad, value_and_grad, vmap, pmap, jacfwd, jacrev, hessian, jvp, linearize, vjp, make_jaxpr, eval_shape, custom_transforms, defjvp, defjvp2, defjvp_all, defvjp, defvjp2, defvjp_all, custom_gradient
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
42
jax/api.py
42
jax/api.py
@ -1076,7 +1076,47 @@ def _check_custom_transforms_type(name, fun):
|
||||
raise TypeError(msg.format(name, type(fun)))
|
||||
|
||||
def defjvp_all(fun, custom_jvp):
|
||||
"""Define a custom JVP rule for a custom_transforms function."""
|
||||
"""Define a custom JVP rule for a custom_transforms function.
|
||||
|
||||
If ``fun`` represents a function with signature ``a -> b``, then
|
||||
``custom_jvp`` represents a function with signature ``a -> T a -> (b, Tb)``,
|
||||
where we use ``T x`` to represent a tangent type for the type ``x``.
|
||||
|
||||
Defining a custom JVP rule will also affect the dfeault VJP rule, which is
|
||||
derived from the JVP rule automatically via transposition.
|
||||
|
||||
Args:
|
||||
fun: a custom_transforms function.
|
||||
custom_jvp: a Python callable specifying the JVP rule, taking two tuples as
|
||||
arguments specifying the input primal values and tangent values,
|
||||
respectively. The tuple elements can be arrays, scalars, or (nested)
|
||||
standard Python containers (tuple/list/dict) thereof. Must be functionally
|
||||
pure.
|
||||
|
||||
Returns:
|
||||
None. A side-effect is that ``fun`` is associated with the JVP rule
|
||||
specified by ``custom_jvp``.
|
||||
|
||||
For example:
|
||||
|
||||
>>> @jax.custom_transforms
|
||||
... def f(x):
|
||||
... return np.sin(x ** 2)
|
||||
...
|
||||
>>> print(f(3.))
|
||||
0.4121185
|
||||
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
|
||||
>>> print(out_primal)
|
||||
0.4121185
|
||||
>>> print(out_tangent)
|
||||
-10.933563
|
||||
>>> jax.defjvp_all(f, lambda ps, ts: (np.sin(ps[0] ** 2), 8. * ts[0]))
|
||||
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
|
||||
>>> print(out_primal)
|
||||
0.4121185
|
||||
>>> print(out_tangent)
|
||||
16.0
|
||||
"""
|
||||
_check_custom_transforms_type("defjvp_all", fun)
|
||||
def custom_transforms_jvp(primals, tangents, **params):
|
||||
jax_kwargs, jax_args = primals[0], primals[1:]
|
||||
|
Loading…
x
Reference in New Issue
Block a user