mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
more nonlinear evaluation in backward_pass.py (#2214)
* more nonlinear evaluation in backward_pass.py fixes #2180 * add tests, fix #1963 by not raising error eagerly
This commit is contained in:
parent
5e77789afe
commit
7ca43f0ea9
@ -183,21 +183,12 @@ def backward_pass(jaxpr: core.Jaxpr, consts, args, cotangents_in):
|
||||
else:
|
||||
write_primal(eqn.outvars[0], ans)
|
||||
else:
|
||||
call_jaxpr = eqn.params["call_jaxpr"]
|
||||
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
||||
if any(is_linear(v) for v in eqn.invars):
|
||||
linear_eqns.append(eqn)
|
||||
elif eqn.primitive is not pe.remat_call_p:
|
||||
ans = _eval_subjaxpr_primals(
|
||||
eqn.primitive, call_jaxpr,
|
||||
map(read_primal, eqn.invars), eqn.params)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
|
||||
# we special-case remat_call here because it can be mixed linear /
|
||||
# nonlinear, so we always evaluate it even if it has a linear part
|
||||
if eqn.primitive is pe.remat_call_p:
|
||||
ans = _eval_subjaxpr_primals(
|
||||
eqn.primitive, call_jaxpr,
|
||||
map(read_primal, eqn.invars), eqn.params)
|
||||
if any(not is_linear(v) for v in eqn.invars):
|
||||
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
|
||||
map(read_primal, eqn.invars), params)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
|
||||
ct_env = {}
|
||||
@ -260,12 +251,10 @@ def _eval_primals(jaxpr, args):
|
||||
else:
|
||||
write_primal(eqn.outvars[0], ans)
|
||||
else:
|
||||
call_jaxpr = eqn.params["call_jaxpr"]
|
||||
if (eqn.primitive is pe.remat_call_p or
|
||||
not any(is_linear(v) for v in eqn.invars)):
|
||||
ans = _eval_subjaxpr_primals(
|
||||
eqn.primitive, call_jaxpr,
|
||||
map(read_primal, eqn.invars), eqn.params)
|
||||
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
||||
if any(not is_linear(v) for v in eqn.invars):
|
||||
ans = _eval_subjaxpr_primals(eqn.primitive, call_jaxpr,
|
||||
map(read_primal, eqn.invars), params)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
return map(read_primal, jaxpr.outvars)
|
||||
|
||||
|
@ -1006,11 +1006,11 @@ def _scan_transpose(cts, *args, forward, length, num_consts, num_carry, jaxpr, l
|
||||
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
|
||||
raise NotImplementedError
|
||||
if not all(init_lin):
|
||||
raise NotImplementedError
|
||||
pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963
|
||||
|
||||
consts, init, xs = split_list(args, [num_consts, num_carry])
|
||||
ires, consts = split_list(consts, [num_ires])
|
||||
xs, eres = split_list(xs, [sum(xs_lin)])
|
||||
consts, _, xs = split_list(args, [num_consts, num_carry])
|
||||
ires, _ = split_list(consts, [num_ires])
|
||||
_, eres = split_list(xs, [sum(xs_lin)])
|
||||
assert not any(r is ad.undefined_primal for r in ires)
|
||||
assert not any(r is ad.undefined_primal for r in eres)
|
||||
|
||||
|
@ -1660,6 +1660,45 @@ class APITest(jtu.JaxTestCase):
|
||||
u0 = np.ones_like(target)
|
||||
loss(u0, target, 10) # doesn't crash
|
||||
|
||||
def test_remat_jit3(self):
|
||||
# https://github.com/google/jax/issues/2180
|
||||
def f(w, x):
|
||||
a = np.dot(x, w)
|
||||
b = np.einsum("btd,bTd->btT", a, a)
|
||||
c = np.einsum("btT,btd->btd", b, a)
|
||||
return np.sum(c)
|
||||
|
||||
w = np.ones([1, 1])
|
||||
x = np.ones([1, 1, 1])
|
||||
f = api.remat(f)
|
||||
api.grad(f)(w, x) # doesn't crash
|
||||
|
||||
@api.jit
|
||||
def mul(a, b):
|
||||
return a * b
|
||||
|
||||
def f(w, x):
|
||||
a = mul(w, x)
|
||||
b = mul(a, a)
|
||||
return b
|
||||
|
||||
w = 1.
|
||||
x = 1.
|
||||
f = api.remat(f)
|
||||
api.grad(f)(w, x) # doesn't crash
|
||||
|
||||
def test_remat_scan2(self):
|
||||
# https://github.com/google/jax/issues/1963
|
||||
|
||||
def scan_bug(x0):
|
||||
f = lambda x, _: (x + 1, None)
|
||||
def scanned_f(x, _):
|
||||
return lax.scan(f, x, xs=None, length=1)[0], None
|
||||
x, _ = jax.remat(scanned_f)(x0, None)
|
||||
return x
|
||||
|
||||
jax.grad(scan_bug)(1.0) # doesn't crash
|
||||
|
||||
def test_trivial_computations(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = api.jit(lambda x: x)(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user