almost done with a prefix_multimap solution

This commit is contained in:
Matthew Johnson 2019-03-21 19:08:19 -07:00
parent 7be3649744
commit 04cfa11ebe
3 changed files with 38 additions and 29 deletions

View File

@ -28,7 +28,7 @@ import functools
import jax.numpy as np
from jax.core import pack
from jax.util import partial, safe_zip, safe_map, unzip2
from jax.tree_util import tree_multimap, tree_flatten, tree_unflatten
from jax.tree_util import tree_map, prefix_multimap, tree_structure
map = safe_map
zip = safe_zip
@ -41,27 +41,19 @@ def optimizer(opt_maker):
@functools.wraps(init_fun)
def treemapped_init_fun(x0_tree):
x0_flat, treedef = tree_flatten(x0_tree)
state_flat = zip(*map(init_fun, x0_flat))
state_trees = map(partial(tree_unflatten, treedef), state_flat)
assert all(treedef == tree_flatten(tree)[1] for tree in state_trees)
return tuple(state_trees)
return tree_map(init_fun, x0_tree)
@functools.wraps(update_fun)
def treemapped_update_fun(i, grad_tree, state_trees):
grad_flat, treedef = tree_flatten(grad_tree)
state_flat, treedefs = unzip2(map(tree_flatten, state_trees))
assert all(td == treedef for td in treedefs)
state_flat = zip(*map(partial(update_fun, i), grad_flat, zip(*state_flat)))
state_trees = map(partial(tree_unflatten, treedef), state_flat)
return tuple(state_trees)
def treemapped_update_fun(i, grad_tree, state_tree):
tdf = tree_structure(grad_tree)
return prefix_multimap(partial(update_fun, i), tdf, grad_tree, state_tree)
return treemapped_init_fun, treemapped_update_fun
return tree_opt_maker
def iterate(state_trees):
"""Extract the current iterate from an optimizer state."""
return state_trees[0]
raise NotImplementedError # TODO don't have tree structure... assume flat?
get_params = iterate
# optimizers

View File

@ -24,7 +24,7 @@ from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p,
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, zero, Zero)
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
from ..tree_util import process_pytree, build_tree, register_pytree_node, prune, tree_map
from ..tree_util import process_pytree, build_tree, register_pytree_node, tree_map
from ..linear_util import thunk, staged, transformation, transformation_with_aux, wrap_init
from six.moves import builtins, reduce

View File

@ -36,15 +36,17 @@ def tree_map(f, tree):
def tree_multimap(f, tree, *rest):
tree_type = type(tree)
node_type = node_types.get(tree_type)
node_type = node_types.get(type(tree))
if node_type:
children, node_spec = node_type.to_iterable(tree)
all_children = [children]
for other_tree in rest:
other_children, other_node_spec = node_type.to_iterable(other_tree)
if other_node_spec != node_spec:
raise TypeError('Mismatch: {} != {}'.format(other_node_spec, node_spec))
other_node_type = node_types.get(type(other_tree))
if node_type != other_node_type:
raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type))
other_children, other_node_data = node_type.to_iterable(other_tree)
if other_node_data != node_spec:
raise TypeError('Mismatch: {} != {}'.format(other_node_data, node_spec))
all_children.append(other_children)
new_children = [tree_multimap(f, *xs) for xs in zip(*all_children)]
@ -53,6 +55,30 @@ def tree_multimap(f, tree, *rest):
return f(tree, *rest)
def prefix_multimap(f, treedef, tree, *rest):
"""Like tree_multimap but only maps down through a tree prefix."""
if treedef is leaf:
return f(tree, *rest)
else:
node_type = node_types.get(type(tree))
if node_type != treedef.node_type:
raise TypeError('Mismatch: {} != {}'.format(treedef.node_type, node_type))
children, node_data = node_type.to_iterable(tree)
if node_data != treedef.node_data:
raise TypeError('Mismatch: {} != {}'.format(treedef.node_data, node_data))
all_children = [children]
for other_tree in rest:
other_children, other_node_data = node_type.to_iterable(other_tree)
if other_node_data != node_data:
raise TypeError('Mismatch: {} != {}'.format(other_node_data, node_data))
all_children.append(other_children)
all_children = zip(*all_children)
new_children = [tree_multimap_prefix(f, td, *xs)
for td, xs in zip(treedef.children, all_children)]
return node_type.from_iterable(node_data, new_children)
def tree_reduce(f, tree):
flat, _ = tree_flatten(tree)
return reduce(f, flat)
@ -132,15 +158,6 @@ def tree_structure(tree):
return spec
def prune(treedef, tuple_tree):
if treedef is leaf:
return tuple_tree
elif treedef.children:
return tuple(map(prune, treedef.children, tuple_tree))
else:
return ()
class PyTreeDef(object):
def __init__(self, node_type, node_data, children):
self.node_type = node_type