add a test for tree_reduce with is_leaf argument

This commit is contained in:
Matthew Johnson 2023-05-16 15:37:20 -07:00
parent 75fc830f26
commit 42b2a80df2

View File

@ -298,6 +298,11 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(out, (((1, [3]), (2, None)),
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
def testTreeReduceWithIsLeafArgument(self):
out = tree_util.tree_reduce(lambda x, y: x + y, [(1, 2), [(3, 4), (5, 6)]],
is_leaf=lambda l: isinstance(l, tuple))
self.assertEqual(out, (1, 2, 3, 4, 5, 6))
@parameterized.parameters(
tree_util.tree_leaves,
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0])