mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #1102 from georgedahl/fix
Fix pack_optimizer_state to correctly use tuples everywhere in the pa…
This commit is contained in:
commit
72c0dcf1cf
@ -506,5 +506,6 @@ def pack_optimizer_state(marked_pytree):
|
||||
sentinels, tree_def = tree_flatten(marked_pytree)
|
||||
assert all(isinstance(s, JoinPoint) for s in sentinels)
|
||||
subtrees = [s.subtree for s in sentinels]
|
||||
packed_state, subtree_defs = unzip2(map(tree_flatten, subtrees))
|
||||
states_flat, subtree_defs = unzip2(map(tree_flatten, subtrees))
|
||||
packed_state = pack(map(pack, states_flat))
|
||||
return OptimizerState(packed_state, tree_def, subtree_defs)
|
||||
|
@ -288,6 +288,13 @@ class OptimizerTests(jtu.JaxTestCase):
|
||||
J2 = jacfwd(loss, argnums=(0,))(initial_params)
|
||||
self.assertAllClose(J1, J2, check_dtypes=True)
|
||||
|
||||
def testUnpackPackRoundTrip(self):
|
||||
opt_init, _, _ = optimizers.momentum(0.1, mass=0.9)
|
||||
params = [{'w': onp.random.randn(1, 2), 'bias': onp.random.randn(2)}]
|
||||
expected = opt_init(params)
|
||||
ans = optimizers.pack_optimizer_state(
|
||||
optimizers.unpack_optimizer_state(expected))
|
||||
self.assertEqual(ans, expected)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user