mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add jax.tree_util.tree_leaves_with_path(tree).
PiperOrigin-RevId: 539609052
This commit is contained in:
parent
9b10384b43
commit
ed073aa6c9
@ -22,6 +22,7 @@ List of Functions
|
||||
tree_flatten
|
||||
tree_flatten_with_path
|
||||
tree_leaves
|
||||
tree_leaves_with_path
|
||||
tree_map
|
||||
tree_map_with_path
|
||||
tree_reduce
|
||||
|
@ -723,6 +723,20 @@ def tree_flatten_with_path(
|
||||
return _generate_key_paths(tree, is_leaf), tree_def
|
||||
|
||||
|
||||
def tree_leaves_with_path(
|
||||
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
) -> List[Tuple[KeyPath, Any]]:
|
||||
"""Flattens a pytree like ``tree_leaves``, but also returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
tree: a pytree to flatten. If it contains a custom type, it must be
|
||||
registered with ``register_pytree_with_keys``.
|
||||
Returns:
|
||||
A list of key-leaf pairs, each of which contains a leaf and its key path.
|
||||
"""
|
||||
return _generate_key_paths(tree, is_leaf)
|
||||
|
||||
|
||||
def generate_key_paths(
|
||||
tree: Any, is_leaf: Optional[Callable[[Any], bool]] = None
|
||||
) -> List[Tuple[KeyPath, Any]]:
|
||||
|
@ -60,6 +60,7 @@ from jax._src.tree_util import (
|
||||
register_pytree_with_keys_class as register_pytree_with_keys_class,
|
||||
tree_map_with_path as tree_map_with_path,
|
||||
tree_flatten_with_path as tree_flatten_with_path,
|
||||
tree_leaves_with_path as tree_leaves_with_path,
|
||||
keystr as keystr,
|
||||
SequenceKey as SequenceKey,
|
||||
DictKey as DictKey,
|
||||
|
@ -480,6 +480,13 @@ class TreeTest(jtu.JaxTestCase):
|
||||
from_one_tree = tree_util.tree_map(lambda a: a + 2, tree1)
|
||||
self.assertEqual(from_two_trees, from_one_tree)
|
||||
|
||||
def testTreeLeavesWithPath(self):
|
||||
tree = [{i: i for i in range(10)}]
|
||||
actual = tree_util.tree_leaves_with_path(tree)
|
||||
expected = [((tree_util.SequenceKey(0), tree_util.DictKey(i)), i)
|
||||
for i in range(10)]
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def testKeyStr(self):
|
||||
tree1 = [ATuple(12, {'cin': [1, 4, 10], 'bar': None}), jnp.arange(5)]
|
||||
flattened, _ = tree_util.tree_flatten_with_path(tree1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user