mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
b4acfe0640
commit
eb01b8bfef
20
jax/api.py
20
jax/api.py
@ -762,13 +762,23 @@ def linearize(fun, *primals):
|
||||
out_primal, out_pval, jaxpr, consts = ad.linearize(jaxtree_fun, *primals_flat)
|
||||
out_tree = out_tree()
|
||||
out_primal_py = build_tree(out_tree, out_primal)
|
||||
lifted_jvp = partial(lift_linearized, jaxpr, consts, (in_trees, out_tree), out_pval)
|
||||
primal_avals = list(map(core.get_aval, primals_flat))
|
||||
lifted_jvp = partial(lift_linearized, jaxpr, primal_avals, consts,
|
||||
(in_trees, out_tree), out_pval)
|
||||
return out_primal_py, lifted_jvp
|
||||
|
||||
def lift_linearized(jaxpr, consts, io_tree, out_pval, *py_args):
|
||||
def fun(*args):
|
||||
primals = pack(args) # doesn't matter what these are-they'll be ignored
|
||||
tangents = pack(args)
|
||||
def lift_linearized(jaxpr, primal_avals, consts, io_tree, out_pval, *py_args):
|
||||
def fun(*tangents):
|
||||
tangent_avals = list(map(core.get_aval, tangents))
|
||||
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
|
||||
try:
|
||||
core.lattice_join(primal_aval, tangent_aval)
|
||||
except TypeError:
|
||||
msg = ("linearized function called on tangent values inconsistent with "
|
||||
"the original primal values.")
|
||||
raise ValueError(msg)
|
||||
primals = pack(tangents) # doesn't matter what these are-they'll be ignored
|
||||
tangents = pack(tangents)
|
||||
_, ans = eval_jaxpr(jaxpr, consts, (), primals, tangents)
|
||||
return pe.merge_pvals(ans, out_pval)
|
||||
|
||||
|
@ -784,6 +784,20 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
assert len(api.make_jaxpr(fun)(1).eqns) == 0
|
||||
|
||||
def test_issue_871(self):
|
||||
T = np.array([[1., 2.], [3., 4.], [5., 6.]])
|
||||
x = np.array([1, 2, 3])
|
||||
|
||||
y, f_jvp = api.linearize(np.sum, x)
|
||||
jtu.check_raises(lambda: f_jvp(T), ValueError,
|
||||
("linearized function called on tangent values "
|
||||
"inconsistent with the original primal values."))
|
||||
|
||||
y, f_jvp = api.linearize(api.jit(np.sum), x)
|
||||
jtu.check_raises(lambda: f_jvp(T), ValueError,
|
||||
("linearized function called on tangent values "
|
||||
"inconsistent with the original primal values."))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user