mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Merge pull request #14956 from NeilGirdhar:correct_annotation
PiperOrigin-RevId: 516296704
This commit is contained in:
commit
8e30ea081f
@ -544,8 +544,8 @@ _register_keypaths(
|
|||||||
|
|
||||||
def register_pytree_with_keys(
|
def register_pytree_with_keys(
|
||||||
nodetype: Type[T],
|
nodetype: Type[T],
|
||||||
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, _Children]], _AuxData]],
|
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, Any]], _AuxData]],
|
||||||
unflatten_func: Callable[[_AuxData, _Children], T]):
|
unflatten_func: Callable[[_AuxData, Iterable[Any]], T]):
|
||||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||||
|
|
||||||
This is a more powerful alternative to ``register_pytree_node`` that allows
|
This is a more powerful alternative to ``register_pytree_node`` that allows
|
||||||
@ -553,9 +553,9 @@ def register_pytree_with_keys(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
nodetype: a Python type to treat as an internal pytree node.
|
nodetype: a Python type to treat as an internal pytree node.
|
||||||
flatten_func: a function to be used during flattening, taking a value of
|
flatten_with_keys: a function to be used during flattening, taking a value
|
||||||
type ``nodetype`` and returning a pair, with (1) an iterable for tuples of
|
of type ``nodetype`` and returning a pair, with (1) an iterable for tuples
|
||||||
each key path and its child, and (2) some hashable auxiliary data to be
|
of each key path and its child, and (2) some hashable auxiliary data to be
|
||||||
stored in the treedef and to be passed to the ``unflatten_func``.
|
stored in the treedef and to be passed to the ``unflatten_func``.
|
||||||
unflatten_func: a function taking two arguments: the auxiliary data that was
|
unflatten_func: a function taking two arguments: the auxiliary data that was
|
||||||
returned by ``flatten_func`` and stored in the treedef, and the
|
returned by ``flatten_func`` and stored in the treedef, and the
|
||||||
|
Loading…
x
Reference in New Issue
Block a user