diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index bd2ca2df7..73cff5aa8 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -22,10 +22,11 @@ import functools from functools import partial import operator as op import textwrap -from typing import Any, NamedTuple, TypeVar, Union, overload +from typing import Any, NamedTuple, TypeVar, overload from jax._src import traceback_util from jax._src.lib import pytree +from jax._src.lib import xla_extension_version from jax._src.util import safe_zip, set_module from jax._src.util import unzip2 @@ -209,12 +210,21 @@ def all_leaves(iterable: Iterable[Any], _Children = TypeVar("_Children", bound=Iterable[Any]) _AuxData = TypeVar("_AuxData", bound=Hashable) +KeyEntry = TypeVar("KeyEntry", bound=Any) +KeyLeafPair = tuple[KeyEntry, Any] +KeyLeafPairs = Iterable[KeyLeafPair] +KeyPath = tuple[KeyEntry, ...] @export -def register_pytree_node(nodetype: type[T], - flatten_func: Callable[[T], tuple[_Children, _AuxData]], - unflatten_func: Callable[[_AuxData, _Children], T]) -> None: +def register_pytree_node( + nodetype: type[T], + flatten_func: Callable[[T], tuple[_Children, _AuxData]], + unflatten_func: Callable[[_AuxData, _Children], T], + flatten_with_keys_func: ( + Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None + ) = None, +) -> None: """Extends the set of types that are considered internal nodes in pytrees. See :ref:`example usage `. @@ -279,9 +289,20 @@ def register_pytree_node(nodetype: type[T], >>> jax.jit(f)(m) Array([1., 2., 3., 4., 5.], dtype=float32) """ - default_registry.register_node(nodetype, flatten_func, unflatten_func) - none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func) - dispatch_registry.register_node(nodetype, flatten_func, unflatten_func) + if xla_extension_version >= 299: + default_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func + ) + none_leaf_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func + ) + dispatch_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func + ) + else: + default_registry.register_node(nodetype, flatten_func, unflatten_func) + none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func) + dispatch_registry.register_node(nodetype, flatten_func, unflatten_func) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) @@ -452,21 +473,6 @@ def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool return all(tree_leaves(tree, is_leaf=is_leaf)) -register_pytree_node( - collections.OrderedDict, - lambda x: (tuple(x.values()), tuple(x.keys())), - lambda keys, values: collections.OrderedDict(safe_zip(keys, values))) - -def _flatten_defaultdict(d): - keys = tuple(sorted(d)) - return tuple(d[k] for k in keys), (d.default_factory, keys) - -register_pytree_node( - collections.defaultdict, - _flatten_defaultdict, - lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values))) - - class _HashableCallableShim: """Object that delegates __call__, __hash__, and __eq__ to another object.""" @@ -578,11 +584,11 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any, # flatten_one_level is not exported. -def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]: +def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]: """Flatten the given pytree node by one level. Args: - pytree: A valid pytree node, either built-in or registered via + tree: A valid pytree node, either built-in or registered via :func:`register_pytree_node` or related functions. Returns: @@ -601,9 +607,9 @@ def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]: >>> meta ('a', 'b') """ - out = default_registry.flatten_one_level(pytree) + out = default_registry.flatten_one_level(tree) if out is None: - raise ValueError(f"can't tree-flatten type: {type(pytree)}") + raise ValueError(f"can't tree-flatten type: {type(tree)}") else: return out @@ -739,10 +745,12 @@ class FlattenedIndexKey(): def __str__(self): return f'[]' -BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey] -KeyEntry = TypeVar("KeyEntry", bound=Hashable) -KeyPath = tuple[KeyEntry, ...] +if xla_extension_version >= 299: + SequenceKey = pytree.SequenceKey # type: ignore + DictKey = pytree.DictKey # type: ignore + GetAttrKey = pytree.GetAttrKey # type: ignore + FlattenedIndexKey = pytree.FlattenedIndexKey # type: ignore @export @@ -764,6 +772,7 @@ def keystr(keys: KeyPath): return ''.join(map(str, keys)) +# TODO(ivyzheng): remove this after _child_keys() also moved to C++. class _RegistryWithKeypathsEntry(NamedTuple): flatten_with_keys: Callable[..., Any] unflatten_func: Callable[..., Any] @@ -780,7 +789,6 @@ def _register_keypaths( flatten_with_keys, _registry[ty].from_iter ) - _registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {} _register_keypaths( @@ -803,13 +811,9 @@ _register_keypaths( @export def register_pytree_with_keys( nodetype: type[T], - flatten_with_keys: Callable[ - [T], tuple[Iterable[tuple[KeyEntry, Any]], _AuxData] - ], + flatten_with_keys: Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]], unflatten_func: Callable[[_AuxData, Iterable[Any]], T], - flatten_func: None | ( - Callable[[T], tuple[Iterable[Any], _AuxData]] - ) = None, + flatten_func: None | (Callable[[T], tuple[Iterable[Any], _AuxData]]) = None, ): """Extends the set of types that are considered internal nodes in pytrees. @@ -870,7 +874,9 @@ def register_pytree_with_keys( return [c for _, c in key_children], treedef flatten_func = flatten_func_impl - register_pytree_node(nodetype, flatten_func, unflatten_func) + register_pytree_node( + nodetype, flatten_func, unflatten_func, flatten_with_keys + ) _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( flatten_with_keys, unflatten_func ) @@ -1092,6 +1098,40 @@ def register_dataclass( return nodetype +if xla_extension_version >= 299: + register_pytree_with_keys( + collections.OrderedDict, + lambda x: (tuple((DictKey(k), x[k]) for k in x.keys()), tuple(x.keys())), + lambda keys, values: collections.OrderedDict(safe_zip(keys, values)), + ) + + def _flatten_defaultdict_with_keys(d): + keys = tuple(sorted(d)) + return tuple((DictKey(k), d[k]) for k in keys), (d.default_factory, keys) + + register_pytree_with_keys( + collections.defaultdict, + _flatten_defaultdict_with_keys, + lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)), + ) +else: + register_pytree_node( + collections.OrderedDict, + lambda x: (tuple(x.values()), tuple(x.keys())), + lambda keys, values: collections.OrderedDict(safe_zip(keys, values)), + ) + + def _flatten_defaultdict(d): + keys = tuple(sorted(d)) + return tuple(d[k] for k in keys), (d.default_factory, keys) + + register_pytree_node( + collections.defaultdict, + _flatten_defaultdict, + lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)), + ) + + @export def register_static(cls: type[H]) -> type[H]: """Registers `cls` as a pytree with no leaves. @@ -1144,6 +1184,8 @@ def tree_flatten_with_path( which contains a leaf and its key path. The second element is a treedef representing the structure of the flattened tree. """ + if xla_extension_version >= 299: + return default_registry.flatten_with_path(tree, is_leaf) _, tree_def = tree_flatten(tree, is_leaf) return _generate_key_paths(tree, is_leaf), tree_def @@ -1164,13 +1206,15 @@ def tree_leaves_with_path( - :func:`jax.tree_util.tree_leaves` - :func:`jax.tree_util.tree_flatten_with_path` """ - return _generate_key_paths(tree, is_leaf) + return tree_flatten_with_path(tree, is_leaf)[0] # generate_key_paths is not exported. def generate_key_paths( tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[tuple[KeyPath, Any]]: + if xla_extension_version >= 299: + return tree_leaves_with_path(tree, is_leaf) return list(_generate_key_paths_((), tree, is_leaf)) _generate_key_paths = generate_key_paths # alias for backward compat diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index 25468c4ba..d80c750ae 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -34,17 +34,53 @@ class PackageStructureTest(jtu.JaxTestCase): _mod("jax.errors", exclude=["JaxRuntimeError"]), _mod( "jax.numpy", - exclude=["array_repr", "array_str", "can_cast", "character", "complexfloating", - "dtype", "iinfo", "index_exp", "inexact", "integer", "iterable", "finfo", - "flexible", "floating", "generic", "get_printoptions", "ndarray", "ndim", - "number", "object_", "printoptions", "save", "savez", "set_printoptions", - "shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"] + exclude=[ + "array_repr", + "array_str", + "can_cast", + "character", + "complexfloating", + "dtype", + "iinfo", + "index_exp", + "inexact", + "integer", + "iterable", + "finfo", + "flexible", + "floating", + "generic", + "get_printoptions", + "ndarray", + "ndim", + "number", + "object_", + "printoptions", + "save", + "savez", + "set_printoptions", + "shape", + "signedinteger", + "size", + "s_", + "unsignedinteger", + "ComplexWarning", + ], ), _mod("jax.numpy.linalg"), _mod("jax.nn.initializers"), _mod( "jax.tree_util", - exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"], + exclude=[ + "PyTreeDef", + "default_registry", + "KeyEntry", + "KeyPath", + "DictKey", + "GetAttrKey", + "SequenceKey", + "FlattenedIndexKey", + ], ), ]) def test_exported_names_match_module(self, module_name, include, exclude): diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index a3a8bc96e..bd0497a33 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +from collections.abc import Hashable import dataclasses import functools import pickle @@ -24,14 +25,20 @@ import jax from jax import flatten_util from jax import tree_util from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp +# Easier to read. +SequenceKey = tree_util.SequenceKey +DictKey = tree_util.DictKey +GetAttrKey = tree_util.GetAttrKey +FlattenedIndexKey = tree_util.FlattenedIndexKey + def _dummy_func(*args, **kwargs): return - ATuple = collections.namedtuple("ATuple", ("foo", "bar")) class ANamedTupleSubclass(ATuple): @@ -758,6 +765,78 @@ class TreeTest(jtu.JaxTestCase): ], ) + def testTreeFlattenWithPathBuiltin(self): + x = (1, {"a": 2, "b": 3}) + flattened = tree_util.tree_flatten_with_path(x) + _, tdef = tree_util.tree_flatten(x) + self.assertEqual( + flattened[0], + [ + ((SequenceKey(0),), 1), + ((SequenceKey(1), DictKey("a")), 2), + ((SequenceKey(1), DictKey("b")), 3), + ], + ) + self.assertEqual(flattened[1], tdef) + + def testTreeFlattenWithPathCustom(self): + x = [ + AnObject2( + x=12, + y={"foo": SpecialWithKeys(x=2, y=3), "bar": None}, + z="constantdef", + ), + 5, + ] + flattened, _ = tree_util.tree_flatten_with_path(x) + self.assertEqual( + flattened, + [ + ((SequenceKey(0), "x"), 12), + ((SequenceKey(0), "y", DictKey("foo"), GetAttrKey("x")), 2), + ((SequenceKey(0), "y", DictKey("foo"), GetAttrKey("y")), 3), + ((SequenceKey(1),), 5), + ], + ) + + def testFlattenWithPathDefaultDict(self): + if xla_extension_version < 299: + self.skipTest("Skipping for Python-based with path APIs.") + d = collections.defaultdict(int, {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}) + leaves, treedef = tree_util.tree_flatten_with_path(d) + self.assertEqual( + leaves, + [ + ((DictKey("a"),), 1), + ((DictKey("b"),), 2), + ((DictKey("c"), DictKey("a")), 1), + ((DictKey("c"), DictKey("b")), 2), + ], + ) + restored_d = tree_util.tree_unflatten(treedef, [l for _, l in leaves]) + self.assertEqual(list(restored_d.keys()), ["a", "b", "c"]) + _, from_flatten = tree_util.tree_flatten(d) + self.assertEqual(treedef, from_flatten) + + def testFlattenWithPathOrderedDict(self): + if xla_extension_version < 299: + self.skipTest("Skipping for Python-based with path APIs.") + d = collections.OrderedDict({"b": 2, "a": 1, "c": {"b": 2, "a": 1}}) + leaves, treedef = tree_util.tree_flatten_with_path(d) + self.assertEqual( + leaves, + [ + ((DictKey("b"),), 2), + ((DictKey("a"),), 1), + ((DictKey("c"), DictKey("a")), 1), + ((DictKey("c"), DictKey("b")), 2), + ], + ) + restored_d = tree_util.tree_unflatten(treedef, [l for _, l in leaves]) + self.assertEqual(list(restored_d.keys()), ["b", "a", "c"]) + _, from_flatten = tree_util.tree_flatten(d) + self.assertEqual(treedef, from_flatten) + def testFlattenOneLevel(self): EmptyTuple = collections.namedtuple("EmptyTuple", ()) tree1 = {'a': 1, @@ -838,6 +917,90 @@ class TreeTest(jtu.JaxTestCase): tree_util.tree_flatten(t) +class TreeKeyTest(absltest.TestCase): + + def testBasic(self): + if xla_extension_version < 299: + self.skipTest("Skipping for Python-based with path APIs.") + + def assert_equal_and_hash_equal(a, b): + self.assertEqual(a, b) + self.assertEqual(hash(a), hash(b)) + + key = SequenceKey(idx=1) + self.assertEqual(str(key), "[1]") + self.assertEqual(key.idx, 1) + assert_equal_and_hash_equal(key, SequenceKey(1)) + + class DictKeyEntry(Hashable): + + def __init__(self, s: str): + self.s = s + + def __hash__(self): + return hash(self.s) + + def __eq__(self, other): + return self.s == other.s + + key = DictKey(key="foo") + self.assertEqual(str(key), "['foo']") + self.assertEqual(key.key, "foo") + assert_equal_and_hash_equal(key, DictKey("foo")) + assert_equal_and_hash_equal( + DictKey(DictKeyEntry("foo")), DictKey(DictKeyEntry("foo")) + ) + + key = GetAttrKey(name="bar") + self.assertEqual(str(key), ".bar") + self.assertEqual(key.name, "bar") + assert_equal_and_hash_equal(key, GetAttrKey("bar")) + + key = FlattenedIndexKey(1) + self.assertEqual(str(key), "[]") + self.assertEqual(key.key, 1) + assert_equal_and_hash_equal(key, FlattenedIndexKey(1)) + self.assertNotEqual(hash(key), hash(SequenceKey(1))) + + def testPatternMatching(self): + keys = [ + SequenceKey(1), + DictKey("foo"), + GetAttrKey("bar"), + FlattenedIndexKey(1), + ] + for key in keys: + match key: + case jax.tree_util.SequenceKey(idx=idx): + self.assertEqual(idx, 1) + case jax.tree_util.DictKey(key=key): + self.assertEqual(key, "foo") + case jax.tree_util.GetAttrKey(name=name): + self.assertEqual(name, "bar") + case jax.tree_util.FlattenedIndexKey(key=idx_key): + self.assertEqual(idx_key, 1) + case _: + raise ValueError(f"key not matched: {key}") + match [ + DictKey("foo"), + ]: + case [DictKey("foo"), *_]: + pass + case _: + raise ValueError(f"keys are not matched: {keys}") + + def testPickle(self): + keys = [ + SequenceKey(1), + DictKey("foo"), + GetAttrKey("bar"), + FlattenedIndexKey(1), + ] + for key in keys: + unpickled = pickle.loads(pickle.dumps(key)) + self.assertEqual(key, unpickled) + + class StaticTest(parameterized.TestCase): @parameterized.parameters(