mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
unify tree_map and tree_multimap
This commit is contained in:
parent
4fcfaeb8a9
commit
b8f9dd6269
@ -20,7 +20,6 @@ List of Functions
|
||||
tree_flatten
|
||||
tree_leaves
|
||||
tree_map
|
||||
tree_multimap
|
||||
tree_reduce
|
||||
tree_structure
|
||||
tree_transpose
|
||||
|
@ -141,28 +141,8 @@ def register_pytree_node_class(cls):
|
||||
register_pytree_node(cls, op.methodcaller('tree_flatten'), cls.tree_unflatten)
|
||||
return cls
|
||||
|
||||
def tree_map(f: Callable[[Any], Any], tree: Any,
|
||||
def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None) -> Any:
|
||||
"""Maps a function over a pytree to produce a new pytree.
|
||||
|
||||
Args:
|
||||
f: unary function to be applied at each leaf.
|
||||
tree: a pytree to be mapped over.
|
||||
is_leaf: an optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, which indicates whether
|
||||
the flattening should traverse the current object, or if it should be
|
||||
stopped immediately, with the whole subtree being treated as a leaf.
|
||||
|
||||
Returns:
|
||||
A new pytree with the same structure as `tree` but with the value at each
|
||||
leaf given by ``f(x)`` where ``x`` is the value at the corresponding leaf in
|
||||
the input ``tree``.
|
||||
"""
|
||||
leaves, treedef = tree_flatten(tree, is_leaf)
|
||||
return treedef.unflatten(map(f, leaves))
|
||||
|
||||
def tree_multimap(f: Callable[..., Any], tree: Any, *rest: Any,
|
||||
is_leaf: Optional[Callable[[Any], bool]] = None) -> Any:
|
||||
"""Maps a multi-input function over pytree args to produce a new pytree.
|
||||
|
||||
Args:
|
||||
@ -187,6 +167,8 @@ def tree_multimap(f: Callable[..., Any], tree: Any, *rest: Any,
|
||||
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
|
||||
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
|
||||
|
||||
tree_multimap = tree_map
|
||||
|
||||
# TODO(mattjj,phawkins): consider removing this function
|
||||
def _process_pytree(process_node, tree):
|
||||
leaves, treedef = pytree.flatten(tree)
|
||||
|
Loading…
x
Reference in New Issue
Block a user