more nonlinear evaluation in backward_pass.py (#2214)

* more nonlinear evaluation in backward_pass.py

fixes #2180

* add tests, fix #1963 by not raising error eagerly
This commit is contained in:
Matthew Johnson 2020-02-11 15:56:53 -08:00 committed by GitHub
parent 5e77789afe
commit 7ca43f0ea9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 23 deletions

View File

@ -183,21 +183,12 @@ def backward_pass(jaxpr: core.Jaxpr, consts, args, cotangents_in):
else:
write_primal(eqn.outvars[0], ans)
else:
call_jaxpr = eqn.params["call_jaxpr"]
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
elif eqn.primitive is not pe.remat_call_p:
ans = _eval_subjaxpr_primals(
eqn.primitive, call_jaxpr,
map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)
# we special-case remat_call here because it can be mixed linear /
# nonlinear, so we always evaluate it even if it has a linear part
if eqn.primitive is pe.remat_call_p:
ans = _eval_subjaxpr_primals(
eqn.primitive, call_jaxpr,
map(read_primal, eqn.invars), eqn.params)
if any(not is_linear(v) for v in eqn.invars):
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
map(read_primal, eqn.invars), params)
map(write_primal, eqn.outvars, ans)
ct_env = {}
@ -260,12 +251,10 @@ def _eval_primals(jaxpr, args):
else:
write_primal(eqn.outvars[0], ans)
else:
call_jaxpr = eqn.params["call_jaxpr"]
if (eqn.primitive is pe.remat_call_p or
not any(is_linear(v) for v in eqn.invars)):
ans = _eval_subjaxpr_primals(
eqn.primitive, call_jaxpr,
map(read_primal, eqn.invars), eqn.params)
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
if any(not is_linear(v) for v in eqn.invars):
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
map(read_primal, eqn.invars), params)
map(write_primal, eqn.outvars, ans)
return map(read_primal, jaxpr.outvars)

View File

@ -1006,11 +1006,11 @@ def _scan_transpose(cts, *args, forward, length, num_consts, num_carry, jaxpr, l
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
raise NotImplementedError
if not all(init_lin):
raise NotImplementedError
pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963
consts, init, xs = split_list(args, [num_consts, num_carry])
ires, consts = split_list(consts, [num_ires])
xs, eres = split_list(xs, [sum(xs_lin)])
consts, _, xs = split_list(args, [num_consts, num_carry])
ires, _ = split_list(consts, [num_ires])
_, eres = split_list(xs, [sum(xs_lin)])
assert not any(r is ad.undefined_primal for r in ires)
assert not any(r is ad.undefined_primal for r in eres)

View File

@ -1660,6 +1660,45 @@ class APITest(jtu.JaxTestCase):
u0 = np.ones_like(target)
loss(u0, target, 10) # doesn't crash
def test_remat_jit3(self):
# https://github.com/google/jax/issues/2180
def f(w, x):
a = np.dot(x, w)
b = np.einsum("btd,bTd->btT", a, a)
c = np.einsum("btT,btd->btd", b, a)
return np.sum(c)
w = np.ones([1, 1])
x = np.ones([1, 1, 1])
f = api.remat(f)
api.grad(f)(w, x) # doesn't crash
@api.jit
def mul(a, b):
return a * b
def f(w, x):
a = mul(w, x)
b = mul(a, a)
return b
w = 1.
x = 1.
f = api.remat(f)
api.grad(f)(w, x) # doesn't crash
def test_remat_scan2(self):
# https://github.com/google/jax/issues/1963
def scan_bug(x0):
f = lambda x, _: (x + 1, None)
def scanned_f(x, _):
return lax.scan(f, x, xs=None, length=1)[0], None
x, _ = jax.remat(scanned_f)(x0, None)
return x
jax.grad(scan_bug)(1.0) # doesn't crash
def test_trivial_computations(self):
x = np.array([1, 2, 3])
y = api.jit(lambda x: x)(x)