mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add back an OptState that is a proper pytree
This commit is contained in:
parent
04cfa11ebe
commit
6fdf64d7ee
@ -22,13 +22,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import operator
|
||||
import collections
|
||||
import functools
|
||||
import operator
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.core import pack
|
||||
from jax.util import partial, safe_zip, safe_map, unzip2
|
||||
from jax.tree_util import tree_map, prefix_multimap, tree_structure
|
||||
from jax.tree_util import (tree_map, prefix_multimap, tree_structure,
|
||||
register_pytree_node)
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
@ -40,22 +42,31 @@ def optimizer(opt_maker):
|
||||
init_fun, update_fun = opt_maker(*args, **kwargs)
|
||||
|
||||
@functools.wraps(init_fun)
|
||||
def treemapped_init_fun(x0_tree):
|
||||
return tree_map(init_fun, x0_tree)
|
||||
def tree_init_fun(x0_tree):
|
||||
prefix = tree_structure(x0_tree)
|
||||
return OptState(prefix, tree_map(init_fun, x0_tree))
|
||||
|
||||
@functools.wraps(update_fun)
|
||||
def treemapped_update_fun(i, grad_tree, state_tree):
|
||||
tdf = tree_structure(grad_tree)
|
||||
return prefix_multimap(partial(update_fun, i), tdf, grad_tree, state_tree)
|
||||
def tree_update_fun(i, grad_tree, state_tree):
|
||||
assert type(state_tree) is OptState
|
||||
prefix, tree = state_tree
|
||||
tree = prefix_multimap(partial(update_fun, i), prefix, grad_tree, tree)
|
||||
return OptState(prefix, tree)
|
||||
|
||||
return treemapped_init_fun, treemapped_update_fun
|
||||
return tree_init_fun, tree_update_fun
|
||||
return tree_opt_maker
|
||||
|
||||
def iterate(state_trees):
|
||||
def iterate(state_tree):
|
||||
"""Extract the current iterate from an optimizer state."""
|
||||
raise NotImplementedError # TODO don't have tree structure... assume flat?
|
||||
assert type(state_tree) is OptState
|
||||
prefix, tree = state_tree
|
||||
return prefix_multimap(operator.itemgetter(0), prefix, tree)
|
||||
get_params = iterate
|
||||
|
||||
OptState = collections.namedtuple("OptState", ["prefix", "tree"])
|
||||
register_pytree_node(OptState, lambda xs: ([xs.tree], xs.prefix),
|
||||
lambda prefix, xs: OptState(prefix, xs[0]))
|
||||
|
||||
# optimizers
|
||||
|
||||
@optimizer
|
||||
|
@ -74,7 +74,7 @@ def prefix_multimap(f, treedef, tree, *rest):
|
||||
all_children.append(other_children)
|
||||
all_children = zip(*all_children)
|
||||
|
||||
new_children = [tree_multimap_prefix(f, td, *xs)
|
||||
new_children = [prefix_multimap(f, td, *xs)
|
||||
for td, xs in zip(treedef.children, all_children)]
|
||||
return node_type.from_iterable(node_data, new_children)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user