mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add an optional flatten_func
argument to custom node registration even when flatten_with_keys
is given, for better perf for those in need.
Fixes #14844 PiperOrigin-RevId: 517308676
This commit is contained in:
parent
d9598215b8
commit
08c83369be
@ -542,10 +542,17 @@ _register_keypaths(
|
||||
collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys())
|
||||
)
|
||||
|
||||
|
||||
def register_pytree_with_keys(
|
||||
nodetype: Type[T],
|
||||
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, Any]], _AuxData]],
|
||||
unflatten_func: Callable[[_AuxData, Iterable[Any]], T]):
|
||||
flatten_with_keys: Callable[
|
||||
[T], Tuple[Iterable[Tuple[KeyPath, Any]], _AuxData]
|
||||
],
|
||||
unflatten_func: Callable[[_AuxData, Iterable[Any]], T],
|
||||
flatten_func: Optional[
|
||||
Callable[[T], Tuple[Iterable[Any], _AuxData]]
|
||||
] = None,
|
||||
):
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
This is a more powerful alternative to ``register_pytree_node`` that allows
|
||||
@ -561,15 +568,24 @@ def register_pytree_with_keys(
|
||||
returned by ``flatten_func`` and stored in the treedef, and the
|
||||
unflattened children. The function should return an instance of
|
||||
``nodetype``.
|
||||
flatten_func: an optional function similar to ``flatten_with_keys``, but
|
||||
returns only children and auxiliary data. It must return the children
|
||||
in the same order as ``flatten_with_keys``, and return the same aux data.
|
||||
This argument is optional and only needed for faster traversal when
|
||||
calling functions without keys like ``tree_map`` and ``tree_flatten``.
|
||||
"""
|
||||
def flatten_func(tree):
|
||||
key_children, treedef = flatten_with_keys(tree)
|
||||
return [c for _, c in key_children], treedef
|
||||
if not flatten_func:
|
||||
def flatten_func_impl(tree):
|
||||
key_children, treedef = flatten_with_keys(tree)
|
||||
return [c for _, c in key_children], treedef
|
||||
flatten_func = flatten_func_impl
|
||||
|
||||
register_pytree_node(nodetype, flatten_func, unflatten_func)
|
||||
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
|
||||
flatten_with_keys, unflatten_func
|
||||
)
|
||||
|
||||
|
||||
def register_pytree_with_keys_class(cls: U) -> U:
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
@ -590,8 +606,12 @@ def register_pytree_with_keys_class(cls: U) -> U:
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
"""
|
||||
flatten_func = (
|
||||
op.methodcaller("tree_flatten") if hasattr(cls, "tree_flatten") else None
|
||||
)
|
||||
register_pytree_with_keys(
|
||||
cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten
|
||||
cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten,
|
||||
flatten_func
|
||||
)
|
||||
return cls
|
||||
|
||||
|
@ -531,6 +531,33 @@ class TreeTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
|
||||
flatten_one_level(jnp.array((1, 2)))
|
||||
|
||||
def testOptionalFlatten(self):
|
||||
@tree_util.register_pytree_with_keys_class
|
||||
class FooClass:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
def tree_flatten(self):
|
||||
return ((self.x, self.y), 'treedef')
|
||||
def tree_flatten_with_keys(self):
|
||||
return (((tree_util.GetAttrKey('x'), self.x),
|
||||
(tree_util.GetAttrKey('x'), self.y)), 'treedef')
|
||||
@classmethod
|
||||
def tree_unflatten(cls, _, children):
|
||||
return cls(*children)
|
||||
|
||||
tree = FooClass(x=1, y=2)
|
||||
self.assertEqual(
|
||||
str(tree_util.tree_flatten(tree)[1]),
|
||||
"PyTreeDef(CustomNode(FooClass[treedef], [*, *]))",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(tree_util.tree_flatten_with_path(tree)[1]),
|
||||
"PyTreeDef(CustomNode(FooClass[treedef], [*, *]))",
|
||||
)
|
||||
self.assertEqual(tree_util.tree_flatten(tree)[0],
|
||||
[l for _, l in tree_util.tree_flatten_with_path(tree)[0]])
|
||||
|
||||
|
||||
class RavelUtilTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user