Merge pull request #14956 from NeilGirdhar:correct_annotation

PiperOrigin-RevId: 516296704
This commit is contained in:
jax authors 2023-03-13 13:03:49 -07:00
commit 8e30ea081f

View File

@ -544,8 +544,8 @@ _register_keypaths(
def register_pytree_with_keys( def register_pytree_with_keys(
nodetype: Type[T], nodetype: Type[T],
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, _Children]], _AuxData]], flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, Any]], _AuxData]],
unflatten_func: Callable[[_AuxData, _Children], T]): unflatten_func: Callable[[_AuxData, Iterable[Any]], T]):
"""Extends the set of types that are considered internal nodes in pytrees. """Extends the set of types that are considered internal nodes in pytrees.
This is a more powerful alternative to ``register_pytree_node`` that allows This is a more powerful alternative to ``register_pytree_node`` that allows
@ -553,9 +553,9 @@ def register_pytree_with_keys(
Args: Args:
nodetype: a Python type to treat as an internal pytree node. nodetype: a Python type to treat as an internal pytree node.
flatten_func: a function to be used during flattening, taking a value of flatten_with_keys: a function to be used during flattening, taking a value
type ``nodetype`` and returning a pair, with (1) an iterable for tuples of of type ``nodetype`` and returning a pair, with (1) an iterable for tuples
each key path and its child, and (2) some hashable auxiliary data to be of each key path and its child, and (2) some hashable auxiliary data to be
stored in the treedef and to be passed to the ``unflatten_func``. stored in the treedef and to be passed to the ``unflatten_func``.
unflatten_func: a function taking two arguments: the auxiliary data that was unflatten_func: a function taking two arguments: the auxiliary data that was
returned by ``flatten_func`` and stored in the treedef, and the returned by ``flatten_func`` and stored in the treedef, and the