Correct register_pytree_with_keys annotation

This commit is contained in:
Neil Girdhar 2023-03-13 15:10:47 -04:00
parent 06e54dd8b7
commit ae4b0c5430

View File

@ -544,8 +544,8 @@ _register_keypaths(
def register_pytree_with_keys(
nodetype: Type[T],
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, _Children]], _AuxData]],
unflatten_func: Callable[[_AuxData, _Children], T]):
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, Any]], _AuxData]],
unflatten_func: Callable[[_AuxData, Iterable[Any]], T]):
"""Extends the set of types that are considered internal nodes in pytrees.
This is a more powerful alternative to ``register_pytree_node`` that allows
@ -553,9 +553,9 @@ def register_pytree_with_keys(
Args:
nodetype: a Python type to treat as an internal pytree node.
flatten_func: a function to be used during flattening, taking a value of
type ``nodetype`` and returning a pair, with (1) an iterable for tuples of
each key path and its child, and (2) some hashable auxiliary data to be
flatten_with_keys: a function to be used during flattening, taking a value
of type ``nodetype`` and returning a pair, with (1) an iterable for tuples
of each key path and its child, and (2) some hashable auxiliary data to be
stored in the treedef and to be passed to the ``unflatten_func``.
unflatten_func: a function taking two arguments: the auxiliary data that was
returned by ``flatten_func`` and stored in the treedef, and the