Merge pull request #1102 from georgedahl/fix

Fix pack_optimizer_state to correctly use tuples everywhere in the pa…
This commit is contained in:
James Bradbury 2019-08-03 21:40:33 -07:00 committed by GitHub
commit 72c0dcf1cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 1 deletions

View File

@ -506,5 +506,6 @@ def pack_optimizer_state(marked_pytree):
sentinels, tree_def = tree_flatten(marked_pytree)
assert all(isinstance(s, JoinPoint) for s in sentinels)
subtrees = [s.subtree for s in sentinels]
packed_state, subtree_defs = unzip2(map(tree_flatten, subtrees))
states_flat, subtree_defs = unzip2(map(tree_flatten, subtrees))
packed_state = pack(map(pack, states_flat))
return OptimizerState(packed_state, tree_def, subtree_defs)

View File

@ -288,6 +288,13 @@ class OptimizerTests(jtu.JaxTestCase):
J2 = jacfwd(loss, argnums=(0,))(initial_params)
self.assertAllClose(J1, J2, check_dtypes=True)
def testUnpackPackRoundTrip(self):
opt_init, _, _ = optimizers.momentum(0.1, mass=0.9)
params = [{'w': onp.random.randn(1, 2), 'bias': onp.random.randn(2)}]
expected = opt_init(params)
ans = optimizers.pack_optimizer_state(
optimizers.unpack_optimizer_state(expected))
self.assertEqual(ans, expected)
if __name__ == '__main__':
absltest.main()