mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
promise to flatten trees in left-to-right order
This commit is contained in:
parent
a636bd3468
commit
8677d99267
@ -43,6 +43,10 @@ else:
|
||||
def tree_flatten(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
|
||||
"""Flattens a pytree.
|
||||
|
||||
The flattening order (i.e. the order of elements in the output list)
|
||||
is deterministic, corresponding to a left-to-right depth-first tree
|
||||
traversal.
|
||||
|
||||
Args:
|
||||
tree: a pytree to flatten.
|
||||
is_leaf: an optionally specified function that will be called at each
|
||||
|
@ -230,6 +230,14 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(tree_util.tree_structure((3,)),
|
||||
tree_util.treedef_tuple((tree_util.tree_structure(3),)))
|
||||
|
||||
def testFlattenOrder(self):
|
||||
flat1, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9])
|
||||
flat2, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9])
|
||||
flat3, _ = tree_util.tree_flatten([0, ((1, (2, 3)), (4, (5, 6, 7))), 8, 9])
|
||||
self.assertEqual(flat1, list(range(10)))
|
||||
self.assertEqual(flat2, list(range(10)))
|
||||
self.assertEqual(flat3, list(range(10)))
|
||||
|
||||
def testFlattenUpTo(self):
|
||||
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
|
||||
out = tree.flatten_up_to([({
|
||||
|
Loading…
x
Reference in New Issue
Block a user