Document jax.tree.* directly

This commit is contained in:
Jake VanderPlas 2024-06-12 14:01:27 -07:00
parent 06fe7052bf
commit d82b66f77f
3 changed files with 231 additions and 223 deletions

View File

@ -20,19 +20,27 @@ List of Functions
register_pytree_with_keys
register_pytree_with_keys_class
register_static
tree_all
tree_flatten
tree_flatten_with_path
tree_leaves
tree_leaves_with_path
tree_map
tree_map_with_path
tree_reduce
tree_structure
tree_transpose
tree_unflatten
treedef_children
treedef_is_leaf
treedef_tuple
keystr
Legacy APIs
-----------
These APIs are now accessed via :mod:`jax.tree`.
.. autosummary::
:toctree: _autosummary
tree_all
tree_flatten
tree_leaves
tree_map
tree_reduce
tree_structure
tree_transpose
tree_unflatten

View File

@ -21,21 +21,93 @@ T = TypeVar("T")
def all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
"""Alias of :func:`jax.tree_util.tree_all`."""
"""Call all() over the leaves of a tree.
Args:
tree: the pytree to evaluate
is_leaf : an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
result: boolean True or False
Examples:
>>> import jax
>>> jax.tree.all([True, {'a': True, 'b': (True, True)}])
True
>>> jax.tree.all([False, (True, False)])
False
See Also:
- :func:`jax.tree.reduce`
- :func:`jax.tree.leaves`
"""
return tree_util.tree_all(tree, is_leaf=is_leaf)
def flatten(tree: Any,
is_leaf: Callable[[Any], bool] | None = None
) -> tuple[list[tree_util.Leaf], tree_util.PyTreeDef]:
"""Alias of :func:`jax.tree_util.tree_flatten`."""
"""Flattens a pytree.
The flattening order (i.e. the order of elements in the output list)
is deterministic, corresponding to a left-to-right depth-first tree
traversal.
Args:
tree: a pytree to flatten.
is_leaf: an optionally specified function that will be called at each
flattening step. It should return a boolean, with true stopping the
traversal and the whole subtree being treated as a leaf, and false
indicating the flattening should traverse the current object.
Returns:
A pair where the first element is a list of leaf values and the second
element is a treedef representing the structure of the flattened tree.
Example:
>>> import jax
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
>>> vals
[1, 2, 3, 4, 5]
>>> treedef
PyTreeDef([*, (*, *), [*, *]])
See Also:
- :func:`jax.tree.leaves`
- :func:`jax.tree.structure`
- :func:`jax.tree.unflatten`
"""
return tree_util.tree_flatten(tree, is_leaf)
def leaves(tree: Any,
is_leaf: Callable[[Any], bool] | None = None
) -> list[tree_util.Leaf]:
"""Alias of :func:`jax.tree_util.tree_leaves`."""
"""Gets the leaves of a pytree.
Args:
tree: the pytree for which to get the leaves
is_leaf : an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
leaves: a list of tree leaves.
Example:
>>> import jax
>>> jax.tree.leaves([1, (2, 3), [4, 5]])
[1, 2, 3, 4, 5]
See Also:
- :func:`jax.tree.flatten`
- :func:`jax.tree.structure`
- :func:`jax.tree.unflatten`
"""
return tree_util.tree_leaves(tree, is_leaf)
@ -43,7 +115,42 @@ def map(f: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Callable[[Any], bool] | None = None) -> Any:
"""Alias of :func:`jax.tree_util.tree_map`."""
"""Maps a multi-input function over pytree args to produce a new pytree.
Args:
f: function that takes ``1 + len(rest)`` arguments, to be applied at the
corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf providing the first
positional argument to ``f``.
rest: a tuple of pytrees, each of which has the same structure as ``tree``
or has ``tree`` as a prefix.
is_leaf: an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
A new pytree with the same structure as ``tree`` but with the value at each
leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding
leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in
``rest``.
Examples:
>>> import jax
>>> jax.tree.map(lambda x: x + 1, {"x": 7, "y": 42})
{'x': 8, 'y': 43}
If multiple inputs are passed, the structure of the tree is taken from the
first input; subsequent inputs need only have ``tree`` as a prefix:
>>> jax.tree.map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
[[5, 7, 9], [6, 1, 2]]
See Also:
- :func:`jax.tree.leaves`
- :func:`jax.tree.reduce`
"""
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
@ -63,24 +170,116 @@ def reduce(function: Callable[[T, Any], T],
tree: Any,
initializer: Any = tree_util.no_initializer,
is_leaf: Callable[[Any], bool] | None = None) -> T:
"""Alias of :func:`jax.tree_util.tree_reduce`."""
"""Call reduce() over the leaves of a tree.
Args:
function: the reduction function
tree: the pytree to reduce over
initializer: the optional initial value
is_leaf : an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
result: the reduced value.
Examples:
>>> import jax
>>> import operator
>>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]])
21
See Also:
- :func:`jax.tree.leaves`
- :func:`jax.tree.map`
"""
return tree_util.tree_reduce(function, tree, initializer, is_leaf=is_leaf)
def structure(tree: Any,
is_leaf: None | (Callable[[Any], bool]) = None) -> tree_util.PyTreeDef:
"""Alias of :func:`jax.tree_util.tree_structure`."""
"""Gets the treedef for a pytree.
Args:
tree: the pytree for which to get the leaves
is_leaf : an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
pytreedef: a PyTreeDef representing the structure of the tree.
Example:
>>> import jax
>>> jax.tree.structure([1, (2, 3), [4, 5]])
PyTreeDef([*, (*, *), [*, *]])
See Also:
- :func:`jax.tree.flatten`
- :func:`jax.tree.leaves`
- :func:`jax.tree.unflatten`
"""
return tree_util.tree_structure(tree, is_leaf)
def transpose(outer_treedef: tree_util.PyTreeDef,
inner_treedef: tree_util.PyTreeDef,
pytree_to_transpose: Any) -> Any:
"""Alias of :func:`jax.tree_util.tree_transpose`."""
"""Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).
Args:
outer_treedef: PyTreeDef representing the outer tree.
inner_treedef: PyTreeDef representing the inner tree.
If None, then it will be inferred from outer_treedef and the structure of
pytree_to_transpose.
pytree_to_transpose: the pytree to be transposed.
Returns:
transposed_pytree: the transposed pytree.
Examples:
>>> import jax
>>> tree = [(1, 2, 3), (4, 5, 6)]
>>> inner_structure = jax.tree.structure(('*', '*', '*'))
>>> outer_structure = jax.tree.structure(['*', '*'])
>>> jax.tree.transpose(outer_structure, inner_structure, tree)
([1, 4], [2, 5], [3, 6])
Inferring the inner structure:
>>> jax.tree.transpose(outer_structure, None, tree)
([1, 4], [2, 5], [3, 6])
"""
return tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)
def unflatten(treedef: tree_util.PyTreeDef,
leaves: Iterable[tree_util.Leaf]) -> Any:
"""Alias of :func:`jax.tree_util.tree_unflatten`."""
"""Reconstructs a pytree from the treedef and the leaves.
The inverse of :func:`tree_flatten`.
Args:
treedef: the treedef to reconstruct
leaves: the iterable of leaves to use for reconstruction. The iterable must
match the leaves of the treedef.
Returns:
The reconstructed pytree, containing the ``leaves`` placed in the structure
described by ``treedef``.
Example:
>>> import jax
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
>>> newvals = [100, 200, 300, 400, 500]
>>> jax.tree.unflatten(treedef, newvals)
[100, (200, 300), [400, 500]]
See Also:
- :func:`jax.tree.flatten`
- :func:`jax.tree.leaves`
- :func:`jax.tree.structure`
"""
return tree_util.tree_unflatten(treedef, leaves)

View File

@ -74,66 +74,13 @@ dispatch_registry.__name__ = "dispatch_registry"
def tree_flatten(tree: Any,
is_leaf: Callable[[Any], bool] | None = None
) -> tuple[list[Leaf], PyTreeDef]:
"""Flattens a pytree.
The flattening order (i.e. the order of elements in the output list)
is deterministic, corresponding to a left-to-right depth-first tree
traversal.
Args:
tree: a pytree to flatten.
is_leaf: an optionally specified function that will be called at each
flattening step. It should return a boolean, with true stopping the
traversal and the whole subtree being treated as a leaf, and false
indicating the flattening should traverse the current object.
Returns:
A pair where the first element is a list of leaf values and the second
element is a treedef representing the structure of the flattened tree.
Example:
>>> import jax
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
>>> vals
[1, 2, 3, 4, 5]
>>> treedef
PyTreeDef([*, (*, *), [*, *]])
See Also:
- :func:`jax.tree.leaves`
- :func:`jax.tree.structure`
- :func:`jax.tree.unflatten`
"""
"""Alias of :func:`jax.tree.flatten`."""
return default_registry.flatten(tree, is_leaf)
@export
def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
"""Reconstructs a pytree from the treedef and the leaves.
The inverse of :func:`tree_flatten`.
Args:
treedef: the treedef to reconstruct
leaves: the iterable of leaves to use for reconstruction. The iterable must
match the leaves of the treedef.
Returns:
The reconstructed pytree, containing the ``leaves`` placed in the structure
described by ``treedef``.
Example:
>>> import jax
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
>>> newvals = [100, 200, 300, 400, 500]
>>> jax.tree.unflatten(treedef, newvals)
[100, (200, 300), [400, 500]]
See Also:
- :func:`jax.tree.flatten`
- :func:`jax.tree.leaves`
- :func:`jax.tree.structure`
"""
"""Alias of :func:`jax.tree.unflatten`."""
return treedef.unflatten(leaves)
@ -141,28 +88,7 @@ def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
def tree_leaves(tree: Any,
is_leaf: Callable[[Any], bool] | None = None
) -> list[Leaf]:
"""Gets the leaves of a pytree.
Args:
tree: the pytree for which to get the leaves
is_leaf : an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
leaves: a list of tree leaves.
Example:
>>> import jax
>>> jax.tree.leaves([1, (2, 3), [4, 5]])
[1, 2, 3, 4, 5]
See Also:
- :func:`jax.tree.flatten`
- :func:`jax.tree.structure`
- :func:`jax.tree.unflatten`
"""
"""Alias of :func:`jax.tree.leaves`."""
return default_registry.flatten(tree, is_leaf)[0]
@ -170,28 +96,7 @@ def tree_leaves(tree: Any,
def tree_structure(tree: Any,
is_leaf: None | (Callable[[Any],
bool]) = None) -> PyTreeDef:
"""Gets the treedef for a pytree.
Args:
tree: the pytree for which to get the leaves
is_leaf : an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
pytreedef: a PyTreeDef representing the structure of the tree.
Example:
>>> import jax
>>> jax.tree.structure([1, (2, 3), [4, 5]])
PyTreeDef([*, (*, *), [*, *]])
See Also:
- :func:`jax.tree.flatten`
- :func:`jax.tree.leaves`
- :func:`jax.tree.unflatten`
"""
"""Alias of :func:`jax.tree.structure`."""
return default_registry.flatten(tree, is_leaf)[1]
@ -432,42 +337,7 @@ def tree_map(f: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Callable[[Any], bool] | None = None) -> Any:
"""Maps a multi-input function over pytree args to produce a new pytree.
Args:
f: function that takes ``1 + len(rest)`` arguments, to be applied at the
corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf providing the first
positional argument to ``f``.
rest: a tuple of pytrees, each of which has the same structure as ``tree``
or has ``tree`` as a prefix.
is_leaf: an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
A new pytree with the same structure as ``tree`` but with the value at each
leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding
leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in
``rest``.
Examples:
>>> import jax.tree_util
>>> jax.tree_util.tree_map(lambda x: x + 1, {"x": 7, "y": 42})
{'x': 8, 'y': 43}
If multiple inputs are passed, the structure of the tree is taken from the
first input; subsequent inputs need only have ``tree`` as a prefix:
>>> jax.tree_util.tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
[[5, 7, 9], [6, 1, 2]]
See Also:
- :func:`jax.tree.leaves`
- :func:`jax.tree.reduce`
"""
"""Alias of :func:`jax.tree.map`."""
leaves, treedef = tree_flatten(tree, is_leaf)
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
@ -507,31 +377,7 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
@export
def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None,
pytree_to_transpose: Any) -> Any:
"""Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).
Args:
outer_treedef: PyTreeDef representing the outer tree.
inner_treedef: PyTreeDef representing the inner tree.
If None, then it will be inferred from outer_treedef and the structure of
pytree_to_transpose.
pytree_to_transpose: the pytree to be transposed.
Returns:
transposed_pytree: the transposed pytree.
Examples:
>>> import jax
>>> tree = [(1, 2, 3), (4, 5, 6)]
>>> inner_structure = jax.tree.structure(('*', '*', '*'))
>>> outer_structure = jax.tree.structure(['*', '*'])
>>> jax.tree.transpose(outer_structure, inner_structure, tree)
([1, 4], [2, 5], [3, 6])
Inferring the inner structure:
>>> jax.tree.transpose(outer_structure, None, tree)
([1, 4], [2, 5], [3, 6])
"""
"""Alias of :func:`jax.tree.transpose`."""
flat, treedef = tree_flatten(pytree_to_transpose)
if inner_treedef is None:
inner_treedef = tree_structure(outer_treedef.flatten_up_to(pytree_to_transpose)[0])
@ -592,30 +438,7 @@ def tree_reduce(function: Callable[[T, Any], T],
tree: Any,
initializer: Any = no_initializer,
is_leaf: Callable[[Any], bool] | None = None) -> T:
"""Call reduce() over the leaves of a tree.
Args:
function: the reduction function
tree: the pytree to reduce over
initializer: the optional initial value
is_leaf : an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
result: the reduced value.
Examples:
>>> import jax
>>> import operator
>>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]])
21
See Also:
- :func:`jax.tree.leaves`
- :func:`jax.tree.map`
"""
"""Alias of :func:`jax.tree.reduce`."""
if initializer is no_initializer:
return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf))
else:
@ -624,29 +447,7 @@ def tree_reduce(function: Callable[[T, Any], T],
@export
def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
"""Call all() over the leaves of a tree.
Args:
tree: the pytree to evaluate
is_leaf : an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether the
flattening should traverse the current object, or if it should be stopped
immediately, with the whole subtree being treated as a leaf.
Returns:
result: boolean True or False
Examples:
>>> import jax
>>> jax.tree.all([True, {'a': True, 'b': (True, True)}])
True
>>> jax.tree.all([False, (True, False)])
False
See Also:
- :func:`jax.tree_util.tree_reduce`
- :func:`jax.tree_util.tree_leaves`
"""
"""Alias of :func:`jax.tree.all`."""
return all(tree_leaves(tree, is_leaf=is_leaf))