diff --git a/docs/jax.test_util.rst b/docs/jax.test_util.rst index 063eadaca..278fa37ba 100644 --- a/docs/jax.test_util.rst +++ b/docs/jax.test_util.rst @@ -1,5 +1,5 @@ ``jax.test_util`` module -=================== +======================== .. currentmodule:: jax.test_util diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 93a6c29c2..220342ce5 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -247,6 +247,25 @@ def _merge_tolerance(tol, default): def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS, err_msg=''): + """Check a JVP from automatic differentiation against finite differences. + + Gradients are only checked in a single randomly chosen direction, which + ensures that the finite difference calculation does not become prohibitively + expensive even for large input/output spaces. + + Args: + f: function to check at ``f(*args)``. + f_vjp: function that calculates ``jax.jvp`` applied to ``f``. Typically this + should be ``functools.partial(jax.jvp, f))``. + args: tuple of argument values. + atol: absolute tolerance for gradient equality. + rtol: relative tolerance for gradient equality. + eps: step size used for finite differences. + err_msg: additional error message to include if checks fail. + + Raises: + AssertionError: if gradients do not match. + """ atol = _merge_tolerance(atol, default_gradient_tolerance) rtol = _merge_tolerance(rtol, default_gradient_tolerance) rng = np.random.RandomState(0) @@ -266,6 +285,25 @@ def check_jvp(f, f_jvp, args, atol=None, rtol=None, eps=EPS, err_msg=''): def check_vjp(f, f_vjp, args, atol=None, rtol=None, eps=EPS, err_msg=''): + """Check a VJP from automatic differentiation against finite differences. + + Gradients are only checked in a single randomly chosen direction, which + ensures that the finite difference calculation does not become prohibitively + expensive even for large input/output spaces. + + Args: + f: function to check at ``f(*args)``. + f_vjp: function that calculates ``jax.vjp`` applied to ``f``. Typically this + should be ``functools.partial(jax.jvp, f))``. + args: tuple of argument values. + atol: absolute tolerance for gradient equality. + rtol: relative tolerance for gradient equality. + eps: step size used for finite differences. + err_msg: additional error message to include if checks fail. + + Raises: + AssertionError: if gradients do not match. + """ atol = _merge_tolerance(atol, default_gradient_tolerance) rtol = _merge_tolerance(rtol, default_gradient_tolerance) _rand_like = partial(rand_like, np.random.RandomState(0))