mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
add a test for tree_reduce with is_leaf argument
This commit is contained in:
parent
75fc830f26
commit
42b2a80df2
@ -298,6 +298,11 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out, (((1, [3]), (2, None)),
|
||||
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
|
||||
|
||||
def testTreeReduceWithIsLeafArgument(self):
|
||||
out = tree_util.tree_reduce(lambda x, y: x + y, [(1, 2), [(3, 4), (5, 6)]],
|
||||
is_leaf=lambda l: isinstance(l, tuple))
|
||||
self.assertEqual(out, (1, 2, 3, 4, 5, 6))
|
||||
|
||||
@parameterized.parameters(
|
||||
tree_util.tree_leaves,
|
||||
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user