diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 0e8a270df..e33d5c802 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -542,10 +542,17 @@ _register_keypaths( collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys()) ) + def register_pytree_with_keys( nodetype: Type[T], - flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, Any]], _AuxData]], - unflatten_func: Callable[[_AuxData, Iterable[Any]], T]): + flatten_with_keys: Callable[ + [T], Tuple[Iterable[Tuple[KeyPath, Any]], _AuxData] + ], + unflatten_func: Callable[[_AuxData, Iterable[Any]], T], + flatten_func: Optional[ + Callable[[T], Tuple[Iterable[Any], _AuxData]] + ] = None, +): """Extends the set of types that are considered internal nodes in pytrees. This is a more powerful alternative to ``register_pytree_node`` that allows @@ -561,15 +568,24 @@ def register_pytree_with_keys( returned by ``flatten_func`` and stored in the treedef, and the unflattened children. The function should return an instance of ``nodetype``. + flatten_func: an optional function similar to ``flatten_with_keys``, but + returns only children and auxiliary data. It must return the children + in the same order as ``flatten_with_keys``, and return the same aux data. + This argument is optional and only needed for faster traversal when + calling functions without keys like ``tree_map`` and ``tree_flatten``. """ - def flatten_func(tree): - key_children, treedef = flatten_with_keys(tree) - return [c for _, c in key_children], treedef + if not flatten_func: + def flatten_func_impl(tree): + key_children, treedef = flatten_with_keys(tree) + return [c for _, c in key_children], treedef + flatten_func = flatten_func_impl + register_pytree_node(nodetype, flatten_func, unflatten_func) _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( flatten_with_keys, unflatten_func ) + def register_pytree_with_keys_class(cls: U) -> U: """Extends the set of types that are considered internal nodes in pytrees. @@ -590,8 +606,12 @@ def register_pytree_with_keys_class(cls: U) -> U: def tree_unflatten(cls, aux_data, children): return cls(*children) """ + flatten_func = ( + op.methodcaller("tree_flatten") if hasattr(cls, "tree_flatten") else None + ) register_pytree_with_keys( - cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten + cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten, + flatten_func ) return cls diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 6614ab302..11f85e243 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -531,6 +531,33 @@ class TreeTest(jtu.JaxTestCase): with self.assertRaisesRegex(ValueError, "can't tree-flatten type"): flatten_one_level(jnp.array((1, 2))) + def testOptionalFlatten(self): + @tree_util.register_pytree_with_keys_class + class FooClass: + def __init__(self, x, y): + self.x = x + self.y = y + def tree_flatten(self): + return ((self.x, self.y), 'treedef') + def tree_flatten_with_keys(self): + return (((tree_util.GetAttrKey('x'), self.x), + (tree_util.GetAttrKey('x'), self.y)), 'treedef') + @classmethod + def tree_unflatten(cls, _, children): + return cls(*children) + + tree = FooClass(x=1, y=2) + self.assertEqual( + str(tree_util.tree_flatten(tree)[1]), + "PyTreeDef(CustomNode(FooClass[treedef], [*, *]))", + ) + self.assertEqual( + str(tree_util.tree_flatten_with_path(tree)[1]), + "PyTreeDef(CustomNode(FooClass[treedef], [*, *]))", + ) + self.assertEqual(tree_util.tree_flatten(tree)[0], + [l for _, l in tree_util.tree_flatten_with_path(tree)[0]]) + class RavelUtilTest(jtu.JaxTestCase):