make namedtuples transparent (act as pytree nodes)

This commit is contained in:
Matthew Johnson 2019-05-20 10:08:33 -07:00
parent 5cbaf75d28
commit 88f691f896

View File

@ -56,7 +56,7 @@ def tree_map(f, tree):
leaf given by `f(x)` where `x` is the value at the corresponding leaf in
`tree`.
"""
node_type = node_types.get(type(tree))
node_type = _get_node_type(tree)
if node_type:
children, node_spec = node_type.to_iterable(tree)
new_children = [tree_map(f, child) for child in children]
@ -79,12 +79,12 @@ def tree_multimap(f, tree, *rest):
leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.
"""
node_type = node_types.get(type(tree))
node_type = _get_node_type(tree)
if node_type:
children, aux_data = node_type.to_iterable(tree)
all_children = [children]
for other_tree in rest:
other_node_type = node_types.get(type(other_tree))
other_node_type = _get_node_type(other_tree)
if node_type != other_node_type:
raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type))
other_children, other_aux_data = node_type.to_iterable(other_tree)
@ -113,7 +113,7 @@ def process_pytree(process_node, tree):
def walk_pytree(f_node, f_leaf, tree):
node_type = node_types.get(type(tree))
node_type = _get_node_type(tree)
if node_type:
children, node_spec = node_type.to_iterable(tree)
proc_children, child_specs = unzip2([walk_pytree(f_node, f_leaf, child)
@ -236,3 +236,20 @@ register_pytree_node(tuple, lambda xs: (xs, None), lambda _, xs: tuple(xs))
register_pytree_node(list, lambda xs: (tuple(xs), None), lambda _, xs: list(xs))
register_pytree_node(dict, dict_to_iterable, lambda keys, xs: dict(zip(keys, xs)))
register_pytree_node(type(None), lambda z: ((), None), lambda _, xs: None)
# To handle namedtuples, we can't just use the standard table of node_types
# because every namedtuple creates its own type and thus would require its own
# entry in the table. Instead we use a heuristic check on the type itself to
# decide whether it's a namedtuple type, and if so treat it as a pytree node.
def _get_node_type(maybe_tree):
t = type(maybe_tree)
return node_types.get(t) or _namedtuple_node(t)
def _namedtuple_node(t):
if t.__bases__ == (tuple,) and hasattr(t, '_fields'):
return NamedtupleNode
NamedtupleNode = NodeType('namedtuple',
lambda xs: (tuple(xs), type(xs)),
lambda t, xs: t(*xs))