mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Readd jax.test_util.check_jvp and check_vjp to the public JAX API.
This commit is contained in:
parent
e638445f50
commit
c491203bdd
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
# TODO(phawkins): remove all exports except check_grads.
|
||||
# TODO(phawkins): remove all exports except check_grads/check_jvp/check_vjp.
|
||||
from jax._src.test_util import (
|
||||
JaxTestCase,
|
||||
JaxTestLoader,
|
||||
@ -21,6 +21,8 @@ from jax._src.test_util import (
|
||||
check_close,
|
||||
check_eq,
|
||||
check_grads as check_grads,
|
||||
check_jvp as check_jvp,
|
||||
check_vjp as check_vjp,
|
||||
device_under_test,
|
||||
format_shape_dtype_string,
|
||||
rand_uniform,
|
||||
|
Loading…
x
Reference in New Issue
Block a user