2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2019-05-02 08:02:01 -07:00
|
|
|
"""Utilities for working with tree-like container data structures.
|
|
|
|
|
|
|
|
This module provides a small set of utility functions for working with tree-like
|
|
|
|
data structures, such as nested tuples, lists, and dicts. We call these
|
|
|
|
structures pytrees. They are trees in that they are defined recursively (any
|
|
|
|
non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and
|
|
|
|
can be operated on recursively (object identity equivalence is not preserved by
|
|
|
|
mapping operations, and the structures cannot contain reference cycles).
|
|
|
|
|
|
|
|
The set of Python types that are considered pytree nodes (e.g. that can be
|
|
|
|
mapped over, rather than treated as leaves) is extensible. There is a single
|
|
|
|
module-level registry of types, and class hierarchy is ignored. By registering a
|
|
|
|
new pytree node type, that type in effect becomes transparent to the utility
|
|
|
|
functions in this file.
|
2019-08-23 16:54:59 -07:00
|
|
|
|
|
|
|
The primary purpose of this module is to enable the interoperability between
|
2019-08-23 19:32:45 -07:00
|
|
|
user defined data structures and JAX transformations (e.g. `jit`). This is not
|
|
|
|
meant to be a general purpose tree-like data structure handling library.
|
2019-10-27 10:29:33 +01:00
|
|
|
|
|
|
|
See the `JAX pytrees notebook <https://jax.readthedocs.io/en/latest/notebooks/JAX_pytrees.html>`_
|
|
|
|
for examples.
|
2019-05-02 08:02:01 -07:00
|
|
|
"""
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-09 11:38:23 -07:00
|
|
|
import functools
|
2019-10-10 10:19:43 -04:00
|
|
|
import collections
|
2020-03-10 15:01:18 -07:00
|
|
|
import operator as op
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-29 15:06:05 -04:00
|
|
|
from .lib import pytree
|
|
|
|
|
2019-10-31 14:09:12 -07:00
|
|
|
from .util import partial, safe_zip, unzip2
|
2019-07-29 15:06:05 -04:00
|
|
|
|
2019-10-27 10:29:33 +01:00
|
|
|
def tree_flatten(tree):
|
|
|
|
"""Flattens a pytree.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tree: a pytree to flatten.
|
2020-06-12 15:41:07 -07:00
|
|
|
|
2019-10-27 10:29:33 +01:00
|
|
|
Returns:
|
2020-06-12 15:41:07 -07:00
|
|
|
A pair where the first element is a list of leaf values and the second
|
|
|
|
element is a treedef representing the structure of the flattened tree.
|
2019-10-27 10:29:33 +01:00
|
|
|
"""
|
|
|
|
return pytree.flatten(tree)
|
|
|
|
|
|
|
|
def tree_unflatten(treedef, leaves):
|
|
|
|
"""Reconstructs a pytree from the treedef and the leaves.
|
|
|
|
|
2020-06-12 15:41:07 -07:00
|
|
|
The inverse of :func:`tree_flatten`.
|
2019-10-27 10:29:33 +01:00
|
|
|
|
|
|
|
Args:
|
|
|
|
treedef: the treedef to reconstruct
|
2020-06-12 15:41:07 -07:00
|
|
|
leaves: the list of leaves to use for reconstruction. The list must match
|
|
|
|
the leaves of the treedef.
|
|
|
|
|
2019-10-27 10:29:33 +01:00
|
|
|
Returns:
|
2020-06-12 15:41:07 -07:00
|
|
|
The reconstructed pytree, containing the ``leaves`` placed in the structure
|
|
|
|
described by ``treedef``.
|
2019-10-27 10:29:33 +01:00
|
|
|
"""
|
|
|
|
return treedef.unflatten(leaves)
|
|
|
|
|
|
|
|
def tree_leaves(tree):
|
|
|
|
"""Gets the leaves of a pytree."""
|
|
|
|
return pytree.flatten(tree)[0]
|
|
|
|
|
|
|
|
def tree_structure(tree):
|
|
|
|
"""Gets the treedef for a pytree."""
|
|
|
|
return pytree.flatten(tree)[1]
|
|
|
|
|
|
|
|
def treedef_tuple(treedefs):
|
|
|
|
"""Makes a tuple treedef from a list of child treedefs."""
|
|
|
|
return pytree.tuple(list(treedefs))
|
|
|
|
|
|
|
|
def treedef_children(treedef):
|
|
|
|
return treedef.children()
|
|
|
|
|
|
|
|
def treedef_is_leaf(treedef):
|
|
|
|
return treedef.num_nodes == 1
|
|
|
|
|
2020-03-28 13:14:40 +00:00
|
|
|
def all_leaves(iterable):
|
|
|
|
"""Tests whether all elements in the given iterable are all leaves.
|
|
|
|
|
|
|
|
>>> tree = {"a": [1, 2, 3]}
|
|
|
|
>>> assert all_leaves(jax.tree_leaves(tree))
|
|
|
|
>>> assert not all_leaves([tree])
|
|
|
|
|
|
|
|
This function is useful in advanced cases, for example if a library allows
|
|
|
|
arbitrary map operations on a flat list of leaves it may want to check if
|
|
|
|
the result is still a flat list of leaves.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
iterable: Iterable of leaves.
|
|
|
|
|
|
|
|
Returns:
|
2020-06-12 15:41:07 -07:00
|
|
|
A boolean indicating if all elements in the input are leaves.
|
2020-03-28 13:14:40 +00:00
|
|
|
"""
|
|
|
|
return pytree.all_leaves(iterable)
|
|
|
|
|
2019-10-27 10:29:33 +01:00
|
|
|
def register_pytree_node(nodetype, flatten_func, unflatten_func):
|
|
|
|
"""Extends the set of types that are considered internal nodes in pytrees.
|
|
|
|
|
|
|
|
See `example usage <https://jax.readthedocs.io/en/latest/notebooks/JAX_pytrees.html#Pytrees-are-extensible>`_.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
nodetype: a Python type to treat as an internal pytree node.
|
2020-06-12 15:41:07 -07:00
|
|
|
flatten_func: a function to be used during flattening, taking a value of
|
|
|
|
type ``nodetype`` and returning a pair, with (1) an iterable for the
|
|
|
|
children to be flattened recursively, and (2) some auxiliary data to be
|
|
|
|
stored in the treedef and to be passed to the ``unflatten_func``.
|
|
|
|
unflatten_func: a function taking two arguments: the auxiliary data that was
|
|
|
|
returned by ``flatten_func`` and stored in the treedef, and the
|
2019-10-27 10:29:33 +01:00
|
|
|
unflattened children. The function should return an instance of
|
2020-06-12 15:41:07 -07:00
|
|
|
``nodetype``.
|
2019-10-27 10:29:33 +01:00
|
|
|
"""
|
|
|
|
pytree.register_node(nodetype, flatten_func, unflatten_func)
|
|
|
|
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
|
2019-07-29 15:06:05 -04:00
|
|
|
|
2020-03-10 15:01:18 -07:00
|
|
|
def register_pytree_node_class(cls):
|
|
|
|
"""Extends the set of types that are considered internal nodes in pytrees.
|
|
|
|
|
|
|
|
This function is a thin wrapper around ``register_pytree_node``, and provides
|
|
|
|
a class-oriented interface:
|
|
|
|
|
|
|
|
@register_pytree_node_class
|
|
|
|
class Special:
|
|
|
|
def __init__(self, x, y):
|
|
|
|
self.x = x
|
|
|
|
self.y = y
|
|
|
|
def tree_flatten(self):
|
|
|
|
return ((self.x, self.y), None)
|
|
|
|
@classmethod
|
|
|
|
def tree_unflatten(cls, aux_data, children):
|
|
|
|
return cls(*children)
|
|
|
|
"""
|
|
|
|
register_pytree_node(cls, op.methodcaller('tree_flatten'), cls.tree_unflatten)
|
|
|
|
return cls
|
|
|
|
|
2019-09-02 07:25:06 -07:00
|
|
|
def tree_map(f, tree):
|
2019-10-27 10:29:33 +01:00
|
|
|
"""Maps a function over a pytree to produce a new pytree.
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-09-02 07:25:06 -07:00
|
|
|
Args:
|
2020-06-12 15:41:07 -07:00
|
|
|
f: unary function to be applied at each leaf.
|
2019-09-02 07:25:06 -07:00
|
|
|
tree: a pytree to be mapped over.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A new pytree with the same structure as `tree` but with the value at each
|
2020-06-12 15:41:07 -07:00
|
|
|
leaf given by ``f(x)`` where ``x`` is the value at the corresponding leaf in
|
|
|
|
the input ``tree``.
|
2019-09-02 07:25:06 -07:00
|
|
|
"""
|
|
|
|
leaves, treedef = pytree.flatten(tree)
|
|
|
|
return treedef.unflatten(map(f, leaves))
|
|
|
|
|
|
|
|
def tree_multimap(f, tree, *rest):
|
2019-10-27 10:29:33 +01:00
|
|
|
"""Maps a multi-input function over pytree args to produce a new pytree.
|
2019-09-02 07:25:06 -07:00
|
|
|
|
|
|
|
Args:
|
2020-06-12 15:41:07 -07:00
|
|
|
f: function that takes ``1 + len(rest)`` arguments, to be applied at the
|
2019-09-02 07:25:06 -07:00
|
|
|
corresponding leaves of the pytrees.
|
|
|
|
tree: a pytree to be mapped over, with each leaf providing the first
|
2020-06-12 15:41:07 -07:00
|
|
|
positional argument to ``f``.
|
2019-09-02 07:25:06 -07:00
|
|
|
*rest: a tuple of pytrees, each of which has the same structure as tree or
|
|
|
|
or has tree as a prefix.
|
2020-06-12 15:41:07 -07:00
|
|
|
|
2019-09-02 07:25:06 -07:00
|
|
|
Returns:
|
2020-06-12 15:41:07 -07:00
|
|
|
A new pytree with the same structure as ``tree`` but with the value at each
|
|
|
|
leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding
|
|
|
|
leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in
|
|
|
|
``rest``.
|
2019-09-02 07:25:06 -07:00
|
|
|
"""
|
|
|
|
leaves, treedef = pytree.flatten(tree)
|
|
|
|
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
|
|
|
|
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
|
|
|
|
|
2019-10-31 14:09:12 -07:00
|
|
|
# TODO(mattjj,phawkins): consider removing this function
|
|
|
|
def _process_pytree(process_node, tree):
|
2019-09-02 07:25:06 -07:00
|
|
|
leaves, treedef = pytree.flatten(tree)
|
|
|
|
return treedef.walk(process_node, None, leaves), treedef
|
|
|
|
|
|
|
|
def build_tree(treedef, xs):
|
|
|
|
return treedef.from_iterable_tree(xs)
|
|
|
|
|
|
|
|
def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
|
|
|
|
flat, treedef = tree_flatten(pytree_to_transpose)
|
|
|
|
expected_treedef = outer_treedef.compose(inner_treedef)
|
|
|
|
if treedef != expected_treedef:
|
|
|
|
raise TypeError("Mismatch\n{}\n != \n{}".format(treedef, expected_treedef))
|
|
|
|
|
|
|
|
inner_size = inner_treedef.num_leaves
|
|
|
|
outer_size = outer_treedef.num_leaves
|
|
|
|
flat = iter(flat)
|
|
|
|
lol = [[next(flat) for _ in range(inner_size)] for __ in range(outer_size)]
|
|
|
|
transposed_lol = zip(*lol)
|
|
|
|
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
|
|
|
|
return tree_unflatten(inner_treedef, subtrees)
|
|
|
|
|
2019-10-31 14:09:12 -07:00
|
|
|
# TODO(mattjj): remove the Python-side registry when the C++-side registry is
|
|
|
|
# sufficiently queryable that we can express _replace_nones. That may mean once
|
|
|
|
# we have a flatten_one function.
|
|
|
|
_RegistryEntry = collections.namedtuple("RegistryEntry", ["to_iter", "from_iter"])
|
|
|
|
_registry = {
|
|
|
|
tuple: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: tuple(xs)),
|
|
|
|
list: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: list(xs)),
|
|
|
|
dict: _RegistryEntry(lambda xs: unzip2(sorted(xs.items()))[::-1],
|
|
|
|
lambda keys, xs: dict(zip(keys, xs))),
|
|
|
|
type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None),
|
|
|
|
}
|
|
|
|
def _replace_nones(sentinel, tree):
|
2020-06-12 15:41:07 -07:00
|
|
|
"""Replaces ``None`` in ``tree`` with ``sentinel``."""
|
2019-10-31 14:09:12 -07:00
|
|
|
if tree is None:
|
|
|
|
return sentinel
|
|
|
|
else:
|
|
|
|
handler = _registry.get(type(tree))
|
|
|
|
if handler:
|
|
|
|
children, metadata = handler.to_iter(tree)
|
|
|
|
proc_children = [_replace_nones(sentinel, child) for child in children]
|
|
|
|
return handler.from_iter(metadata, proc_children)
|
|
|
|
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
|
|
|
|
# handle namedtuple as a special case, based on heuristic
|
|
|
|
children = iter(tree)
|
|
|
|
proc_children = [_replace_nones(sentinel, child) for child in children]
|
|
|
|
return type(tree)(*proc_children)
|
|
|
|
else:
|
|
|
|
return tree
|
2019-05-20 10:08:33 -07:00
|
|
|
|
2020-05-05 10:11:10 +02:00
|
|
|
no_initializer = object()
|
|
|
|
def tree_reduce(function, tree, initializer=no_initializer):
|
|
|
|
if initializer is no_initializer:
|
|
|
|
return functools.reduce(function, tree_leaves(tree))
|
|
|
|
else:
|
|
|
|
return functools.reduce(function, tree_leaves(tree), initializer)
|
2019-05-20 10:08:33 -07:00
|
|
|
|
2019-07-29 15:06:05 -04:00
|
|
|
def tree_all(tree):
|
|
|
|
return all(tree_leaves(tree))
|
2019-07-09 11:38:23 -07:00
|
|
|
|
2019-10-10 10:19:43 -04:00
|
|
|
register_pytree_node(
|
|
|
|
collections.OrderedDict,
|
|
|
|
lambda x: (list(x.values()), list(x.keys())),
|
|
|
|
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)))
|
|
|
|
|
2019-12-21 23:38:33 +00:00
|
|
|
register_pytree_node(
|
|
|
|
collections.defaultdict,
|
|
|
|
lambda x: (tuple(x.values()), (x.default_factory, tuple(x.keys()))),
|
|
|
|
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)))
|
|
|
|
|
2019-07-09 11:38:23 -07:00
|
|
|
|
|
|
|
class Partial(functools.partial):
|
|
|
|
"""A version of functools.partial that works in pytrees.
|
|
|
|
|
2019-07-20 08:44:04 +01:00
|
|
|
Use it for partial function evaluation in a way that is compatible with JAX's
|
2019-07-09 11:38:23 -07:00
|
|
|
transformations, e.g., ``Partial(func, *args, **kwargs)``.
|
|
|
|
|
|
|
|
(You need to explicitly opt-in to this behavior because we didn't want to give
|
|
|
|
functools.partial different semantics than normal function closures.)
|
|
|
|
"""
|
|
|
|
|
2019-07-09 12:05:59 -07:00
|
|
|
register_pytree_node(
|
|
|
|
Partial,
|
|
|
|
lambda partial_: ((partial_.args, partial_.keywords), partial_.func),
|
|
|
|
lambda func, xs: Partial(func, *xs[0], **xs[1]),
|
|
|
|
)
|