mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
add docstrings for check_vjp and check_jvp
This commit is contained in:
parent
458f6a6efe
commit
b2afb5bf4f
@ -1,5 +1,5 @@
|
||||
``jax.test_util`` module
|
||||
===================
|
||||
========================
|
||||
|
||||
.. currentmodule:: jax.test_util
|
||||
|
||||
|
@ -247,6 +247,25 @@ def _merge_tolerance(tol, default):
|
||||
|
||||
|
||||
def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
|
||||
"""Check a JVP from automatic differentiation against finite differences.
|
||||
|
||||
Gradients are only checked in a single randomly chosen direction, which
|
||||
ensures that the finite difference calculation does not become prohibitively
|
||||
expensive even for large input/output spaces.
|
||||
|
||||
Args:
|
||||
f: function to check at ``f(*args)``.
|
||||
f_vjp: function that calculates ``jax.jvp`` applied to ``f``. Typically this
|
||||
should be ``functools.partial(jax.jvp, f))``.
|
||||
args: tuple of argument values.
|
||||
atol: absolute tolerance for gradient equality.
|
||||
rtol: relative tolerance for gradient equality.
|
||||
eps: step size used for finite differences.
|
||||
err_msg: additional error message to include if checks fail.
|
||||
|
||||
Raises:
|
||||
AssertionError: if gradients do not match.
|
||||
"""
|
||||
atol = _merge_tolerance(atol, default_gradient_tolerance)
|
||||
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
|
||||
rng = np.random.RandomState(0)
|
||||
@ -266,6 +285,25 @@ def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
|
||||
|
||||
|
||||
def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
|
||||
"""Check a VJP from automatic differentiation against finite differences.
|
||||
|
||||
Gradients are only checked in a single randomly chosen direction, which
|
||||
ensures that the finite difference calculation does not become prohibitively
|
||||
expensive even for large input/output spaces.
|
||||
|
||||
Args:
|
||||
f: function to check at ``f(*args)``.
|
||||
f_vjp: function that calculates ``jax.vjp`` applied to ``f``. Typically this
|
||||
should be ``functools.partial(jax.jvp, f))``.
|
||||
args: tuple of argument values.
|
||||
atol: absolute tolerance for gradient equality.
|
||||
rtol: relative tolerance for gradient equality.
|
||||
eps: step size used for finite differences.
|
||||
err_msg: additional error message to include if checks fail.
|
||||
|
||||
Raises:
|
||||
AssertionError: if gradients do not match.
|
||||
"""
|
||||
atol = _merge_tolerance(atol, default_gradient_tolerance)
|
||||
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
|
||||
_rand_like = partial(rand_like, np.random.RandomState(0))
|
||||
|
Loading…
x
Reference in New Issue
Block a user