mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
C++ tree with path API
* Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening. * Moves all the key classes down to C++ level, while keeping the APIs unchanged. * Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy. * Registered defaultdict and ordereddict via the keypath API now. PiperOrigin-RevId: 701613257
This commit is contained in:
parent
db4b3f2922
commit
a1dfdc1d61
@ -22,10 +22,11 @@ import functools
|
||||
from functools import partial
|
||||
import operator as op
|
||||
import textwrap
|
||||
from typing import Any, NamedTuple, TypeVar, Union, overload
|
||||
from typing import Any, NamedTuple, TypeVar, overload
|
||||
|
||||
from jax._src import traceback_util
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip, set_module
|
||||
from jax._src.util import unzip2
|
||||
|
||||
@ -209,12 +210,21 @@ def all_leaves(iterable: Iterable[Any],
|
||||
|
||||
_Children = TypeVar("_Children", bound=Iterable[Any])
|
||||
_AuxData = TypeVar("_AuxData", bound=Hashable)
|
||||
KeyEntry = TypeVar("KeyEntry", bound=Any)
|
||||
KeyLeafPair = tuple[KeyEntry, Any]
|
||||
KeyLeafPairs = Iterable[KeyLeafPair]
|
||||
KeyPath = tuple[KeyEntry, ...]
|
||||
|
||||
|
||||
@export
|
||||
def register_pytree_node(nodetype: type[T],
|
||||
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
|
||||
unflatten_func: Callable[[_AuxData, _Children], T]) -> None:
|
||||
def register_pytree_node(
|
||||
nodetype: type[T],
|
||||
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
|
||||
unflatten_func: Callable[[_AuxData, _Children], T],
|
||||
flatten_with_keys_func: (
|
||||
Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None
|
||||
) = None,
|
||||
) -> None:
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
See :ref:`example usage <pytrees>`.
|
||||
@ -279,9 +289,20 @@ def register_pytree_node(nodetype: type[T],
|
||||
>>> jax.jit(f)(m)
|
||||
Array([1., 2., 3., 4., 5.], dtype=float32)
|
||||
"""
|
||||
default_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
if xla_extension_version >= 299:
|
||||
default_registry.register_node( # type: ignore[call-arg]
|
||||
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
|
||||
)
|
||||
none_leaf_registry.register_node( # type: ignore[call-arg]
|
||||
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
|
||||
)
|
||||
dispatch_registry.register_node( # type: ignore[call-arg]
|
||||
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
|
||||
)
|
||||
else:
|
||||
default_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
|
||||
|
||||
|
||||
@ -452,21 +473,6 @@ def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool
|
||||
return all(tree_leaves(tree, is_leaf=is_leaf))
|
||||
|
||||
|
||||
register_pytree_node(
|
||||
collections.OrderedDict,
|
||||
lambda x: (tuple(x.values()), tuple(x.keys())),
|
||||
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)))
|
||||
|
||||
def _flatten_defaultdict(d):
|
||||
keys = tuple(sorted(d))
|
||||
return tuple(d[k] for k in keys), (d.default_factory, keys)
|
||||
|
||||
register_pytree_node(
|
||||
collections.defaultdict,
|
||||
_flatten_defaultdict,
|
||||
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)))
|
||||
|
||||
|
||||
class _HashableCallableShim:
|
||||
"""Object that delegates __call__, __hash__, and __eq__ to another object."""
|
||||
|
||||
@ -578,11 +584,11 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
|
||||
|
||||
|
||||
# flatten_one_level is not exported.
|
||||
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
|
||||
def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]:
|
||||
"""Flatten the given pytree node by one level.
|
||||
|
||||
Args:
|
||||
pytree: A valid pytree node, either built-in or registered via
|
||||
tree: A valid pytree node, either built-in or registered via
|
||||
:func:`register_pytree_node` or related functions.
|
||||
|
||||
Returns:
|
||||
@ -601,9 +607,9 @@ def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
|
||||
>>> meta
|
||||
('a', 'b')
|
||||
"""
|
||||
out = default_registry.flatten_one_level(pytree)
|
||||
out = default_registry.flatten_one_level(tree)
|
||||
if out is None:
|
||||
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
|
||||
raise ValueError(f"can't tree-flatten type: {type(tree)}")
|
||||
else:
|
||||
return out
|
||||
|
||||
@ -739,10 +745,12 @@ class FlattenedIndexKey():
|
||||
def __str__(self):
|
||||
return f'[<flat index {self.key}>]'
|
||||
|
||||
BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey]
|
||||
|
||||
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
|
||||
KeyPath = tuple[KeyEntry, ...]
|
||||
if xla_extension_version >= 299:
|
||||
SequenceKey = pytree.SequenceKey # type: ignore
|
||||
DictKey = pytree.DictKey # type: ignore
|
||||
GetAttrKey = pytree.GetAttrKey # type: ignore
|
||||
FlattenedIndexKey = pytree.FlattenedIndexKey # type: ignore
|
||||
|
||||
|
||||
@export
|
||||
@ -764,6 +772,7 @@ def keystr(keys: KeyPath):
|
||||
return ''.join(map(str, keys))
|
||||
|
||||
|
||||
# TODO(ivyzheng): remove this after _child_keys() also moved to C++.
|
||||
class _RegistryWithKeypathsEntry(NamedTuple):
|
||||
flatten_with_keys: Callable[..., Any]
|
||||
unflatten_func: Callable[..., Any]
|
||||
@ -780,7 +789,6 @@ def _register_keypaths(
|
||||
flatten_with_keys, _registry[ty].from_iter
|
||||
)
|
||||
|
||||
|
||||
_registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {}
|
||||
|
||||
_register_keypaths(
|
||||
@ -803,13 +811,9 @@ _register_keypaths(
|
||||
@export
|
||||
def register_pytree_with_keys(
|
||||
nodetype: type[T],
|
||||
flatten_with_keys: Callable[
|
||||
[T], tuple[Iterable[tuple[KeyEntry, Any]], _AuxData]
|
||||
],
|
||||
flatten_with_keys: Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]],
|
||||
unflatten_func: Callable[[_AuxData, Iterable[Any]], T],
|
||||
flatten_func: None | (
|
||||
Callable[[T], tuple[Iterable[Any], _AuxData]]
|
||||
) = None,
|
||||
flatten_func: None | (Callable[[T], tuple[Iterable[Any], _AuxData]]) = None,
|
||||
):
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
@ -870,7 +874,9 @@ def register_pytree_with_keys(
|
||||
return [c for _, c in key_children], treedef
|
||||
flatten_func = flatten_func_impl
|
||||
|
||||
register_pytree_node(nodetype, flatten_func, unflatten_func)
|
||||
register_pytree_node(
|
||||
nodetype, flatten_func, unflatten_func, flatten_with_keys
|
||||
)
|
||||
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
|
||||
flatten_with_keys, unflatten_func
|
||||
)
|
||||
@ -1092,6 +1098,40 @@ def register_dataclass(
|
||||
return nodetype
|
||||
|
||||
|
||||
if xla_extension_version >= 299:
|
||||
register_pytree_with_keys(
|
||||
collections.OrderedDict,
|
||||
lambda x: (tuple((DictKey(k), x[k]) for k in x.keys()), tuple(x.keys())),
|
||||
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)),
|
||||
)
|
||||
|
||||
def _flatten_defaultdict_with_keys(d):
|
||||
keys = tuple(sorted(d))
|
||||
return tuple((DictKey(k), d[k]) for k in keys), (d.default_factory, keys)
|
||||
|
||||
register_pytree_with_keys(
|
||||
collections.defaultdict,
|
||||
_flatten_defaultdict_with_keys,
|
||||
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)),
|
||||
)
|
||||
else:
|
||||
register_pytree_node(
|
||||
collections.OrderedDict,
|
||||
lambda x: (tuple(x.values()), tuple(x.keys())),
|
||||
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)),
|
||||
)
|
||||
|
||||
def _flatten_defaultdict(d):
|
||||
keys = tuple(sorted(d))
|
||||
return tuple(d[k] for k in keys), (d.default_factory, keys)
|
||||
|
||||
register_pytree_node(
|
||||
collections.defaultdict,
|
||||
_flatten_defaultdict,
|
||||
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)),
|
||||
)
|
||||
|
||||
|
||||
@export
|
||||
def register_static(cls: type[H]) -> type[H]:
|
||||
"""Registers `cls` as a pytree with no leaves.
|
||||
@ -1144,6 +1184,8 @@ def tree_flatten_with_path(
|
||||
which contains a leaf and its key path. The second element is a treedef
|
||||
representing the structure of the flattened tree.
|
||||
"""
|
||||
if xla_extension_version >= 299:
|
||||
return default_registry.flatten_with_path(tree, is_leaf)
|
||||
_, tree_def = tree_flatten(tree, is_leaf)
|
||||
return _generate_key_paths(tree, is_leaf), tree_def
|
||||
|
||||
@ -1164,13 +1206,15 @@ def tree_leaves_with_path(
|
||||
- :func:`jax.tree_util.tree_leaves`
|
||||
- :func:`jax.tree_util.tree_flatten_with_path`
|
||||
"""
|
||||
return _generate_key_paths(tree, is_leaf)
|
||||
return tree_flatten_with_path(tree, is_leaf)[0]
|
||||
|
||||
|
||||
# generate_key_paths is not exported.
|
||||
def generate_key_paths(
|
||||
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> list[tuple[KeyPath, Any]]:
|
||||
if xla_extension_version >= 299:
|
||||
return tree_leaves_with_path(tree, is_leaf)
|
||||
return list(_generate_key_paths_((), tree, is_leaf))
|
||||
_generate_key_paths = generate_key_paths # alias for backward compat
|
||||
|
||||
|
@ -34,17 +34,53 @@ class PackageStructureTest(jtu.JaxTestCase):
|
||||
_mod("jax.errors", exclude=["JaxRuntimeError"]),
|
||||
_mod(
|
||||
"jax.numpy",
|
||||
exclude=["array_repr", "array_str", "can_cast", "character", "complexfloating",
|
||||
"dtype", "iinfo", "index_exp", "inexact", "integer", "iterable", "finfo",
|
||||
"flexible", "floating", "generic", "get_printoptions", "ndarray", "ndim",
|
||||
"number", "object_", "printoptions", "save", "savez", "set_printoptions",
|
||||
"shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"]
|
||||
exclude=[
|
||||
"array_repr",
|
||||
"array_str",
|
||||
"can_cast",
|
||||
"character",
|
||||
"complexfloating",
|
||||
"dtype",
|
||||
"iinfo",
|
||||
"index_exp",
|
||||
"inexact",
|
||||
"integer",
|
||||
"iterable",
|
||||
"finfo",
|
||||
"flexible",
|
||||
"floating",
|
||||
"generic",
|
||||
"get_printoptions",
|
||||
"ndarray",
|
||||
"ndim",
|
||||
"number",
|
||||
"object_",
|
||||
"printoptions",
|
||||
"save",
|
||||
"savez",
|
||||
"set_printoptions",
|
||||
"shape",
|
||||
"signedinteger",
|
||||
"size",
|
||||
"s_",
|
||||
"unsignedinteger",
|
||||
"ComplexWarning",
|
||||
],
|
||||
),
|
||||
_mod("jax.numpy.linalg"),
|
||||
_mod("jax.nn.initializers"),
|
||||
_mod(
|
||||
"jax.tree_util",
|
||||
exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"],
|
||||
exclude=[
|
||||
"PyTreeDef",
|
||||
"default_registry",
|
||||
"KeyEntry",
|
||||
"KeyPath",
|
||||
"DictKey",
|
||||
"GetAttrKey",
|
||||
"SequenceKey",
|
||||
"FlattenedIndexKey",
|
||||
],
|
||||
),
|
||||
])
|
||||
def test_exported_names_match_module(self, module_name, include, exclude):
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
from collections.abc import Hashable
|
||||
import dataclasses
|
||||
import functools
|
||||
import pickle
|
||||
@ -24,14 +25,20 @@ import jax
|
||||
from jax import flatten_util
|
||||
from jax import tree_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.tree_util import flatten_one_level, prefix_errors
|
||||
import jax.numpy as jnp
|
||||
|
||||
# Easier to read.
|
||||
SequenceKey = tree_util.SequenceKey
|
||||
DictKey = tree_util.DictKey
|
||||
GetAttrKey = tree_util.GetAttrKey
|
||||
FlattenedIndexKey = tree_util.FlattenedIndexKey
|
||||
|
||||
|
||||
def _dummy_func(*args, **kwargs):
|
||||
return
|
||||
|
||||
|
||||
ATuple = collections.namedtuple("ATuple", ("foo", "bar"))
|
||||
|
||||
class ANamedTupleSubclass(ATuple):
|
||||
@ -758,6 +765,78 @@ class TreeTest(jtu.JaxTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def testTreeFlattenWithPathBuiltin(self):
|
||||
x = (1, {"a": 2, "b": 3})
|
||||
flattened = tree_util.tree_flatten_with_path(x)
|
||||
_, tdef = tree_util.tree_flatten(x)
|
||||
self.assertEqual(
|
||||
flattened[0],
|
||||
[
|
||||
((SequenceKey(0),), 1),
|
||||
((SequenceKey(1), DictKey("a")), 2),
|
||||
((SequenceKey(1), DictKey("b")), 3),
|
||||
],
|
||||
)
|
||||
self.assertEqual(flattened[1], tdef)
|
||||
|
||||
def testTreeFlattenWithPathCustom(self):
|
||||
x = [
|
||||
AnObject2(
|
||||
x=12,
|
||||
y={"foo": SpecialWithKeys(x=2, y=3), "bar": None},
|
||||
z="constantdef",
|
||||
),
|
||||
5,
|
||||
]
|
||||
flattened, _ = tree_util.tree_flatten_with_path(x)
|
||||
self.assertEqual(
|
||||
flattened,
|
||||
[
|
||||
((SequenceKey(0), "x"), 12),
|
||||
((SequenceKey(0), "y", DictKey("foo"), GetAttrKey("x")), 2),
|
||||
((SequenceKey(0), "y", DictKey("foo"), GetAttrKey("y")), 3),
|
||||
((SequenceKey(1),), 5),
|
||||
],
|
||||
)
|
||||
|
||||
def testFlattenWithPathDefaultDict(self):
|
||||
if xla_extension_version < 299:
|
||||
self.skipTest("Skipping for Python-based with path APIs.")
|
||||
d = collections.defaultdict(int, {"b": 2, "a": 1, "c": {"b": 2, "a": 1}})
|
||||
leaves, treedef = tree_util.tree_flatten_with_path(d)
|
||||
self.assertEqual(
|
||||
leaves,
|
||||
[
|
||||
((DictKey("a"),), 1),
|
||||
((DictKey("b"),), 2),
|
||||
((DictKey("c"), DictKey("a")), 1),
|
||||
((DictKey("c"), DictKey("b")), 2),
|
||||
],
|
||||
)
|
||||
restored_d = tree_util.tree_unflatten(treedef, [l for _, l in leaves])
|
||||
self.assertEqual(list(restored_d.keys()), ["a", "b", "c"])
|
||||
_, from_flatten = tree_util.tree_flatten(d)
|
||||
self.assertEqual(treedef, from_flatten)
|
||||
|
||||
def testFlattenWithPathOrderedDict(self):
|
||||
if xla_extension_version < 299:
|
||||
self.skipTest("Skipping for Python-based with path APIs.")
|
||||
d = collections.OrderedDict({"b": 2, "a": 1, "c": {"b": 2, "a": 1}})
|
||||
leaves, treedef = tree_util.tree_flatten_with_path(d)
|
||||
self.assertEqual(
|
||||
leaves,
|
||||
[
|
||||
((DictKey("b"),), 2),
|
||||
((DictKey("a"),), 1),
|
||||
((DictKey("c"), DictKey("a")), 1),
|
||||
((DictKey("c"), DictKey("b")), 2),
|
||||
],
|
||||
)
|
||||
restored_d = tree_util.tree_unflatten(treedef, [l for _, l in leaves])
|
||||
self.assertEqual(list(restored_d.keys()), ["b", "a", "c"])
|
||||
_, from_flatten = tree_util.tree_flatten(d)
|
||||
self.assertEqual(treedef, from_flatten)
|
||||
|
||||
def testFlattenOneLevel(self):
|
||||
EmptyTuple = collections.namedtuple("EmptyTuple", ())
|
||||
tree1 = {'a': 1,
|
||||
@ -838,6 +917,90 @@ class TreeTest(jtu.JaxTestCase):
|
||||
tree_util.tree_flatten(t)
|
||||
|
||||
|
||||
class TreeKeyTest(absltest.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
if xla_extension_version < 299:
|
||||
self.skipTest("Skipping for Python-based with path APIs.")
|
||||
|
||||
def assert_equal_and_hash_equal(a, b):
|
||||
self.assertEqual(a, b)
|
||||
self.assertEqual(hash(a), hash(b))
|
||||
|
||||
key = SequenceKey(idx=1)
|
||||
self.assertEqual(str(key), "[1]")
|
||||
self.assertEqual(key.idx, 1)
|
||||
assert_equal_and_hash_equal(key, SequenceKey(1))
|
||||
|
||||
class DictKeyEntry(Hashable):
|
||||
|
||||
def __init__(self, s: str):
|
||||
self.s = s
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.s)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.s == other.s
|
||||
|
||||
key = DictKey(key="foo")
|
||||
self.assertEqual(str(key), "['foo']")
|
||||
self.assertEqual(key.key, "foo")
|
||||
assert_equal_and_hash_equal(key, DictKey("foo"))
|
||||
assert_equal_and_hash_equal(
|
||||
DictKey(DictKeyEntry("foo")), DictKey(DictKeyEntry("foo"))
|
||||
)
|
||||
|
||||
key = GetAttrKey(name="bar")
|
||||
self.assertEqual(str(key), ".bar")
|
||||
self.assertEqual(key.name, "bar")
|
||||
assert_equal_and_hash_equal(key, GetAttrKey("bar"))
|
||||
|
||||
key = FlattenedIndexKey(1)
|
||||
self.assertEqual(str(key), "[<flat index 1>]")
|
||||
self.assertEqual(key.key, 1)
|
||||
assert_equal_and_hash_equal(key, FlattenedIndexKey(1))
|
||||
self.assertNotEqual(hash(key), hash(SequenceKey(1)))
|
||||
|
||||
def testPatternMatching(self):
|
||||
keys = [
|
||||
SequenceKey(1),
|
||||
DictKey("foo"),
|
||||
GetAttrKey("bar"),
|
||||
FlattenedIndexKey(1),
|
||||
]
|
||||
for key in keys:
|
||||
match key:
|
||||
case jax.tree_util.SequenceKey(idx=idx):
|
||||
self.assertEqual(idx, 1)
|
||||
case jax.tree_util.DictKey(key=key):
|
||||
self.assertEqual(key, "foo")
|
||||
case jax.tree_util.GetAttrKey(name=name):
|
||||
self.assertEqual(name, "bar")
|
||||
case jax.tree_util.FlattenedIndexKey(key=idx_key):
|
||||
self.assertEqual(idx_key, 1)
|
||||
case _:
|
||||
raise ValueError(f"key not matched: {key}")
|
||||
match [
|
||||
DictKey("foo"),
|
||||
]:
|
||||
case [DictKey("foo"), *_]:
|
||||
pass
|
||||
case _:
|
||||
raise ValueError(f"keys are not matched: {keys}")
|
||||
|
||||
def testPickle(self):
|
||||
keys = [
|
||||
SequenceKey(1),
|
||||
DictKey("foo"),
|
||||
GetAttrKey("bar"),
|
||||
FlattenedIndexKey(1),
|
||||
]
|
||||
for key in keys:
|
||||
unpickled = pickle.loads(pickle.dumps(key))
|
||||
self.assertEqual(key, unpickled)
|
||||
|
||||
|
||||
class StaticTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
|
Loading…
x
Reference in New Issue
Block a user