mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
almost done with a prefix_multimap solution
This commit is contained in:
parent
7be3649744
commit
04cfa11ebe
@ -28,7 +28,7 @@ import functools
|
||||
import jax.numpy as np
|
||||
from jax.core import pack
|
||||
from jax.util import partial, safe_zip, safe_map, unzip2
|
||||
from jax.tree_util import tree_multimap, tree_flatten, tree_unflatten
|
||||
from jax.tree_util import tree_map, prefix_multimap, tree_structure
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
@ -41,27 +41,19 @@ def optimizer(opt_maker):
|
||||
|
||||
@functools.wraps(init_fun)
|
||||
def treemapped_init_fun(x0_tree):
|
||||
x0_flat, treedef = tree_flatten(x0_tree)
|
||||
state_flat = zip(*map(init_fun, x0_flat))
|
||||
state_trees = map(partial(tree_unflatten, treedef), state_flat)
|
||||
assert all(treedef == tree_flatten(tree)[1] for tree in state_trees)
|
||||
return tuple(state_trees)
|
||||
return tree_map(init_fun, x0_tree)
|
||||
|
||||
@functools.wraps(update_fun)
|
||||
def treemapped_update_fun(i, grad_tree, state_trees):
|
||||
grad_flat, treedef = tree_flatten(grad_tree)
|
||||
state_flat, treedefs = unzip2(map(tree_flatten, state_trees))
|
||||
assert all(td == treedef for td in treedefs)
|
||||
state_flat = zip(*map(partial(update_fun, i), grad_flat, zip(*state_flat)))
|
||||
state_trees = map(partial(tree_unflatten, treedef), state_flat)
|
||||
return tuple(state_trees)
|
||||
def treemapped_update_fun(i, grad_tree, state_tree):
|
||||
tdf = tree_structure(grad_tree)
|
||||
return prefix_multimap(partial(update_fun, i), tdf, grad_tree, state_tree)
|
||||
|
||||
return treemapped_init_fun, treemapped_update_fun
|
||||
return tree_opt_maker
|
||||
|
||||
def iterate(state_trees):
|
||||
"""Extract the current iterate from an optimizer state."""
|
||||
return state_trees[0]
|
||||
raise NotImplementedError # TODO don't have tree structure... assume flat?
|
||||
get_params = iterate
|
||||
|
||||
# optimizers
|
||||
|
@ -24,7 +24,7 @@ from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p,
|
||||
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, zero, Zero)
|
||||
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
|
||||
from ..tree_util import process_pytree, build_tree, register_pytree_node, prune, tree_map
|
||||
from ..tree_util import process_pytree, build_tree, register_pytree_node, tree_map
|
||||
from ..linear_util import thunk, staged, transformation, transformation_with_aux, wrap_init
|
||||
|
||||
from six.moves import builtins, reduce
|
||||
|
@ -36,15 +36,17 @@ def tree_map(f, tree):
|
||||
|
||||
|
||||
def tree_multimap(f, tree, *rest):
|
||||
tree_type = type(tree)
|
||||
node_type = node_types.get(tree_type)
|
||||
node_type = node_types.get(type(tree))
|
||||
if node_type:
|
||||
children, node_spec = node_type.to_iterable(tree)
|
||||
all_children = [children]
|
||||
for other_tree in rest:
|
||||
other_children, other_node_spec = node_type.to_iterable(other_tree)
|
||||
if other_node_spec != node_spec:
|
||||
raise TypeError('Mismatch: {} != {}'.format(other_node_spec, node_spec))
|
||||
other_node_type = node_types.get(type(other_tree))
|
||||
if node_type != other_node_type:
|
||||
raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type))
|
||||
other_children, other_node_data = node_type.to_iterable(other_tree)
|
||||
if other_node_data != node_spec:
|
||||
raise TypeError('Mismatch: {} != {}'.format(other_node_data, node_spec))
|
||||
all_children.append(other_children)
|
||||
|
||||
new_children = [tree_multimap(f, *xs) for xs in zip(*all_children)]
|
||||
@ -53,6 +55,30 @@ def tree_multimap(f, tree, *rest):
|
||||
return f(tree, *rest)
|
||||
|
||||
|
||||
def prefix_multimap(f, treedef, tree, *rest):
|
||||
"""Like tree_multimap but only maps down through a tree prefix."""
|
||||
if treedef is leaf:
|
||||
return f(tree, *rest)
|
||||
else:
|
||||
node_type = node_types.get(type(tree))
|
||||
if node_type != treedef.node_type:
|
||||
raise TypeError('Mismatch: {} != {}'.format(treedef.node_type, node_type))
|
||||
children, node_data = node_type.to_iterable(tree)
|
||||
if node_data != treedef.node_data:
|
||||
raise TypeError('Mismatch: {} != {}'.format(treedef.node_data, node_data))
|
||||
all_children = [children]
|
||||
for other_tree in rest:
|
||||
other_children, other_node_data = node_type.to_iterable(other_tree)
|
||||
if other_node_data != node_data:
|
||||
raise TypeError('Mismatch: {} != {}'.format(other_node_data, node_data))
|
||||
all_children.append(other_children)
|
||||
all_children = zip(*all_children)
|
||||
|
||||
new_children = [tree_multimap_prefix(f, td, *xs)
|
||||
for td, xs in zip(treedef.children, all_children)]
|
||||
return node_type.from_iterable(node_data, new_children)
|
||||
|
||||
|
||||
def tree_reduce(f, tree):
|
||||
flat, _ = tree_flatten(tree)
|
||||
return reduce(f, flat)
|
||||
@ -132,15 +158,6 @@ def tree_structure(tree):
|
||||
return spec
|
||||
|
||||
|
||||
def prune(treedef, tuple_tree):
|
||||
if treedef is leaf:
|
||||
return tuple_tree
|
||||
elif treedef.children:
|
||||
return tuple(map(prune, treedef.children, tuple_tree))
|
||||
else:
|
||||
return ()
|
||||
|
||||
|
||||
class PyTreeDef(object):
|
||||
def __init__(self, node_type, node_data, children):
|
||||
self.node_type = node_type
|
||||
|
Loading…
x
Reference in New Issue
Block a user