Readd jax.test_util.check_jvp and check_vjp to the public JAX API.

This commit is contained in:
Peter Hawkins 2021-10-14 11:55:11 -04:00
parent e638445f50
commit c491203bdd

View File

@ -13,7 +13,7 @@
# limitations under the License.
# flake8: noqa: F401
# TODO(phawkins): remove all exports except check_grads.
# TODO(phawkins): remove all exports except check_grads/check_jvp/check_vjp.
from jax._src.test_util import (
JaxTestCase,
JaxTestLoader,
@ -21,6 +21,8 @@ from jax._src.test_util import (
check_close,
check_eq,
check_grads as check_grads,
check_jvp as check_jvp,
check_vjp as check_vjp,
device_under_test,
format_shape_dtype_string,
rand_uniform,