C++ tree with path API

* Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening.

* Moves all the key classes down to C++ level, while keeping the APIs unchanged.
  * Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy.

* Registered defaultdict and ordereddict via the keypath API now.

PiperOrigin-RevId: 701613257
This commit is contained in:
Ivy Zheng 2024-11-30 21:26:07 -08:00 committed by jax authors
parent db4b3f2922
commit a1dfdc1d61
3 changed files with 288 additions and 45 deletions

View File

@ -22,10 +22,11 @@ import functools
from functools import partial
import operator as op
import textwrap
from typing import Any, NamedTuple, TypeVar, Union, overload
from typing import Any, NamedTuple, TypeVar, overload
from jax._src import traceback_util
from jax._src.lib import pytree
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip, set_module
from jax._src.util import unzip2
@ -209,12 +210,21 @@ def all_leaves(iterable: Iterable[Any],
_Children = TypeVar("_Children", bound=Iterable[Any])
_AuxData = TypeVar("_AuxData", bound=Hashable)
KeyEntry = TypeVar("KeyEntry", bound=Any)
KeyLeafPair = tuple[KeyEntry, Any]
KeyLeafPairs = Iterable[KeyLeafPair]
KeyPath = tuple[KeyEntry, ...]
@export
def register_pytree_node(nodetype: type[T],
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
unflatten_func: Callable[[_AuxData, _Children], T]) -> None:
def register_pytree_node(
nodetype: type[T],
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
unflatten_func: Callable[[_AuxData, _Children], T],
flatten_with_keys_func: (
Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None
) = None,
) -> None:
"""Extends the set of types that are considered internal nodes in pytrees.
See :ref:`example usage <pytrees>`.
@ -279,9 +289,20 @@ def register_pytree_node(nodetype: type[T],
>>> jax.jit(f)(m)
Array([1., 2., 3., 4., 5.], dtype=float32)
"""
default_registry.register_node(nodetype, flatten_func, unflatten_func)
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
if xla_extension_version >= 299:
default_registry.register_node( # type: ignore[call-arg]
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
)
none_leaf_registry.register_node( # type: ignore[call-arg]
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
)
dispatch_registry.register_node( # type: ignore[call-arg]
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
)
else:
default_registry.register_node(nodetype, flatten_func, unflatten_func)
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
@ -452,21 +473,6 @@ def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool
return all(tree_leaves(tree, is_leaf=is_leaf))
register_pytree_node(
collections.OrderedDict,
lambda x: (tuple(x.values()), tuple(x.keys())),
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)))
def _flatten_defaultdict(d):
keys = tuple(sorted(d))
return tuple(d[k] for k in keys), (d.default_factory, keys)
register_pytree_node(
collections.defaultdict,
_flatten_defaultdict,
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)))
class _HashableCallableShim:
"""Object that delegates __call__, __hash__, and __eq__ to another object."""
@ -578,11 +584,11 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
# flatten_one_level is not exported.
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]:
"""Flatten the given pytree node by one level.
Args:
pytree: A valid pytree node, either built-in or registered via
tree: A valid pytree node, either built-in or registered via
:func:`register_pytree_node` or related functions.
Returns:
@ -601,9 +607,9 @@ def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
>>> meta
('a', 'b')
"""
out = default_registry.flatten_one_level(pytree)
out = default_registry.flatten_one_level(tree)
if out is None:
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
raise ValueError(f"can't tree-flatten type: {type(tree)}")
else:
return out
@ -739,10 +745,12 @@ class FlattenedIndexKey():
def __str__(self):
return f'[<flat index {self.key}>]'
BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey]
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
KeyPath = tuple[KeyEntry, ...]
if xla_extension_version >= 299:
SequenceKey = pytree.SequenceKey # type: ignore
DictKey = pytree.DictKey # type: ignore
GetAttrKey = pytree.GetAttrKey # type: ignore
FlattenedIndexKey = pytree.FlattenedIndexKey # type: ignore
@export
@ -764,6 +772,7 @@ def keystr(keys: KeyPath):
return ''.join(map(str, keys))
# TODO(ivyzheng): remove this after _child_keys() also moved to C++.
class _RegistryWithKeypathsEntry(NamedTuple):
flatten_with_keys: Callable[..., Any]
unflatten_func: Callable[..., Any]
@ -780,7 +789,6 @@ def _register_keypaths(
flatten_with_keys, _registry[ty].from_iter
)
_registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {}
_register_keypaths(
@ -803,13 +811,9 @@ _register_keypaths(
@export
def register_pytree_with_keys(
nodetype: type[T],
flatten_with_keys: Callable[
[T], tuple[Iterable[tuple[KeyEntry, Any]], _AuxData]
],
flatten_with_keys: Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]],
unflatten_func: Callable[[_AuxData, Iterable[Any]], T],
flatten_func: None | (
Callable[[T], tuple[Iterable[Any], _AuxData]]
) = None,
flatten_func: None | (Callable[[T], tuple[Iterable[Any], _AuxData]]) = None,
):
"""Extends the set of types that are considered internal nodes in pytrees.
@ -870,7 +874,9 @@ def register_pytree_with_keys(
return [c for _, c in key_children], treedef
flatten_func = flatten_func_impl
register_pytree_node(nodetype, flatten_func, unflatten_func)
register_pytree_node(
nodetype, flatten_func, unflatten_func, flatten_with_keys
)
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
flatten_with_keys, unflatten_func
)
@ -1092,6 +1098,40 @@ def register_dataclass(
return nodetype
if xla_extension_version >= 299:
register_pytree_with_keys(
collections.OrderedDict,
lambda x: (tuple((DictKey(k), x[k]) for k in x.keys()), tuple(x.keys())),
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)),
)
def _flatten_defaultdict_with_keys(d):
keys = tuple(sorted(d))
return tuple((DictKey(k), d[k]) for k in keys), (d.default_factory, keys)
register_pytree_with_keys(
collections.defaultdict,
_flatten_defaultdict_with_keys,
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)),
)
else:
register_pytree_node(
collections.OrderedDict,
lambda x: (tuple(x.values()), tuple(x.keys())),
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)),
)
def _flatten_defaultdict(d):
keys = tuple(sorted(d))
return tuple(d[k] for k in keys), (d.default_factory, keys)
register_pytree_node(
collections.defaultdict,
_flatten_defaultdict,
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)),
)
@export
def register_static(cls: type[H]) -> type[H]:
"""Registers `cls` as a pytree with no leaves.
@ -1144,6 +1184,8 @@ def tree_flatten_with_path(
which contains a leaf and its key path. The second element is a treedef
representing the structure of the flattened tree.
"""
if xla_extension_version >= 299:
return default_registry.flatten_with_path(tree, is_leaf)
_, tree_def = tree_flatten(tree, is_leaf)
return _generate_key_paths(tree, is_leaf), tree_def
@ -1164,13 +1206,15 @@ def tree_leaves_with_path(
- :func:`jax.tree_util.tree_leaves`
- :func:`jax.tree_util.tree_flatten_with_path`
"""
return _generate_key_paths(tree, is_leaf)
return tree_flatten_with_path(tree, is_leaf)[0]
# generate_key_paths is not exported.
def generate_key_paths(
tree: Any, is_leaf: Callable[[Any], bool] | None = None
) -> list[tuple[KeyPath, Any]]:
if xla_extension_version >= 299:
return tree_leaves_with_path(tree, is_leaf)
return list(_generate_key_paths_((), tree, is_leaf))
_generate_key_paths = generate_key_paths # alias for backward compat

View File

@ -34,17 +34,53 @@ class PackageStructureTest(jtu.JaxTestCase):
_mod("jax.errors", exclude=["JaxRuntimeError"]),
_mod(
"jax.numpy",
exclude=["array_repr", "array_str", "can_cast", "character", "complexfloating",
"dtype", "iinfo", "index_exp", "inexact", "integer", "iterable", "finfo",
"flexible", "floating", "generic", "get_printoptions", "ndarray", "ndim",
"number", "object_", "printoptions", "save", "savez", "set_printoptions",
"shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"]
exclude=[
"array_repr",
"array_str",
"can_cast",
"character",
"complexfloating",
"dtype",
"iinfo",
"index_exp",
"inexact",
"integer",
"iterable",
"finfo",
"flexible",
"floating",
"generic",
"get_printoptions",
"ndarray",
"ndim",
"number",
"object_",
"printoptions",
"save",
"savez",
"set_printoptions",
"shape",
"signedinteger",
"size",
"s_",
"unsignedinteger",
"ComplexWarning",
],
),
_mod("jax.numpy.linalg"),
_mod("jax.nn.initializers"),
_mod(
"jax.tree_util",
exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"],
exclude=[
"PyTreeDef",
"default_registry",
"KeyEntry",
"KeyPath",
"DictKey",
"GetAttrKey",
"SequenceKey",
"FlattenedIndexKey",
],
),
])
def test_exported_names_match_module(self, module_name, include, exclude):

View File

@ -13,6 +13,7 @@
# limitations under the License.
import collections
from collections.abc import Hashable
import dataclasses
import functools
import pickle
@ -24,14 +25,20 @@ import jax
from jax import flatten_util
from jax import tree_util
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax._src.tree_util import flatten_one_level, prefix_errors
import jax.numpy as jnp
# Easier to read.
SequenceKey = tree_util.SequenceKey
DictKey = tree_util.DictKey
GetAttrKey = tree_util.GetAttrKey
FlattenedIndexKey = tree_util.FlattenedIndexKey
def _dummy_func(*args, **kwargs):
return
ATuple = collections.namedtuple("ATuple", ("foo", "bar"))
class ANamedTupleSubclass(ATuple):
@ -758,6 +765,78 @@ class TreeTest(jtu.JaxTestCase):
],
)
def testTreeFlattenWithPathBuiltin(self):
x = (1, {"a": 2, "b": 3})
flattened = tree_util.tree_flatten_with_path(x)
_, tdef = tree_util.tree_flatten(x)
self.assertEqual(
flattened[0],
[
((SequenceKey(0),), 1),
((SequenceKey(1), DictKey("a")), 2),
((SequenceKey(1), DictKey("b")), 3),
],
)
self.assertEqual(flattened[1], tdef)
def testTreeFlattenWithPathCustom(self):
x = [
AnObject2(
x=12,
y={"foo": SpecialWithKeys(x=2, y=3), "bar": None},
z="constantdef",
),
5,
]
flattened, _ = tree_util.tree_flatten_with_path(x)
self.assertEqual(
flattened,
[
((SequenceKey(0), "x"), 12),
((SequenceKey(0), "y", DictKey("foo"), GetAttrKey("x")), 2),
((SequenceKey(0), "y", DictKey("foo"), GetAttrKey("y")), 3),
((SequenceKey(1),), 5),
],
)
def testFlattenWithPathDefaultDict(self):
if xla_extension_version < 299:
self.skipTest("Skipping for Python-based with path APIs.")
d = collections.defaultdict(int, {"b": 2, "a": 1, "c": {"b": 2, "a": 1}})
leaves, treedef = tree_util.tree_flatten_with_path(d)
self.assertEqual(
leaves,
[
((DictKey("a"),), 1),
((DictKey("b"),), 2),
((DictKey("c"), DictKey("a")), 1),
((DictKey("c"), DictKey("b")), 2),
],
)
restored_d = tree_util.tree_unflatten(treedef, [l for _, l in leaves])
self.assertEqual(list(restored_d.keys()), ["a", "b", "c"])
_, from_flatten = tree_util.tree_flatten(d)
self.assertEqual(treedef, from_flatten)
def testFlattenWithPathOrderedDict(self):
if xla_extension_version < 299:
self.skipTest("Skipping for Python-based with path APIs.")
d = collections.OrderedDict({"b": 2, "a": 1, "c": {"b": 2, "a": 1}})
leaves, treedef = tree_util.tree_flatten_with_path(d)
self.assertEqual(
leaves,
[
((DictKey("b"),), 2),
((DictKey("a"),), 1),
((DictKey("c"), DictKey("a")), 1),
((DictKey("c"), DictKey("b")), 2),
],
)
restored_d = tree_util.tree_unflatten(treedef, [l for _, l in leaves])
self.assertEqual(list(restored_d.keys()), ["b", "a", "c"])
_, from_flatten = tree_util.tree_flatten(d)
self.assertEqual(treedef, from_flatten)
def testFlattenOneLevel(self):
EmptyTuple = collections.namedtuple("EmptyTuple", ())
tree1 = {'a': 1,
@ -838,6 +917,90 @@ class TreeTest(jtu.JaxTestCase):
tree_util.tree_flatten(t)
class TreeKeyTest(absltest.TestCase):
def testBasic(self):
if xla_extension_version < 299:
self.skipTest("Skipping for Python-based with path APIs.")
def assert_equal_and_hash_equal(a, b):
self.assertEqual(a, b)
self.assertEqual(hash(a), hash(b))
key = SequenceKey(idx=1)
self.assertEqual(str(key), "[1]")
self.assertEqual(key.idx, 1)
assert_equal_and_hash_equal(key, SequenceKey(1))
class DictKeyEntry(Hashable):
def __init__(self, s: str):
self.s = s
def __hash__(self):
return hash(self.s)
def __eq__(self, other):
return self.s == other.s
key = DictKey(key="foo")
self.assertEqual(str(key), "['foo']")
self.assertEqual(key.key, "foo")
assert_equal_and_hash_equal(key, DictKey("foo"))
assert_equal_and_hash_equal(
DictKey(DictKeyEntry("foo")), DictKey(DictKeyEntry("foo"))
)
key = GetAttrKey(name="bar")
self.assertEqual(str(key), ".bar")
self.assertEqual(key.name, "bar")
assert_equal_and_hash_equal(key, GetAttrKey("bar"))
key = FlattenedIndexKey(1)
self.assertEqual(str(key), "[<flat index 1>]")
self.assertEqual(key.key, 1)
assert_equal_and_hash_equal(key, FlattenedIndexKey(1))
self.assertNotEqual(hash(key), hash(SequenceKey(1)))
def testPatternMatching(self):
keys = [
SequenceKey(1),
DictKey("foo"),
GetAttrKey("bar"),
FlattenedIndexKey(1),
]
for key in keys:
match key:
case jax.tree_util.SequenceKey(idx=idx):
self.assertEqual(idx, 1)
case jax.tree_util.DictKey(key=key):
self.assertEqual(key, "foo")
case jax.tree_util.GetAttrKey(name=name):
self.assertEqual(name, "bar")
case jax.tree_util.FlattenedIndexKey(key=idx_key):
self.assertEqual(idx_key, 1)
case _:
raise ValueError(f"key not matched: {key}")
match [
DictKey("foo"),
]:
case [DictKey("foo"), *_]:
pass
case _:
raise ValueError(f"keys are not matched: {keys}")
def testPickle(self):
keys = [
SequenceKey(1),
DictKey("foo"),
GetAttrKey("bar"),
FlattenedIndexKey(1),
]
for key in keys:
unpickled = pickle.loads(pickle.dumps(key))
self.assertEqual(key, unpickled)
class StaticTest(parameterized.TestCase):
@parameterized.parameters(