Add an optional flatten_func argument to custom node registration even when flatten_with_keys is given, for better perf for those in need.

Fixes #14844

PiperOrigin-RevId: 517308676
This commit is contained in:
Ivy Zheng 2023-03-16 21:34:29 -07:00 committed by jax authors
parent d9598215b8
commit 08c83369be
2 changed files with 53 additions and 6 deletions

View File

@ -542,10 +542,17 @@ _register_keypaths(
collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys())
)
def register_pytree_with_keys(
nodetype: Type[T],
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, Any]], _AuxData]],
unflatten_func: Callable[[_AuxData, Iterable[Any]], T]):
flatten_with_keys: Callable[
[T], Tuple[Iterable[Tuple[KeyPath, Any]], _AuxData]
],
unflatten_func: Callable[[_AuxData, Iterable[Any]], T],
flatten_func: Optional[
Callable[[T], Tuple[Iterable[Any], _AuxData]]
] = None,
):
"""Extends the set of types that are considered internal nodes in pytrees.
This is a more powerful alternative to ``register_pytree_node`` that allows
@ -561,15 +568,24 @@ def register_pytree_with_keys(
returned by ``flatten_func`` and stored in the treedef, and the
unflattened children. The function should return an instance of
``nodetype``.
flatten_func: an optional function similar to ``flatten_with_keys``, but
returns only children and auxiliary data. It must return the children
in the same order as ``flatten_with_keys``, and return the same aux data.
This argument is optional and only needed for faster traversal when
calling functions without keys like ``tree_map`` and ``tree_flatten``.
"""
def flatten_func(tree):
key_children, treedef = flatten_with_keys(tree)
return [c for _, c in key_children], treedef
if not flatten_func:
def flatten_func_impl(tree):
key_children, treedef = flatten_with_keys(tree)
return [c for _, c in key_children], treedef
flatten_func = flatten_func_impl
register_pytree_node(nodetype, flatten_func, unflatten_func)
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
flatten_with_keys, unflatten_func
)
def register_pytree_with_keys_class(cls: U) -> U:
"""Extends the set of types that are considered internal nodes in pytrees.
@ -590,8 +606,12 @@ def register_pytree_with_keys_class(cls: U) -> U:
def tree_unflatten(cls, aux_data, children):
return cls(*children)
"""
flatten_func = (
op.methodcaller("tree_flatten") if hasattr(cls, "tree_flatten") else None
)
register_pytree_with_keys(
cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten
cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten,
flatten_func
)
return cls

View File

@ -531,6 +531,33 @@ class TreeTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
flatten_one_level(jnp.array((1, 2)))
def testOptionalFlatten(self):
@tree_util.register_pytree_with_keys_class
class FooClass:
def __init__(self, x, y):
self.x = x
self.y = y
def tree_flatten(self):
return ((self.x, self.y), 'treedef')
def tree_flatten_with_keys(self):
return (((tree_util.GetAttrKey('x'), self.x),
(tree_util.GetAttrKey('x'), self.y)), 'treedef')
@classmethod
def tree_unflatten(cls, _, children):
return cls(*children)
tree = FooClass(x=1, y=2)
self.assertEqual(
str(tree_util.tree_flatten(tree)[1]),
"PyTreeDef(CustomNode(FooClass[treedef], [*, *]))",
)
self.assertEqual(
str(tree_util.tree_flatten_with_path(tree)[1]),
"PyTreeDef(CustomNode(FooClass[treedef], [*, *]))",
)
self.assertEqual(tree_util.tree_flatten(tree)[0],
[l for _, l in tree_util.tree_flatten_with_path(tree)[0]])
class RavelUtilTest(jtu.JaxTestCase):