unify tree_map and tree_multimap

This commit is contained in:
Roy Frostig 2021-04-28 19:44:20 -07:00
parent 4fcfaeb8a9
commit b8f9dd6269
2 changed files with 3 additions and 22 deletions

View File

@ -20,7 +20,6 @@ List of Functions
tree_flatten
tree_leaves
tree_map
tree_multimap
tree_reduce
tree_structure
tree_transpose

View File

@ -141,28 +141,8 @@ def register_pytree_node_class(cls):
register_pytree_node(cls, op.methodcaller('tree_flatten'), cls.tree_unflatten)
return cls
def tree_map(f: Callable[[Any], Any], tree: Any,
def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
is_leaf: Optional[Callable[[Any], bool]] = None) -> Any:
"""Maps a function over a pytree to produce a new pytree.
Args:
f: unary function to be applied at each leaf.
tree: a pytree to be mapped over.
is_leaf: an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether
the flattening should traverse the current object, or if it should be
stopped immediately, with the whole subtree being treated as a leaf.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by ``f(x)`` where ``x`` is the value at the corresponding leaf in
the input ``tree``.
"""
leaves, treedef = tree_flatten(tree, is_leaf)
return treedef.unflatten(map(f, leaves))
def tree_multimap(f: Callable[..., Any], tree: Any, *rest: Any,
is_leaf: Optional[Callable[[Any], bool]] = None) -> Any:
"""Maps a multi-input function over pytree args to produce a new pytree.
Args:
@ -187,6 +167,8 @@ def tree_multimap(f: Callable[..., Any], tree: Any, *rest: Any,
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
tree_multimap = tree_map
# TODO(mattjj,phawkins): consider removing this function
def _process_pytree(process_node, tree):
leaves, treedef = pytree.flatten(tree)