improve linearize error message

fixes #871
This commit is contained in:
Matthew Johnson 2019-06-18 09:18:44 -07:00
parent b4acfe0640
commit eb01b8bfef
2 changed files with 29 additions and 5 deletions

View File

@ -762,13 +762,23 @@ def linearize(fun, *primals):
out_primal, out_pval, jaxpr, consts = ad.linearize(jaxtree_fun, *primals_flat)
out_tree = out_tree()
out_primal_py = build_tree(out_tree, out_primal)
lifted_jvp = partial(lift_linearized, jaxpr, consts, (in_trees, out_tree), out_pval)
primal_avals = list(map(core.get_aval, primals_flat))
lifted_jvp = partial(lift_linearized, jaxpr, primal_avals, consts,
(in_trees, out_tree), out_pval)
return out_primal_py, lifted_jvp
def lift_linearized(jaxpr, consts, io_tree, out_pval, *py_args):
def fun(*args):
primals = pack(args) # doesn't matter what these are-they'll be ignored
tangents = pack(args)
def lift_linearized(jaxpr, primal_avals, consts, io_tree, out_pval, *py_args):
def fun(*tangents):
tangent_avals = list(map(core.get_aval, tangents))
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
try:
core.lattice_join(primal_aval, tangent_aval)
except TypeError:
msg = ("linearized function called on tangent values inconsistent with "
"the original primal values.")
raise ValueError(msg)
primals = pack(tangents) # doesn't matter what these are-they'll be ignored
tangents = pack(tangents)
_, ans = eval_jaxpr(jaxpr, consts, (), primals, tangents)
return pe.merge_pvals(ans, out_pval)

View File

@ -784,6 +784,20 @@ class APITest(jtu.JaxTestCase):
assert len(api.make_jaxpr(fun)(1).eqns) == 0
def test_issue_871(self):
T = np.array([[1., 2.], [3., 4.], [5., 6.]])
x = np.array([1, 2, 3])
y, f_jvp = api.linearize(np.sum, x)
jtu.check_raises(lambda: f_jvp(T), ValueError,
("linearized function called on tangent values "
"inconsistent with the original primal values."))
y, f_jvp = api.linearize(api.jit(np.sum), x)
jtu.check_raises(lambda: f_jvp(T), ValueError,
("linearized function called on tangent values "
"inconsistent with the original primal values."))
if __name__ == '__main__':
absltest.main()