[pytrees] fix function underlying tree-flattening with keys

There were two bugs in the _generate_keypaths function underlying tree_flatten_with_path, leading to disagreement between `len(tree_flatten(x)[0])` and `len(tree_flatten_with_path(x)[0])` for some `x`
1. pytree nodes that weren't registered as pytree-nodes-with-keys were treated as leaves
2. namedtuples that were registered as pytree nodes were being flattened as generic namedtuples rather than using the explicitly registered flattener
This commit is contained in:
Matthew Johnson 2023-03-17 19:08:53 -07:00
parent bab1098866
commit 82c0035a50
2 changed files with 24 additions and 6 deletions

View File

@ -648,18 +648,24 @@ def _generate_key_paths_(
if is_leaf and is_leaf(tree):
yield key_path, tree
return
handler = _registry_with_keypaths.get(type(tree))
if handler:
key_children, _ = handler.flatten_with_keys(tree)
key_handler = _registry_with_keypaths.get(type(tree))
handler = _registry.get(type(tree))
if key_handler:
key_children, _ = key_handler.flatten_with_keys(tree)
for k, c in key_children:
yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf)
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
elif handler:
children, _ = handler.to_iter(tree)
for i, c in enumerate(children):
k = FlattenedIndexKey(i)
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
# handle namedtuple as a special case, based on heuristic
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
for k, c in key_children:
yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf)
elif tree is not None: # Some strictly leaf type, like int or numpy array
yield key_path, tree
else:
yield key_path, tree # strict leaf type
def tree_map_with_path(f: Callable[..., Any],

View File

@ -37,6 +37,10 @@ ATuple = collections.namedtuple("ATuple", ("foo", "bar"))
class ANamedTupleSubclass(ATuple):
pass
ATuple2 = collections.namedtuple("ATuple2", ("foo", "bar"))
tree_util.register_pytree_node(ATuple2, lambda o: ((o.foo,), o.bar),
lambda bar, foo: ATuple2(foo[0], bar))
class AnObject:
def __init__(self, x, y, z):
@ -558,6 +562,14 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(tree_util.tree_flatten(tree)[0],
[l for _, l in tree_util.tree_flatten_with_path(tree)[0]])
def testPyTreeWithoutKeysIsntTreatedAsLeaf(self):
leaves, _ = tree_util.tree_flatten_with_path(Special([1, 2], [3, 4]))
self.assertLen(leaves, 4)
def testNamedTupleRegisteredWithoutKeysIsntTreatedAsLeaf(self):
leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi'))
self.assertLen(leaves, 1)
class RavelUtilTest(jtu.JaxTestCase):