Address review comments, add a test.

This commit is contained in:
Peter Hawkins 2019-06-24 10:45:42 -04:00
parent f2bc287865
commit 5bdbcc42d5
2 changed files with 15 additions and 5 deletions

View File

@ -784,6 +784,14 @@ def lift_linearized(jaxpr, primal_avals, consts, io_tree, out_pval, *py_args):
return apply_jaxtree_fun(fun, io_tree, *py_args)
def _check_inexact_input_vjp(x):
aval = core.get_aval(x)
if not onp.issubdtype(aval.dtype, onp.inexact):
msg = ("Primal inputs to reverse-mode differentiation must be of float "
"or complex type, got type {}")
raise TypeError(msg.format(aval.dtype.name))
def vjp(fun, *primals, **kwargs):
"""Compute a (reverse-mode) vector-Jacobian product of `fun`.
@ -823,11 +831,7 @@ def vjp(fun, *primals, **kwargs):
fun = lu.wrap_init(fun)
primals_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, primals))
_check_args(primals_flat)
for p in tree_flatten(primals)[0]:
if not onp.issubdtype(onp.result_type(p), onp.inexact):
msg = ("Primal inputs to reverse-mode differentiation must be of float "
"or complex type, got type {}")
raise ValueError(msg.format(onp.result_type(p)))
tree_map(_check_inexact_input_vjp, primals)
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(fun, in_trees)
if not has_aux:
out_primal, out_vjp = ad.vjp(jaxtree_fun, primals_flat)

View File

@ -815,6 +815,12 @@ class APITest(jtu.JaxTestCase):
if eqn.bound_subjaxprs)
self.assertEqual(len(subjaxpr.eqns), 1)
def test_grad_of_int_errors(self):
dfn = grad(lambda x: x ** 2)
jtu.check_raises(lambda: dfn(3), TypeError,
"Primal inputs to reverse-mode differentiation must be of "
"float or complex type, got type int32")
if __name__ == '__main__':
absltest.main()