mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #14956 from NeilGirdhar:correct_annotation
PiperOrigin-RevId: 516296704
This commit is contained in:
commit
8e30ea081f
@ -544,8 +544,8 @@ _register_keypaths(
|
||||
|
||||
def register_pytree_with_keys(
|
||||
nodetype: Type[T],
|
||||
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, _Children]], _AuxData]],
|
||||
unflatten_func: Callable[[_AuxData, _Children], T]):
|
||||
flatten_with_keys: Callable[[T], Tuple[Iterable[Tuple[KeyPath, Any]], _AuxData]],
|
||||
unflatten_func: Callable[[_AuxData, Iterable[Any]], T]):
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
This is a more powerful alternative to ``register_pytree_node`` that allows
|
||||
@ -553,9 +553,9 @@ def register_pytree_with_keys(
|
||||
|
||||
Args:
|
||||
nodetype: a Python type to treat as an internal pytree node.
|
||||
flatten_func: a function to be used during flattening, taking a value of
|
||||
type ``nodetype`` and returning a pair, with (1) an iterable for tuples of
|
||||
each key path and its child, and (2) some hashable auxiliary data to be
|
||||
flatten_with_keys: a function to be used during flattening, taking a value
|
||||
of type ``nodetype`` and returning a pair, with (1) an iterable for tuples
|
||||
of each key path and its child, and (2) some hashable auxiliary data to be
|
||||
stored in the treedef and to be passed to the ``unflatten_func``.
|
||||
unflatten_func: a function taking two arguments: the auxiliary data that was
|
||||
returned by ``flatten_func`` and stored in the treedef, and the
|
||||
|
Loading…
x
Reference in New Issue
Block a user