mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax] set leaf and node counts when creating a tuple pytree definition
PiperOrigin-RevId: 388479354
This commit is contained in:
parent
0d8ef03a93
commit
86c48ccb7c
@ -19,6 +19,7 @@ import re
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.tree_util import _process_pytree
|
||||
@ -202,6 +203,18 @@ class TreeTest(jtu.JaxTestCase):
|
||||
_, c1 = tree_util.tree_flatten((7,))
|
||||
self.assertEqual([c0, c1], tree.children())
|
||||
|
||||
def testTreedefTupleFromChildren(self):
|
||||
# https://github.com/google/jax/issues/7377
|
||||
# TODO(frostig): remove after the minimum jaxlib is is 0.1.70 or newer.
|
||||
if jax.lib._xla_extension_version < 29:
|
||||
self.skipTest("fixed in future jaxlib")
|
||||
tree = ((1, 2, (3, 4)), (5,))
|
||||
leaves, treedef1 = tree_util.tree_flatten(tree)
|
||||
treedef2 = tree_util.treedef_tuple(treedef1.children())
|
||||
self.assertEqual(treedef1.num_leaves, len(leaves))
|
||||
self.assertEqual(treedef1.num_leaves, treedef2.num_leaves)
|
||||
self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
|
||||
|
||||
def testFlattenUpTo(self):
|
||||
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
|
||||
out = tree.flatten_up_to([({
|
||||
|
Loading…
x
Reference in New Issue
Block a user