promise to flatten trees in left-to-right order

This commit is contained in:
Roy Frostig 2022-07-28 19:22:55 -07:00
parent a636bd3468
commit 8677d99267
2 changed files with 12 additions and 0 deletions

View File

@ -43,6 +43,10 @@ else:
def tree_flatten(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
"""Flattens a pytree.
The flattening order (i.e. the order of elements in the output list)
is deterministic, corresponding to a left-to-right depth-first tree
traversal.
Args:
tree: a pytree to flatten.
is_leaf: an optionally specified function that will be called at each

View File

@ -230,6 +230,14 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(tree_util.tree_structure((3,)),
tree_util.treedef_tuple((tree_util.tree_structure(3),)))
def testFlattenOrder(self):
flat1, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9])
flat2, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9])
flat3, _ = tree_util.tree_flatten([0, ((1, (2, 3)), (4, (5, 6, 7))), 8, 9])
self.assertEqual(flat1, list(range(10)))
self.assertEqual(flat2, list(range(10)))
self.assertEqual(flat3, list(range(10)))
def testFlattenUpTo(self):
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
out = tree.flatten_up_to([({