From 8677d99267c679525e9335b18decb3b4ed986e42 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 28 Jul 2022 19:22:55 -0700 Subject: [PATCH] promise to flatten trees in left-to-right order --- jax/_src/tree_util.py | 4 ++++ tests/tree_util_test.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 458635b91..d90257e1f 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -43,6 +43,10 @@ else: def tree_flatten(tree, is_leaf: Optional[Callable[[Any], bool]] = None): """Flattens a pytree. + The flattening order (i.e. the order of elements in the output list) + is deterministic, corresponding to a left-to-right depth-first tree + traversal. + Args: tree: a pytree to flatten. is_leaf: an optionally specified function that will be called at each diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 735d1086f..a95347fde 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -230,6 +230,14 @@ class TreeTest(jtu.JaxTestCase): self.assertEqual(tree_util.tree_structure((3,)), tree_util.treedef_tuple((tree_util.tree_structure(3),))) + def testFlattenOrder(self): + flat1, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9]) + flat2, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9]) + flat3, _ = tree_util.tree_flatten([0, ((1, (2, 3)), (4, (5, 6, 7))), 8, 9]) + self.assertEqual(flat1, list(range(10))) + self.assertEqual(flat2, list(range(10))) + self.assertEqual(flat3, list(range(10))) + def testFlattenUpTo(self): _, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)]) out = tree.flatten_up_to([({