add docstrings for check_vjp and check_jvp

This commit is contained in:
Stephan Hoyer 2025-01-23 17:31:13 -08:00
parent 458f6a6efe
commit b2afb5bf4f
2 changed files with 39 additions and 1 deletions

View File

@ -1,5 +1,5 @@
``jax.test_util`` module
===================
========================
.. currentmodule:: jax.test_util

View File

@ -247,6 +247,25 @@ def _merge_tolerance(tol, default):
def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
"""Check a JVP from automatic differentiation against finite differences.
Gradients are only checked in a single randomly chosen direction, which
ensures that the finite difference calculation does not become prohibitively
expensive even for large input/output spaces.
Args:
f: function to check at ``f(*args)``.
f_vjp: function that calculates ``jax.jvp`` applied to ``f``. Typically this
should be ``functools.partial(jax.jvp, f))``.
args: tuple of argument values.
atol: absolute tolerance for gradient equality.
rtol: relative tolerance for gradient equality.
eps: step size used for finite differences.
err_msg: additional error message to include if checks fail.
Raises:
AssertionError: if gradients do not match.
"""
atol = _merge_tolerance(atol, default_gradient_tolerance)
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
rng = np.random.RandomState(0)
@ -266,6 +285,25 @@ def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS, err_msg=''):
"""Check a VJP from automatic differentiation against finite differences.
Gradients are only checked in a single randomly chosen direction, which
ensures that the finite difference calculation does not become prohibitively
expensive even for large input/output spaces.
Args:
f: function to check at ``f(*args)``.
f_vjp: function that calculates ``jax.vjp`` applied to ``f``. Typically this
should be ``functools.partial(jax.jvp, f))``.
args: tuple of argument values.
atol: absolute tolerance for gradient equality.
rtol: relative tolerance for gradient equality.
eps: step size used for finite differences.
err_msg: additional error message to include if checks fail.
Raises:
AssertionError: if gradients do not match.
"""
atol = _merge_tolerance(atol, default_gradient_tolerance)
rtol = _merge_tolerance(rtol, default_gradient_tolerance)
_rand_like = partial(rand_like, np.random.RandomState(0))