mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
make namedtuples transparent (act as pytree nodes)
This commit is contained in:
parent
5cbaf75d28
commit
88f691f896
@ -56,7 +56,7 @@ def tree_map(f, tree):
|
||||
leaf given by `f(x)` where `x` is the value at the corresponding leaf in
|
||||
`tree`.
|
||||
"""
|
||||
node_type = node_types.get(type(tree))
|
||||
node_type = _get_node_type(tree)
|
||||
if node_type:
|
||||
children, node_spec = node_type.to_iterable(tree)
|
||||
new_children = [tree_map(f, child) for child in children]
|
||||
@ -79,12 +79,12 @@ def tree_multimap(f, tree, *rest):
|
||||
leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
|
||||
in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.
|
||||
"""
|
||||
node_type = node_types.get(type(tree))
|
||||
node_type = _get_node_type(tree)
|
||||
if node_type:
|
||||
children, aux_data = node_type.to_iterable(tree)
|
||||
all_children = [children]
|
||||
for other_tree in rest:
|
||||
other_node_type = node_types.get(type(other_tree))
|
||||
other_node_type = _get_node_type(other_tree)
|
||||
if node_type != other_node_type:
|
||||
raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type))
|
||||
other_children, other_aux_data = node_type.to_iterable(other_tree)
|
||||
@ -113,7 +113,7 @@ def process_pytree(process_node, tree):
|
||||
|
||||
|
||||
def walk_pytree(f_node, f_leaf, tree):
|
||||
node_type = node_types.get(type(tree))
|
||||
node_type = _get_node_type(tree)
|
||||
if node_type:
|
||||
children, node_spec = node_type.to_iterable(tree)
|
||||
proc_children, child_specs = unzip2([walk_pytree(f_node, f_leaf, child)
|
||||
@ -236,3 +236,20 @@ register_pytree_node(tuple, lambda xs: (xs, None), lambda _, xs: tuple(xs))
|
||||
register_pytree_node(list, lambda xs: (tuple(xs), None), lambda _, xs: list(xs))
|
||||
register_pytree_node(dict, dict_to_iterable, lambda keys, xs: dict(zip(keys, xs)))
|
||||
register_pytree_node(type(None), lambda z: ((), None), lambda _, xs: None)
|
||||
|
||||
|
||||
# To handle namedtuples, we can't just use the standard table of node_types
|
||||
# because every namedtuple creates its own type and thus would require its own
|
||||
# entry in the table. Instead we use a heuristic check on the type itself to
|
||||
# decide whether it's a namedtuple type, and if so treat it as a pytree node.
|
||||
def _get_node_type(maybe_tree):
|
||||
t = type(maybe_tree)
|
||||
return node_types.get(t) or _namedtuple_node(t)
|
||||
|
||||
def _namedtuple_node(t):
|
||||
if t.__bases__ == (tuple,) and hasattr(t, '_fields'):
|
||||
return NamedtupleNode
|
||||
|
||||
NamedtupleNode = NodeType('namedtuple',
|
||||
lambda xs: (tuple(xs), type(xs)),
|
||||
lambda t, xs: t(*xs))
|
||||
|
Loading…
x
Reference in New Issue
Block a user