add back an OptState that is a proper pytree

This commit is contained in:
Matthew Johnson 2019-03-21 19:23:00 -07:00
parent 04cfa11ebe
commit 6fdf64d7ee
2 changed files with 22 additions and 11 deletions

View File

@ -22,13 +22,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import operator
import collections
import functools
import operator
import jax.numpy as np
from jax.core import pack
from jax.util import partial, safe_zip, safe_map, unzip2
from jax.tree_util import tree_map, prefix_multimap, tree_structure
from jax.tree_util import (tree_map, prefix_multimap, tree_structure,
register_pytree_node)
map = safe_map
zip = safe_zip
@ -40,22 +42,31 @@ def optimizer(opt_maker):
init_fun, update_fun = opt_maker(*args, **kwargs)
@functools.wraps(init_fun)
def treemapped_init_fun(x0_tree):
return tree_map(init_fun, x0_tree)
def tree_init_fun(x0_tree):
prefix = tree_structure(x0_tree)
return OptState(prefix, tree_map(init_fun, x0_tree))
@functools.wraps(update_fun)
def treemapped_update_fun(i, grad_tree, state_tree):
tdf = tree_structure(grad_tree)
return prefix_multimap(partial(update_fun, i), tdf, grad_tree, state_tree)
def tree_update_fun(i, grad_tree, state_tree):
assert type(state_tree) is OptState
prefix, tree = state_tree
tree = prefix_multimap(partial(update_fun, i), prefix, grad_tree, tree)
return OptState(prefix, tree)
return treemapped_init_fun, treemapped_update_fun
return tree_init_fun, tree_update_fun
return tree_opt_maker
def iterate(state_trees):
def iterate(state_tree):
"""Extract the current iterate from an optimizer state."""
raise NotImplementedError # TODO don't have tree structure... assume flat?
assert type(state_tree) is OptState
prefix, tree = state_tree
return prefix_multimap(operator.itemgetter(0), prefix, tree)
get_params = iterate
OptState = collections.namedtuple("OptState", ["prefix", "tree"])
register_pytree_node(OptState, lambda xs: ([xs.tree], xs.prefix),
lambda prefix, xs: OptState(prefix, xs[0]))
# optimizers
@optimizer

View File

@ -74,7 +74,7 @@ def prefix_multimap(f, treedef, tree, *rest):
all_children.append(other_children)
all_children = zip(*all_children)
new_children = [tree_multimap_prefix(f, td, *xs)
new_children = [prefix_multimap(f, td, *xs)
for td, xs in zip(treedef.children, all_children)]
return node_type.from_iterable(node_data, new_children)