Add jax.tree_util.tree_leaves_with_path(tree).

PiperOrigin-RevId: 539609052
This commit is contained in:
Tom Hennigan 2023-06-12 04:12:55 -07:00 committed by jax authors
parent 9b10384b43
commit ed073aa6c9
4 changed files with 23 additions and 0 deletions

View File

@ -22,6 +22,7 @@ List of Functions
tree_flatten
tree_flatten_with_path
tree_leaves
tree_leaves_with_path
tree_map
tree_map_with_path
tree_reduce

View File

@ -723,6 +723,20 @@ def tree_flatten_with_path(
return _generate_key_paths(tree, is_leaf), tree_def
def tree_leaves_with_path(
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[KeyPath, Any]]:
"""Flattens a pytree like ``tree_leaves``, but also returns each leaf's key path.
Args:
tree: a pytree to flatten. If it contains a custom type, it must be
registered with ``register_pytree_with_keys``.
Returns:
A list of key-leaf pairs, each of which contains a leaf and its key path.
"""
return _generate_key_paths(tree, is_leaf)
def generate_key_paths(
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[KeyPath, Any]]:

View File

@ -60,6 +60,7 @@ from jax._src.tree_util import (
register_pytree_with_keys_class as register_pytree_with_keys_class,
tree_map_with_path as tree_map_with_path,
tree_flatten_with_path as tree_flatten_with_path,
tree_leaves_with_path as tree_leaves_with_path,
keystr as keystr,
SequenceKey as SequenceKey,
DictKey as DictKey,

View File

@ -480,6 +480,13 @@ class TreeTest(jtu.JaxTestCase):
from_one_tree = tree_util.tree_map(lambda a: a + 2, tree1)
self.assertEqual(from_two_trees, from_one_tree)
def testTreeLeavesWithPath(self):
tree = [{i: i for i in range(10)}]
actual = tree_util.tree_leaves_with_path(tree)
expected = [((tree_util.SequenceKey(0), tree_util.DictKey(i)), i)
for i in range(10)]
self.assertEqual(actual, expected)
def testKeyStr(self):
tree1 = [ATuple(12, {'cin': [1, 4, 10], 'bar': None}), jnp.arange(5)]
flattened, _ = tree_util.tree_flatten_with_path(tree1)