rocm_jax/jax/_src/tree_util.py
Yash Katariya a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00

571 lines
21 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import difflib
import functools
from functools import partial
import operator as op
from typing import (Any, Callable, Dict, Hashable, Iterable, List, NamedTuple,
Optional, Tuple, Type, TypeVar, overload)
import textwrap
import warnings
from jax._src.lib import pytree
from jax._src.util import safe_zip, unzip2
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
T = TypeVar("T")
U = TypeVar("U", bound=Type[Any])
Leaf = Any
PyTreeDef = pytree.PyTreeDef
def tree_flatten(tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None
) -> Tuple[List[Leaf], PyTreeDef]:
"""Flattens a pytree.
The flattening order (i.e. the order of elements in the output list)
is deterministic, corresponding to a left-to-right depth-first tree
traversal.
Args:
tree: a pytree to flatten.
is_leaf: an optionally specified function that will be called at each
flattening step. It should return a boolean, with true stopping the
traversal and the whole subtree being treated as a leaf, and false
indicating the flattening should traverse the current object.
Returns:
A pair where the first element is a list of leaf values and the second
element is a treedef representing the structure of the flattened tree.
"""
return pytree.flatten(tree, is_leaf)
def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
"""Reconstructs a pytree from the treedef and the leaves.
The inverse of :func:`tree_flatten`.
Args:
treedef: the treedef to reconstruct
leaves: the iterable of leaves to use for reconstruction. The iterable
must match the leaves of the treedef.
Returns:
The reconstructed pytree, containing the ``leaves`` placed in the structure
described by ``treedef``.
"""
return treedef.unflatten(leaves)
def tree_leaves(tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None
) -> List[Leaf]:
"""Gets the leaves of a pytree."""
return pytree.flatten(tree, is_leaf)[0]
def tree_structure(tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTreeDef:
"""Gets the treedef for a pytree."""
return pytree.flatten(tree, is_leaf)[1]
def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef:
"""Makes a tuple treedef from an iterable of child treedefs."""
return pytree.tuple(list(treedefs))
def treedef_children(treedef: PyTreeDef) -> List[PyTreeDef]:
return treedef.children()
def treedef_is_leaf(treedef: PyTreeDef) -> bool:
return treedef.num_nodes == 1
def treedef_is_strict_leaf(treedef: PyTreeDef) -> bool:
return treedef.num_nodes == 1 and treedef.num_leaves == 1
def all_leaves(iterable: Iterable[Any],
is_leaf: Optional[Callable[[Any], bool]] = None) -> bool:
"""Tests whether all elements in the given iterable are all leaves.
>>> tree = {"a": [1, 2, 3]}
>>> assert all_leaves(jax.tree_util.tree_leaves(tree))
>>> assert not all_leaves([tree])
This function is useful in advanced cases, for example if a library allows
arbitrary map operations on a flat iterable of leaves it may want to check
if the result is still a flat iterable of leaves.
Args:
iterable: Iterable of leaves.
Returns:
A boolean indicating if all elements in the input are leaves.
"""
if is_leaf is None:
return pytree.all_leaves(iterable)
else:
lst = list(iterable)
return lst == tree_leaves(lst, is_leaf)
_Children = TypeVar("_Children", bound=Iterable[Any])
_AuxData = TypeVar("_AuxData", bound=Hashable)
def register_pytree_node(nodetype: Type[T],
flatten_func: Callable[[T], Tuple[_Children, _AuxData]],
unflatten_func: Callable[[_AuxData, _Children], T]):
"""Extends the set of types that are considered internal nodes in pytrees.
See :ref:`example usage <pytrees>`.
Args:
nodetype: a Python type to treat as an internal pytree node.
flatten_func: a function to be used during flattening, taking a value of
type ``nodetype`` and returning a pair, with (1) an iterable for the
children to be flattened recursively, and (2) some hashable auxiliary
data to be stored in the treedef and to be passed to the
``unflatten_func``.
unflatten_func: a function taking two arguments: the auxiliary data that was
returned by ``flatten_func`` and stored in the treedef, and the
unflattened children. The function should return an instance of
``nodetype``.
"""
pytree.register_node(nodetype, flatten_func, unflatten_func)
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
def register_pytree_node_class(cls: U) -> U:
"""Extends the set of types that are considered internal nodes in pytrees.
This function is a thin wrapper around ``register_pytree_node``, and provides
a class-oriented interface::
@register_pytree_node_class
class Special:
def __init__(self, x, y):
self.x = x
self.y = y
def tree_flatten(self):
return ((self.x, self.y), None)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
"""
register_pytree_node(cls, op.methodcaller('tree_flatten'), cls.tree_unflatten)
return cls
def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
is_leaf: Optional[Callable[[Any], bool]] = None) -> Any:
"""Maps a multi-input function over pytree args to produce a new pytree.
Args:
f: function that takes ``1 + len(rest)`` arguments, to be applied at the
corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf providing the first
positional argument to ``f``.
rest: a tuple of pytrees, each of which has the same structure as ``tree``
or has ``tree`` as a prefix.
is_leaf: an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether
the flattening should traverse the current object, or if it should be
stopped immediately, with the whole subtree being treated as a leaf.
Returns:
A new pytree with the same structure as ``tree`` but with the value at each
leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding
leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in
``rest``.
Examples:
>>> import jax.tree_util
>>> jax.tree_util.tree_map(lambda x: x + 1, {"x": 7, "y": 42})
{'x': 8, 'y': 43}
If multiple inputs are passed, the structure of the tree is taken from the
first input; subsequent inputs need only have ``tree`` as a prefix:
>>> jax.tree_util.tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
[[5, 7, 9], [6, 1, 2]]
"""
leaves, treedef = tree_flatten(tree, is_leaf)
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
return treedef.from_iterable_tree(xs)
def tree_transpose(outer_treedef: PyTreeDef,
inner_treedef: PyTreeDef,
pytree_to_transpose: Any) -> Any:
"""Transform a tree having tree structure (outer, inner) into one having structure
(inner, outer).
"""
flat, treedef = tree_flatten(pytree_to_transpose)
inner_size = inner_treedef.num_leaves
outer_size = outer_treedef.num_leaves
if treedef.num_leaves != (inner_size * outer_size):
expected_treedef = outer_treedef.compose(inner_treedef)
raise TypeError(f"Mismatch\n{treedef}\n != \n{expected_treedef}")
iter_flat = iter(flat)
lol = [[next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)]
transposed_lol = zip(*lol)
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
return tree_unflatten(inner_treedef, subtrees)
# TODO(mattjj): remove the Python-side registry when the C++-side registry is
# sufficiently queryable that we can express _replace_nones. That may mean once
# we have a flatten_one function.
_RegistryEntry = collections.namedtuple("_RegistryEntry", ["to_iter", "from_iter"])
_registry = {
tuple: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: tuple(xs)),
list: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: list(xs)),
dict: _RegistryEntry(lambda xs: unzip2(sorted(xs.items()))[::-1],
lambda keys, xs: dict(zip(keys, xs))),
type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None),
}
def _replace_nones(sentinel, tree):
"""Replaces ``None`` in ``tree`` with ``sentinel``."""
if tree is None:
return sentinel
else:
handler = _registry.get(type(tree))
if handler:
children, metadata = handler.to_iter(tree)
proc_children = [_replace_nones(sentinel, child) for child in children]
return handler.from_iter(metadata, proc_children)
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
# handle namedtuple as a special case, based on heuristic
children = iter(tree)
proc_children = [_replace_nones(sentinel, child) for child in children]
return type(tree)(*proc_children)
else:
return tree
no_initializer = object()
@overload
def tree_reduce(function: Callable[[T, Any], T],
tree: Any) -> T:
...
@overload
def tree_reduce(function: Callable[[T, Any], T],
tree: Any,
initializer: T) -> T:
...
def tree_reduce(function: Callable[[T, Any], T],
tree: Any,
initializer: Any = no_initializer) -> T:
if initializer is no_initializer:
return functools.reduce(function, tree_leaves(tree))
else:
return functools.reduce(function, tree_leaves(tree), initializer)
def tree_all(tree: Any) -> bool:
return all(tree_leaves(tree))
register_pytree_node(
collections.OrderedDict,
lambda x: (tuple(x.values()), tuple(x.keys())),
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)))
register_pytree_node(
collections.defaultdict,
lambda x: (tuple(x.values()), (x.default_factory, tuple(x.keys()))),
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values))) # type: ignore[index,call-overload]
class _HashableCallableShim:
"""Object that delegates __call__, __hash__, and __eq__ to another object."""
def __init__(self, fun):
self.fun = fun
def __call__(self, *args, **kw):
return self.fun(*args, **kw)
def __hash__(self):
return hash(self.fun)
def __eq__(self, other):
if isinstance(other, _HashableCallableShim):
return self.fun == other.fun
return self.fun == other
class Partial(functools.partial):
"""A version of functools.partial that works in pytrees.
Use it for partial function evaluation in a way that is compatible with JAX's
transformations, e.g., ``Partial(func, *args, **kwargs)``.
(You need to explicitly opt-in to this behavior because we didn't want to give
functools.partial different semantics than normal function closures.)
For example, here is a basic usage of ``Partial`` in a manner similar to
``functools.partial``:
>>> import jax.numpy as jnp
>>> add_one = Partial(jnp.add, 1)
>>> add_one(2)
Array(3, dtype=int32, weak_type=True)
Pytree compatibility means that the resulting partial function can be passed
as an argument within transformed JAX functions, which is not possible with a
standard ``functools.partial`` function:
>>> from jax import jit
>>> @jit
... def call_func(f, *args):
... return f(*args)
...
>>> call_func(add_one, 2)
Array(3, dtype=int32, weak_type=True)
Passing zero arguments to ``Partial`` effectively wraps the original function,
making it a valid argument in JAX transformed functions:
>>> call_func(Partial(jnp.add), 1, 2)
Array(3, dtype=int32, weak_type=True)
Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in a
``TypeError``.
Note that if the result of ``Partial`` is used in the context where the
value is traced, it results in all bound arguments being traced when passed
to the partially-evaluated function:
>>> print_zero = Partial(print, 0)
>>> print_zero()
0
>>> call_func(print_zero)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
"""
def __new__(klass, func, *args, **kw):
# In Python 3.10+, if func is itself a functools.partial instance,
# functools.partial.__new__ would merge the arguments of this Partial
# instance with the arguments of the func. We box func in a class that does
# not (yet) have a `func` attribute to defeat this optimization, since we
# care exactly which arguments are considered part of the pytree.
if isinstance(func, functools.partial):
original_func = func
func = _HashableCallableShim(original_func)
out = super().__new__(klass, func, *args, **kw)
func.func = original_func.func
func.args = original_func.args
func.keywords = original_func.keywords
return out
else:
return super().__new__(klass, func, *args, **kw)
register_pytree_node(
Partial,
lambda partial_: ((partial_.args, partial_.keywords), partial_.func),
lambda func, xs: Partial(func, *xs[0], **xs[1]), # type: ignore[index]
)
def broadcast_prefix(prefix_tree: Any, full_tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None
) -> List[Any]:
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
# ValueError; use prefix_errors to find disagreements and raise more precise
# error messages.
result = []
num_leaves = lambda t: tree_structure(t).num_leaves
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
return result
def flatten_one_level(pytree: Any) -> Tuple[List[Any], Hashable]:
handler = _registry.get(type(pytree))
if handler:
children, meta = handler.to_iter(pytree)
return list(children), meta
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
return list(pytree), None
else:
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
def prefix_errors(prefix_tree: Any, full_tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> List[Callable[[str], ValueError]]:
return list(_prefix_error(KeyPath(()), prefix_tree, full_tree, is_leaf))
class KeyPathEntry(NamedTuple):
key: Any
def pprint(self) -> str:
assert False # must override
class KeyPath(NamedTuple):
keys: Tuple[KeyPathEntry, ...]
def __add__(self, other):
if isinstance(other, KeyPathEntry):
return KeyPath(self.keys + (other,))
raise TypeError(type(other))
def pprint(self, root: str = ' tree root') -> str:
if not self.keys:
return root
return ''.join(k.pprint() for k in self.keys)
class GetitemKeyPathEntry(KeyPathEntry):
def pprint(self) -> str:
return f'[{repr(self.key)}]'
class AttributeKeyPathEntry(KeyPathEntry):
def pprint(self) -> str:
return f'.{self.key}'
class FlattenedKeyPathEntry(KeyPathEntry): # fallback
def pprint(self) -> str:
return f'[<flat index {self.key}>]'
def _child_keys(pytree: Any) -> List[KeyPathEntry]:
assert not treedef_is_strict_leaf(tree_structure(pytree))
handler = _keypath_registry.get(type(pytree))
if handler:
return handler(pytree)
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
# handle namedtuple as a special case, based on heuristic
return [AttributeKeyPathEntry(s) for s in pytree._fields]
else:
num_children = len(treedef_children(tree_structure(pytree)))
return [FlattenedKeyPathEntry(i) for i in range(num_children)]
_keypath_registry: Dict[Type, Callable[[Any], List[KeyPathEntry]]] = {}
def register_keypaths(ty: Type, handler: Callable[[Any], List[KeyPathEntry]]
) -> None:
_keypath_registry[ty] = handler
register_keypaths(tuple,
lambda tup: [GetitemKeyPathEntry(i) for i in range(len(tup))])
register_keypaths(list,
lambda lst: [GetitemKeyPathEntry(i) for i in range(len(lst))])
register_keypaths(dict,
lambda dct: [GetitemKeyPathEntry(k) for k in sorted(dct)])
def _generate_key_paths(tree: Any) -> List[Tuple[KeyPath, Any]]:
return list(_generate_key_paths_(KeyPath(()), tree))
def _generate_key_paths_(key_path: KeyPath, tree: Any
) -> Iterable[Tuple[KeyPath, Any]]:
if treedef_is_strict_leaf(tree_structure(tree)):
yield key_path, tree
else:
child_keys = _child_keys(tree)
tree_children, _ = flatten_one_level(tree)
for k, c in zip(child_keys, tree_children):
yield from _generate_key_paths_(key_path + k, c)
def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Iterable[Callable[[str], ValueError]]:
# A leaf is a valid prefix of any tree:
if treedef_is_strict_leaf(tree_structure(prefix_tree, is_leaf=is_leaf)): return
# The subtrees may disagree because their roots are of different types:
if type(prefix_tree) != type(full_tree):
yield lambda name: ValueError(
"pytree structure error: different types at key path\n"
f" {{name}}{key_path.pprint()}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"but at the same key path the full pytree has a subtree of different type\n"
f" {type(full_tree)}.".format(name=name))
return # don't look for more errors in this subtree
# Or they may disagree if their roots have different numbers of children (note
# that because both prefix_tree and full_tree have the same type at this
# point, and because prefix_tree is not a leaf, each can be flattened once):
prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
full_tree_children, full_tree_meta = flatten_one_level(full_tree)
prefix_tree_keys = _child_keys(prefix_tree)
full_tree_keys = _child_keys(full_tree)
try:
diff = set(prefix_tree_keys).symmetric_difference(set(full_tree_keys))
except:
diff = None
if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
"pytree structure error: different numbers of pytree children at key path\n"
f" {{name}}{key_path.pprint()}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"with {len(prefix_tree_children)} child keys\n"
f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n"
f"but at the same key path the full pytree has a subtree of the same "
f"type but with {len(full_tree_children)} child keys\n"
f" {' '.join(str(k.key) for k in full_tree_keys)}\n"
.format(name=name)
+ ("" if diff is None else
f"so the symmetric difference on key sets is\n"
f" {' '.join(str(k.key) for k in diff)}"))
return # don't look for more errors in this subtree
# Or they may disagree if their roots have different pytree metadata:
if prefix_tree_meta != full_tree_meta:
prefix_tree_meta_str = str(prefix_tree_meta)
full_tree_meta_str = str(full_tree_meta)
metadata_diff = textwrap.indent(
'\n'.join(difflib.ndiff(prefix_tree_meta_str.splitlines(),
full_tree_meta_str.splitlines())),
prefix=" ")
yield lambda name: ValueError(
"pytree structure error: different pytree metadata at key path\n"
f" {{name}}{key_path.pprint()}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"with metadata\n"
f" {prefix_tree_meta_str}\n"
f"but at the same key path the full pytree has a subtree of the same "
f"type but with metadata\n"
f" {full_tree_meta_str}\n"
f"so the diff in the metadata at these pytree nodes is\n"
f"{metadata_diff}".format(name=name))
return # don't look for more errors in this subtree
# If the root types and numbers of children agree, there must be an error
# in a subtree, so recurse:
assert prefix_tree_keys == full_tree_keys, \
("equal pytree nodes gave differing prefix_tree_keys: "
f"{prefix_tree_keys} and {full_tree_keys}")
for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children):
yield from _prefix_error(key_path + k, t1, t2)
# TODO(jakevdp) remove these deprecated wrappers & their imports in jax/__init__.py
def _deprecate(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
warnings.warn(f"jax.{f.__name__} is deprecated, and will be removed in a future release. "
f"Use jax.tree_util.{f.__name__} instead.",
category=FutureWarning, stacklevel=2)
return f(*args, **kwargs)
return wrapped
def __getattr__(name):
prefix = "_deprecated_"
if name.startswith(prefix):
name = name[len(prefix):]
return _deprecate(globals()[name])
else:
raise AttributeError(f"module {__name__} has no attribute {name!r}")