mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #735 from google/issue711
fix bug in scan utility _convert_zeros
This commit is contained in:
commit
5cbaf75d28
@ -428,16 +428,25 @@ def _maybe_tracer_tuple_to_abstract_tuple(tup):
|
||||
|
||||
### scan
|
||||
|
||||
def _convert_zeros(convert_symbolic, example, tangent):
|
||||
if tangent is ad.zero:
|
||||
if not convert_symbolic:
|
||||
def _convert_zeros(instantiate, example, tangent):
|
||||
t = type(instantiate)
|
||||
if t is bool:
|
||||
if instantiate:
|
||||
return ad.instantiate_zeros(example, tangent)
|
||||
elif tangent is ad_util.zero:
|
||||
return core.unit
|
||||
else:
|
||||
return ad.zeros_like_jaxval(example)
|
||||
elif type(tangent) is ad.TangentTuple:
|
||||
return core.pack(map(_convert_zeros, convert_symbolic, example, tangent))
|
||||
raise TypeError(tangent) # not clear if ever reachable
|
||||
elif t is tuple:
|
||||
if type(tangent) is ad.TangentTuple:
|
||||
return core.pack(map(_convert_zeros, instantiate, example, tangent))
|
||||
elif tangent is ad_util.zero:
|
||||
zeros = [ad_util.zero] * len(instantiate)
|
||||
return core.pack(map(_convert_zeros, instantiate, example, zeros))
|
||||
else:
|
||||
raise TypeError(tangent)
|
||||
else:
|
||||
return tangent
|
||||
raise TypeError(t)
|
||||
|
||||
def _demote_aval_rank(xs):
|
||||
assert isinstance(xs, core.AbstractValue)
|
||||
@ -641,7 +650,7 @@ def _scan_partial_eval(trace, *tracers, **kwargs):
|
||||
length = kwargs.pop('length')
|
||||
forward = kwargs.pop('forward')
|
||||
assert not kwargs
|
||||
in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
||||
in_pvs, _ = unzip2([t.pval for t in tracers])
|
||||
sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs)
|
||||
|
||||
sc_carry = sc_init
|
||||
@ -819,7 +828,19 @@ def _make_typed_jaxpr(traceable, in_avals):
|
||||
class FixedPointError(Exception): pass
|
||||
|
||||
|
||||
# We use a custom bind for scan just to add some error checks
|
||||
def scan_bind(consts, init, xs, forward, length, jaxpr):
|
||||
if not core.skip_checks:
|
||||
assert type(jaxpr.in_avals) is tuple
|
||||
consts_aval, init_aval, xs_aval = jaxpr.in_avals
|
||||
assert type(jaxpr.out_aval) is core.AbstractTuple
|
||||
carry_aval, y_aval = jaxpr.out_aval
|
||||
assert init_aval == carry_aval
|
||||
return core.Primitive.bind(scan_p, consts, init, xs,
|
||||
forward=forward, length=length, jaxpr=jaxpr)
|
||||
|
||||
scan_p = core.Primitive("scan")
|
||||
scan_p.def_custom_bind(scan_bind)
|
||||
scan_p.def_impl(_scan_impl)
|
||||
ad.primitive_jvps[scan_p] = _scan_jvp
|
||||
ad.primitive_transposes[scan_p] = _scan_transpose
|
||||
|
@ -592,6 +592,37 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
expected = (onp.zeros_like(W_trans), onp.zeros_like(W_out))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testIssue711(self):
|
||||
# Tests reverse-mode differentiation through a scan for which the scanned
|
||||
# function also involves reverse-mode differentiation.
|
||||
# See https://github.com/google/jax/issues/711
|
||||
def harmonic_bond(conf, params):
|
||||
return np.sum(conf * params)
|
||||
|
||||
def minimize_structure(test_params):
|
||||
energy_fn = partial(harmonic_bond, params=test_params)
|
||||
grad_fn = api.grad(energy_fn)
|
||||
|
||||
def apply_carry(carry, _):
|
||||
i, x = carry
|
||||
new_x = x - 0.1 * api.grad(energy_fn)(x)
|
||||
new_carry = (i+1, new_x)
|
||||
return new_carry, _
|
||||
|
||||
x0 = np.array([1., 2., 3.])
|
||||
carry_final, _ = lax.scan(apply_carry, (0, x0), np.zeros((75, 0)))
|
||||
_, x_final = carry_final
|
||||
return x_final
|
||||
|
||||
initial_params = 0.5
|
||||
minimize_structure(initial_params) # doesn't crash
|
||||
|
||||
def loss(test_params):
|
||||
x_final = minimize_structure(test_params)
|
||||
return np.sum(np.sin(1.0 - x_final))
|
||||
|
||||
api.grad(loss)(0.25) # doesn't crash
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user