2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2019-05-02 08:02:01 -07:00
|
|
|
"""Utilities for working with tree-like container data structures.
|
|
|
|
|
|
|
|
The code here is independent of JAX. The only dependence is on jax.util, which
|
|
|
|
itself has no JAX-specific code.
|
|
|
|
|
|
|
|
This module provides a small set of utility functions for working with tree-like
|
|
|
|
data structures, such as nested tuples, lists, and dicts. We call these
|
|
|
|
structures pytrees. They are trees in that they are defined recursively (any
|
|
|
|
non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and
|
|
|
|
can be operated on recursively (object identity equivalence is not preserved by
|
|
|
|
mapping operations, and the structures cannot contain reference cycles).
|
|
|
|
|
|
|
|
The set of Python types that are considered pytree nodes (e.g. that can be
|
|
|
|
mapped over, rather than treated as leaves) is extensible. There is a single
|
|
|
|
module-level registry of types, and class hierarchy is ignored. By registering a
|
|
|
|
new pytree node type, that type in effect becomes transparent to the utility
|
|
|
|
functions in this file.
|
|
|
|
"""
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
from __future__ import absolute_import
|
2018-11-21 13:27:26 -08:00
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-09 11:38:23 -07:00
|
|
|
import functools
|
2018-11-17 18:03:33 -08:00
|
|
|
from collections import namedtuple
|
|
|
|
import itertools as it
|
2018-11-21 13:20:44 -08:00
|
|
|
from six.moves import reduce
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
from .util import unzip2, concatenate, partial, safe_map
|
|
|
|
|
|
|
|
map = safe_map
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def tree_map(f, tree):
|
2019-03-22 07:12:25 -07:00
|
|
|
"""Map a function over a pytree to produce a new pytree.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
f: function to be applied at each leaf.
|
|
|
|
tree: a pytree to be mapped over.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A new pytree with the same structure as `tree` but with the value at each
|
|
|
|
leaf given by `f(x)` where `x` is the value at the corresponding leaf in
|
|
|
|
`tree`.
|
|
|
|
"""
|
2019-05-20 10:08:33 -07:00
|
|
|
node_type = _get_node_type(tree)
|
2018-11-17 18:03:33 -08:00
|
|
|
if node_type:
|
|
|
|
children, node_spec = node_type.to_iterable(tree)
|
|
|
|
new_children = [tree_map(f, child) for child in children]
|
|
|
|
return node_type.from_iterable(node_spec, new_children)
|
|
|
|
else:
|
|
|
|
return f(tree)
|
|
|
|
|
|
|
|
def tree_multimap(f, tree, *rest):
|
2019-03-22 07:12:25 -07:00
|
|
|
"""Map a multi-input function over pytree args to produce a new pytree.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
f: function that takes `1 + len(rest)` arguments, to be applied at the
|
|
|
|
corresponding leaves of the pytrees.
|
|
|
|
tree: a pytree to be mapped over, with each leaf providing the first
|
|
|
|
positional argument to `f`.
|
|
|
|
*rest: a tuple of pytrees, each with the same structure as `tree`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A new pytree with the same structure as `tree` but with the value at each
|
|
|
|
leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
|
|
|
|
in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.
|
|
|
|
"""
|
2019-05-20 10:08:33 -07:00
|
|
|
node_type = _get_node_type(tree)
|
2018-11-17 18:03:33 -08:00
|
|
|
if node_type:
|
2019-05-03 12:37:14 -07:00
|
|
|
children, aux_data = node_type.to_iterable(tree)
|
2018-11-17 18:03:33 -08:00
|
|
|
all_children = [children]
|
|
|
|
for other_tree in rest:
|
2019-05-20 10:08:33 -07:00
|
|
|
other_node_type = _get_node_type(other_tree)
|
2019-05-03 12:37:14 -07:00
|
|
|
if node_type != other_node_type:
|
|
|
|
raise TypeError('Mismatch: {} != {}'.format(other_node_type, node_type))
|
|
|
|
other_children, other_aux_data = node_type.to_iterable(other_tree)
|
|
|
|
if other_aux_data != aux_data:
|
|
|
|
raise TypeError('Mismatch: {} != {}'.format(other_aux_data, aux_data))
|
2018-11-17 18:03:33 -08:00
|
|
|
all_children.append(other_children)
|
|
|
|
|
|
|
|
new_children = [tree_multimap(f, *xs) for xs in zip(*all_children)]
|
2019-05-03 12:37:14 -07:00
|
|
|
return node_type.from_iterable(aux_data, new_children)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
|
|
|
return f(tree, *rest)
|
|
|
|
|
2019-03-21 19:08:19 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def tree_reduce(f, tree):
|
|
|
|
flat, _ = tree_flatten(tree)
|
|
|
|
return reduce(f, flat)
|
|
|
|
|
|
|
|
|
|
|
|
def tree_all(tree):
|
|
|
|
flat, _ = tree_flatten(tree)
|
|
|
|
return all(flat)
|
|
|
|
|
|
|
|
|
|
|
|
def process_pytree(process_node, tree):
|
|
|
|
return walk_pytree(process_node, lambda x: x, tree)
|
|
|
|
|
|
|
|
|
|
|
|
def walk_pytree(f_node, f_leaf, tree):
|
2019-05-20 10:08:33 -07:00
|
|
|
node_type = _get_node_type(tree)
|
2018-11-17 18:03:33 -08:00
|
|
|
if node_type:
|
|
|
|
children, node_spec = node_type.to_iterable(tree)
|
|
|
|
proc_children, child_specs = unzip2([walk_pytree(f_node, f_leaf, child)
|
|
|
|
for child in children])
|
|
|
|
tree_def = PyTreeDef(node_type, node_spec, child_specs)
|
|
|
|
return f_node(proc_children), tree_def
|
|
|
|
else:
|
|
|
|
return f_leaf(tree), leaf
|
|
|
|
|
|
|
|
|
|
|
|
def build_tree(treedef, xs):
|
|
|
|
if treedef is leaf:
|
|
|
|
return xs
|
|
|
|
else:
|
|
|
|
# We use 'iter' for clearer error messages
|
|
|
|
children = map(build_tree, iter(treedef.children), iter(xs))
|
|
|
|
return treedef.node_type.from_iterable(treedef.node_data, children)
|
|
|
|
|
|
|
|
|
|
|
|
tree_flatten = partial(walk_pytree, concatenate, lambda x: [x])
|
|
|
|
|
2019-01-03 16:14:30 -08:00
|
|
|
def tree_unflatten(treedef, xs):
|
|
|
|
return _tree_unflatten(iter(xs), treedef)
|
|
|
|
|
|
|
|
def _tree_unflatten(xs, treedef):
|
2018-11-21 15:42:53 -08:00
|
|
|
if treedef is leaf:
|
|
|
|
return next(xs)
|
|
|
|
else:
|
2019-01-03 16:14:30 -08:00
|
|
|
children = map(partial(_tree_unflatten, xs), treedef.children)
|
2018-11-21 15:42:53 -08:00
|
|
|
return treedef.node_type.from_iterable(treedef.node_data, children)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-06 11:59:33 -08:00
|
|
|
def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
|
|
|
|
flat, treedef = tree_flatten(pytree_to_transpose)
|
|
|
|
expected_treedef = _nested_treedef(inner_treedef, outer_treedef)
|
|
|
|
if treedef != expected_treedef:
|
|
|
|
raise TypeError("Mismatch\n{}\n != \n{}".format(treedef, expected_treedef))
|
|
|
|
|
|
|
|
inner_size = _num_leaves(inner_treedef)
|
|
|
|
outer_size = _num_leaves(outer_treedef)
|
|
|
|
flat = iter(flat)
|
|
|
|
lol = [[next(flat) for _ in range(inner_size)] for __ in range(outer_size)]
|
|
|
|
transposed_lol = zip(*lol)
|
|
|
|
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
|
|
|
|
return tree_unflatten(inner_treedef, subtrees)
|
|
|
|
|
|
|
|
def _num_leaves(treedef):
|
|
|
|
return 1 if treedef is leaf else sum(map(_num_leaves, treedef.children))
|
|
|
|
|
|
|
|
def _nested_treedef(inner, outer):
|
|
|
|
# just used in tree_transpose error checking
|
|
|
|
if outer is leaf:
|
|
|
|
return inner
|
|
|
|
else:
|
|
|
|
children = map(partial(_nested_treedef, inner), outer.children)
|
|
|
|
return PyTreeDef(outer.node_type, outer.node_data, tuple(children))
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def tree_structure(tree):
|
2019-01-06 11:59:33 -08:00
|
|
|
_, spec = process_pytree(lambda _: None, tree)
|
2018-11-17 18:03:33 -08:00
|
|
|
return spec
|
|
|
|
|
|
|
|
|
|
|
|
class PyTreeDef(object):
|
|
|
|
def __init__(self, node_type, node_data, children):
|
|
|
|
self.node_type = node_type
|
|
|
|
self.node_data = node_data
|
|
|
|
self.children = children
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
if self.node_data is None:
|
|
|
|
data_repr = ""
|
|
|
|
else:
|
|
|
|
data_repr = "[{}]".format(self.node_data)
|
|
|
|
|
|
|
|
return "PyTree({}{}, [{}])".format(self.node_type.name, data_repr,
|
|
|
|
','.join(map(repr, self.children)))
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash((self.node_type, self.node_data, tuple(self.children)))
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
2019-01-06 11:59:33 -08:00
|
|
|
if other is leaf:
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
return (self.node_type == other.node_type and
|
|
|
|
self.node_data == other.node_data and
|
|
|
|
self.children == other.children)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def __ne__(self, other):
|
|
|
|
return not self == other
|
|
|
|
|
|
|
|
|
|
|
|
class PyLeaf(object):
|
|
|
|
def __repr__(self):
|
|
|
|
return '*'
|
|
|
|
|
|
|
|
leaf = PyLeaf()
|
|
|
|
|
|
|
|
def dict_to_iterable(xs):
|
|
|
|
keys = tuple(sorted(xs.keys()))
|
|
|
|
return tuple(map(xs.get, keys)), keys
|
|
|
|
|
|
|
|
class NodeType(object):
|
|
|
|
def __init__(self, name, to_iterable, from_iterable):
|
|
|
|
self.name = name
|
|
|
|
self.to_iterable = to_iterable
|
|
|
|
self.from_iterable = from_iterable
|
|
|
|
|
2019-05-03 12:37:14 -07:00
|
|
|
def __repr__(self):
|
|
|
|
return self.name
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
node_types = {}
|
|
|
|
|
|
|
|
def register_pytree_node(py_type, to_iterable, from_iterable):
|
|
|
|
assert py_type not in node_types
|
|
|
|
node_types[py_type] = NodeType(str(py_type), to_iterable, from_iterable)
|
|
|
|
|
|
|
|
register_pytree_node(tuple, lambda xs: (xs, None), lambda _, xs: tuple(xs))
|
|
|
|
register_pytree_node(list, lambda xs: (tuple(xs), None), lambda _, xs: list(xs))
|
|
|
|
register_pytree_node(dict, dict_to_iterable, lambda keys, xs: dict(zip(keys, xs)))
|
|
|
|
register_pytree_node(type(None), lambda z: ((), None), lambda _, xs: None)
|
2019-05-20 10:08:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
# To handle namedtuples, we can't just use the standard table of node_types
|
|
|
|
# because every namedtuple creates its own type and thus would require its own
|
|
|
|
# entry in the table. Instead we use a heuristic check on the type itself to
|
|
|
|
# decide whether it's a namedtuple type, and if so treat it as a pytree node.
|
|
|
|
def _get_node_type(maybe_tree):
|
|
|
|
t = type(maybe_tree)
|
|
|
|
return node_types.get(t) or _namedtuple_node(t)
|
|
|
|
|
|
|
|
def _namedtuple_node(t):
|
2019-06-03 07:22:32 -07:00
|
|
|
if issubclass(t, tuple) and hasattr(t, '_fields'):
|
2019-05-20 10:08:33 -07:00
|
|
|
return NamedtupleNode
|
|
|
|
|
|
|
|
NamedtupleNode = NodeType('namedtuple',
|
|
|
|
lambda xs: (tuple(xs), type(xs)),
|
|
|
|
lambda t, xs: t(*xs))
|
2019-07-09 11:38:23 -07:00
|
|
|
|
|
|
|
|
|
|
|
class Partial(functools.partial):
|
|
|
|
"""A version of functools.partial that works in pytrees.
|
|
|
|
|
|
|
|
Use it for partial function evaluation in a way that is compatibile with JAX's
|
|
|
|
transformations, e.g., ``Partial(func, *args, **kwargs)``.
|
|
|
|
|
|
|
|
(You need to explicitly opt-in to this behavior because we didn't want to give
|
|
|
|
functools.partial different semantics than normal function closures.)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _partial_to_iterable(partial_):
|
|
|
|
values = partial_.args + tuple(partial_.keywords.values())
|
|
|
|
spec = (partial_.func, len(partial_.args), tuple(partial_.keywords.keys()))
|
|
|
|
return values, spec
|
|
|
|
|
|
|
|
def _iterable_to_partial(spec, values):
|
|
|
|
func, args_count, keys = spec
|
|
|
|
args = values[:args_count]
|
|
|
|
keywords = dict(zip(keys, values[args_count:]))
|
|
|
|
return Partial(func, *args, **keywords)
|
|
|
|
|
|
|
|
register_pytree_node(Partial, _partial_to_iterable, _iterable_to_partial)
|