Merge pull request #14956 from NeilGirdhar:correct_annotation

PiperOrigin-RevId: 516296704
This commit is contained in:
jax authors 2023-03-13 13:03:49 -07:00
commit 8e30ea081f

View File

@ -544,8 +544,8 @@ _register_keypaths(
def register_pytree_with_keys(
nodetype: Type[T],
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, _Children]], _AuxData]],
unflatten_func: Callable[[_AuxData, _Children], T]):
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, Any]], _AuxData]],
unflatten_func: Callable[[_AuxData, Iterable[Any]], T]):
"""Extends the set of types that are considered internal nodes in pytrees.
This is a more powerful alternative to ``register_pytree_node`` that allows
@ -553,9 +553,9 @@ def register_pytree_with_keys(
Args:
nodetype: a Python type to treat as an internal pytree node.
flatten_func: a function to be used during flattening, taking a value of
type ``nodetype`` and returning a pair, with (1) an iterable for tuples of
each key path and its child, and (2) some hashable auxiliary data to be
flatten_with_keys: a function to be used during flattening, taking a value
of type ``nodetype`` and returning a pair, with (1) an iterable for tuples
of each key path and its child, and (2) some hashable auxiliary data to be
stored in the treedef and to be passed to the ``unflatten_func``.
unflatten_func: a function taking two arguments: the auxiliary data that was
returned by ``flatten_func`` and stored in the treedef, and the