mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pytrees] fix function underlying tree-flattening with keys
There were two bugs in the _generate_keypaths function underlying tree_flatten_with_path, leading to disagreement between `len(tree_flatten(x)[0])` and `len(tree_flatten_with_path(x)[0])` for some `x` 1. pytree nodes that weren't registered as pytree-nodes-with-keys were treated as leaves 2. namedtuples that were registered as pytree nodes were being flattened as generic namedtuples rather than using the explicitly registered flattener
This commit is contained in:
parent
bab1098866
commit
82c0035a50
@ -648,18 +648,24 @@ def _generate_key_paths_(
|
||||
if is_leaf and is_leaf(tree):
|
||||
yield key_path, tree
|
||||
return
|
||||
handler = _registry_with_keypaths.get(type(tree))
|
||||
if handler:
|
||||
key_children, _ = handler.flatten_with_keys(tree)
|
||||
key_handler = _registry_with_keypaths.get(type(tree))
|
||||
handler = _registry.get(type(tree))
|
||||
if key_handler:
|
||||
key_children, _ = key_handler.flatten_with_keys(tree)
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf)
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
elif handler:
|
||||
children, _ = handler.to_iter(tree)
|
||||
for i, c in enumerate(children):
|
||||
k = FlattenedIndexKey(i)
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf)
|
||||
elif tree is not None: # Some strictly leaf type, like int or numpy array
|
||||
yield key_path, tree
|
||||
else:
|
||||
yield key_path, tree # strict leaf type
|
||||
|
||||
|
||||
def tree_map_with_path(f: Callable[..., Any],
|
||||
|
@ -37,6 +37,10 @@ ATuple = collections.namedtuple("ATuple", ("foo", "bar"))
|
||||
class ANamedTupleSubclass(ATuple):
|
||||
pass
|
||||
|
||||
ATuple2 = collections.namedtuple("ATuple2", ("foo", "bar"))
|
||||
tree_util.register_pytree_node(ATuple2, lambda o: ((o.foo,), o.bar),
|
||||
lambda bar, foo: ATuple2(foo[0], bar))
|
||||
|
||||
class AnObject:
|
||||
|
||||
def __init__(self, x, y, z):
|
||||
@ -558,6 +562,14 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(tree_util.tree_flatten(tree)[0],
|
||||
[l for _, l in tree_util.tree_flatten_with_path(tree)[0]])
|
||||
|
||||
def testPyTreeWithoutKeysIsntTreatedAsLeaf(self):
|
||||
leaves, _ = tree_util.tree_flatten_with_path(Special([1, 2], [3, 4]))
|
||||
self.assertLen(leaves, 4)
|
||||
|
||||
def testNamedTupleRegisteredWithoutKeysIsntTreatedAsLeaf(self):
|
||||
leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi'))
|
||||
self.assertLen(leaves, 1)
|
||||
|
||||
|
||||
class RavelUtilTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user