Merge pull request #16018 from ZacCranko:tree_reduce_is_leaf

PiperOrigin-RevId: 534165099
This commit is contained in:
jax authors 2023-05-22 13:31:04 -07:00
commit 13f5090c4c
2 changed files with 14 additions and 5 deletions

View File

@ -263,22 +263,26 @@ no_initializer = object()
@overload
def tree_reduce(function: Callable[[T, Any], T],
tree: Any) -> T:
tree: Any,
*,
is_leaf: Optional[Callable[[Any], bool]] = None) -> T:
...
@overload
def tree_reduce(function: Callable[[T, Any], T],
tree: Any,
initializer: T) -> T:
initializer: T,
is_leaf: Optional[Callable[[Any], bool]] = None) -> T:
...
def tree_reduce(function: Callable[[T, Any], T],
tree: Any,
initializer: Any = no_initializer) -> T:
initializer: Any = no_initializer,
is_leaf: Optional[Callable[[Any], bool]] = None) -> T:
if initializer is no_initializer:
return functools.reduce(function, tree_leaves(tree))
return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf))
else:
return functools.reduce(function, tree_leaves(tree), initializer)
return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf), initializer)
def tree_all(tree: Any) -> bool:
return all(tree_leaves(tree))

View File

@ -298,6 +298,11 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(out, (((1, [3]), (2, None)),
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
def testTreeReduceWithIsLeafArgument(self):
out = tree_util.tree_reduce(lambda x, y: x + y, [(1, 2), [(3, 4), (5, 6)]],
is_leaf=lambda l: isinstance(l, tuple))
self.assertEqual(out, (1, 2, 3, 4, 5, 6))
@parameterized.parameters(
tree_util.tree_leaves,
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0])