[jax] set leaf and node counts when creating a tuple pytree definition

PiperOrigin-RevId: 388479354
This commit is contained in:
Roy Frostig 2021-08-03 09:53:53 -07:00 committed by jax authors
parent 0d8ef03a93
commit 86c48ccb7c

View File

@ -19,6 +19,7 @@ import re
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import test_util as jtu
from jax import tree_util
from jax._src.tree_util import _process_pytree
@ -202,6 +203,18 @@ class TreeTest(jtu.JaxTestCase):
_, c1 = tree_util.tree_flatten((7,))
self.assertEqual([c0, c1], tree.children())
def testTreedefTupleFromChildren(self):
# https://github.com/google/jax/issues/7377
# TODO(frostig): remove after the minimum jaxlib is is 0.1.70 or newer.
if jax.lib._xla_extension_version < 29:
self.skipTest("fixed in future jaxlib")
tree = ((1, 2, (3, 4)), (5,))
leaves, treedef1 = tree_util.tree_flatten(tree)
treedef2 = tree_util.treedef_tuple(treedef1.children())
self.assertEqual(treedef1.num_leaves, len(leaves))
self.assertEqual(treedef1.num_leaves, treedef2.num_leaves)
self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
def testFlattenUpTo(self):
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
out = tree.flatten_up_to([({