Merge pull request #735 from google/issue711

fix bug in scan utility _convert_zeros
This commit is contained in:
Matthew Johnson 2019-05-20 09:42:43 -07:00 committed by GitHub
commit 5cbaf75d28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 8 deletions

View File

@ -428,16 +428,25 @@ def _maybe_tracer_tuple_to_abstract_tuple(tup):
### scan
def _convert_zeros(convert_symbolic, example, tangent):
if tangent is ad.zero:
if not convert_symbolic:
def _convert_zeros(instantiate, example, tangent):
t = type(instantiate)
if t is bool:
if instantiate:
return ad.instantiate_zeros(example, tangent)
elif tangent is ad_util.zero:
return core.unit
else:
return ad.zeros_like_jaxval(example)
elif type(tangent) is ad.TangentTuple:
return core.pack(map(_convert_zeros, convert_symbolic, example, tangent))
raise TypeError(tangent) # not clear if ever reachable
elif t is tuple:
if type(tangent) is ad.TangentTuple:
return core.pack(map(_convert_zeros, instantiate, example, tangent))
elif tangent is ad_util.zero:
zeros = [ad_util.zero] * len(instantiate)
return core.pack(map(_convert_zeros, instantiate, example, zeros))
else:
raise TypeError(tangent)
else:
return tangent
raise TypeError(t)
def _demote_aval_rank(xs):
assert isinstance(xs, core.AbstractValue)
@ -641,7 +650,7 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
length = kwargs.pop('length')
forward = kwargs.pop('forward')
assert not kwargs
in_pvs, in_consts = unzip2([t.pval for t in tracers])
in_pvs, _ = unzip2([t.pval for t in tracers])
sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs)
sc_carry = sc_init
@ -819,7 +828,19 @@ def _make_typed_jaxpr(traceable, in_avals):
class FixedPointError(Exception): pass
# We use a custom bind for scan just to add some error checks
def scan_bind(consts, init, xs, forward, length, jaxpr):
if not core.skip_checks:
assert type(jaxpr.in_avals) is tuple
consts_aval, init_aval, xs_aval = jaxpr.in_avals
assert type(jaxpr.out_aval) is core.AbstractTuple
carry_aval, y_aval = jaxpr.out_aval
assert init_aval == carry_aval
return core.Primitive.bind(scan_p, consts, init, xs,
forward=forward, length=length, jaxpr=jaxpr)
scan_p = core.Primitive("scan")
scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(_scan_impl)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.primitive_transposes[scan_p] = _scan_transpose

View File

@ -592,6 +592,37 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = (onp.zeros_like(W_trans), onp.zeros_like(W_out))
self.assertAllClose(ans, expected, check_dtypes=False)
def testIssue711(self):
# Tests reverse-mode differentiation through a scan for which the scanned
# function also involves reverse-mode differentiation.
# See https://github.com/google/jax/issues/711
def harmonic_bond(conf, params):
return np.sum(conf * params)
def minimize_structure(test_params):
energy_fn = partial(harmonic_bond, params=test_params)
grad_fn = api.grad(energy_fn)
def apply_carry(carry, _):
i, x = carry
new_x = x - 0.1 * api.grad(energy_fn)(x)
new_carry = (i+1, new_x)
return new_carry, _
x0 = np.array([1., 2., 3.])
carry_final, _ = lax.scan(apply_carry, (0, x0), np.zeros((75, 0)))
_, x_final = carry_final
return x_final
initial_params = 0.5
minimize_structure(initial_params) # doesn't crash
def loss(test_params):
x_final = minimize_structure(test_params)
return np.sum(np.sin(1.0 - x_final))
api.grad(loss)(0.25) # doesn't crash
if __name__ == '__main__':
absltest.main()