mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Address review comments, add a test.
This commit is contained in:
parent
f2bc287865
commit
5bdbcc42d5
14
jax/api.py
14
jax/api.py
@ -784,6 +784,14 @@ def lift_linearized(jaxpr, primal_avals, consts, io_tree, out_pval, *py_args):
|
||||
|
||||
return apply_jaxtree_fun(fun, io_tree, *py_args)
|
||||
|
||||
def _check_inexact_input_vjp(x):
|
||||
aval = core.get_aval(x)
|
||||
if not onp.issubdtype(aval.dtype, onp.inexact):
|
||||
msg = ("Primal inputs to reverse-mode differentiation must be of float "
|
||||
"or complex type, got type {}")
|
||||
raise TypeError(msg.format(aval.dtype.name))
|
||||
|
||||
|
||||
def vjp(fun, *primals, **kwargs):
|
||||
"""Compute a (reverse-mode) vector-Jacobian product of `fun`.
|
||||
|
||||
@ -823,11 +831,7 @@ def vjp(fun, *primals, **kwargs):
|
||||
fun = lu.wrap_init(fun)
|
||||
primals_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, primals))
|
||||
_check_args(primals_flat)
|
||||
for p in tree_flatten(primals)[0]:
|
||||
if not onp.issubdtype(onp.result_type(p), onp.inexact):
|
||||
msg = ("Primal inputs to reverse-mode differentiation must be of float "
|
||||
"or complex type, got type {}")
|
||||
raise ValueError(msg.format(onp.result_type(p)))
|
||||
tree_map(_check_inexact_input_vjp, primals)
|
||||
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
|
||||
if not has_aux:
|
||||
out_primal, out_vjp = ad.vjp(jaxtree_fun, primals_flat)
|
||||
|
@ -815,6 +815,12 @@ class APITest(jtu.JaxTestCase):
|
||||
if eqn.bound_subjaxprs)
|
||||
self.assertEqual(len(subjaxpr.eqns), 1)
|
||||
|
||||
def test_grad_of_int_errors(self):
|
||||
dfn = grad(lambda x: x ** 2)
|
||||
jtu.check_raises(lambda: dfn(3), TypeError,
|
||||
"Primal inputs to reverse-mode differentiation must be of "
|
||||
"float or complex type, got type int32")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user