mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Better documentation for jax.tree_util
This commit is contained in:
parent
3f1b059503
commit
1327143d46
@ -19,6 +19,7 @@ List of Functions
|
||||
register_pytree_node_class
|
||||
register_pytree_with_keys
|
||||
register_pytree_with_keys_class
|
||||
register_static
|
||||
tree_all
|
||||
tree_flatten
|
||||
tree_flatten_with_path
|
||||
|
@ -25,10 +25,12 @@ from typing import Any, Callable, NamedTuple, Sequence, TypeVar, Union, overload
|
||||
|
||||
from jax._src import traceback_util
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.util import safe_zip, set_module
|
||||
from jax._src.util import unzip2
|
||||
|
||||
|
||||
export = set_module('jax.tree_util')
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -67,6 +69,8 @@ dispatch_registry = pytree.PyTreeRegistry(
|
||||
dispatch_registry.__module__ = __name__
|
||||
dispatch_registry.__name__ = "dispatch_registry"
|
||||
|
||||
|
||||
@export
|
||||
def tree_flatten(tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> tuple[list[Leaf], PyTreeDef]:
|
||||
@ -88,12 +92,12 @@ def tree_flatten(tree: Any,
|
||||
element is a treedef representing the structure of the flattened tree.
|
||||
|
||||
Example:
|
||||
>>> import jax
|
||||
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
|
||||
>>> vals
|
||||
[1, 2, 3, 4, 5]
|
||||
>>> treedef
|
||||
PyTreeDef([*, (*, *), [*, *]])
|
||||
>>> import jax
|
||||
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
|
||||
>>> vals
|
||||
[1, 2, 3, 4, 5]
|
||||
>>> treedef
|
||||
PyTreeDef([*, (*, *), [*, *]])
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.leaves`
|
||||
@ -103,6 +107,7 @@ def tree_flatten(tree: Any,
|
||||
return default_registry.flatten(tree, is_leaf)
|
||||
|
||||
|
||||
@export
|
||||
def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
|
||||
"""Reconstructs a pytree from the treedef and the leaves.
|
||||
|
||||
@ -118,11 +123,11 @@ def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
|
||||
described by ``treedef``.
|
||||
|
||||
Example:
|
||||
>>> import jax
|
||||
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
|
||||
>>> newvals = [100, 200, 300, 400, 500]
|
||||
>>> jax.tree.unflatten(treedef, newvals)
|
||||
[100, (200, 300), [400, 500]]
|
||||
>>> import jax
|
||||
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
|
||||
>>> newvals = [100, 200, 300, 400, 500]
|
||||
>>> jax.tree.unflatten(treedef, newvals)
|
||||
[100, (200, 300), [400, 500]]
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.flatten`
|
||||
@ -132,6 +137,7 @@ def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
|
||||
return treedef.unflatten(leaves)
|
||||
|
||||
|
||||
@export
|
||||
def tree_leaves(tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> list[Leaf]:
|
||||
@ -143,13 +149,14 @@ def tree_leaves(tree: Any,
|
||||
flattening step. It should return a boolean, which indicates whether the
|
||||
flattening should traverse the current object, or if it should be stopped
|
||||
immediately, with the whole subtree being treated as a leaf.
|
||||
|
||||
Returns:
|
||||
leaves: a list of tree leaves.
|
||||
|
||||
Example:
|
||||
>>> import jax
|
||||
>>> jax.tree.leaves([1, (2, 3), [4, 5]])
|
||||
[1, 2, 3, 4, 5]
|
||||
>>> import jax
|
||||
>>> jax.tree.leaves([1, (2, 3), [4, 5]])
|
||||
[1, 2, 3, 4, 5]
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.flatten`
|
||||
@ -159,6 +166,7 @@ def tree_leaves(tree: Any,
|
||||
return default_registry.flatten(tree, is_leaf)[0]
|
||||
|
||||
|
||||
@export
|
||||
def tree_structure(tree: Any,
|
||||
is_leaf: None | (Callable[[Any],
|
||||
bool]) = None) -> PyTreeDef:
|
||||
@ -170,12 +178,14 @@ def tree_structure(tree: Any,
|
||||
flattening step. It should return a boolean, which indicates whether the
|
||||
flattening should traverse the current object, or if it should be stopped
|
||||
immediately, with the whole subtree being treated as a leaf.
|
||||
|
||||
Returns:
|
||||
pytreedef: a PyTreeDef representing the structure of the tree.
|
||||
|
||||
Example:
|
||||
>>> import jax
|
||||
>>> jax.tree.structure([1, (2, 3), [4, 5]])
|
||||
PyTreeDef([*, (*, *), [*, *]])
|
||||
>>> import jax
|
||||
>>> jax.tree.structure([1, (2, 3), [4, 5]])
|
||||
PyTreeDef([*, (*, *), [*, *]])
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.flatten`
|
||||
@ -185,31 +195,89 @@ def tree_structure(tree: Any,
|
||||
return default_registry.flatten(tree, is_leaf)[1]
|
||||
|
||||
|
||||
@export
|
||||
def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef:
|
||||
"""Makes a tuple treedef from an iterable of child treedefs."""
|
||||
"""Makes a tuple treedef from an iterable of child treedefs.
|
||||
|
||||
Args:
|
||||
treedefs: iterable of PyTree structures
|
||||
|
||||
Returns:
|
||||
a single treedef representing a tuple of the structures
|
||||
|
||||
Example:
|
||||
>>> import jax
|
||||
>>> x = [1, 2, 3]
|
||||
>>> y = {'a': 4, 'b': 5}
|
||||
>>> x_tree = jax.tree.structure(x)
|
||||
>>> y_tree = jax.tree.structure(y)
|
||||
>>> xy_tree = jax.tree_util.treedef_tuple([x_tree, y_tree])
|
||||
>>> xy_tree == jax.tree.structure((x, y))
|
||||
True
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree_util.treedef_children`
|
||||
"""
|
||||
return pytree.tuple(default_registry, list(treedefs)) # type: ignore
|
||||
|
||||
|
||||
@export
|
||||
def treedef_children(treedef: PyTreeDef) -> list[PyTreeDef]:
|
||||
"""Return a list of treedefs for immediate children
|
||||
|
||||
Args:
|
||||
treedef: a single PyTreeDef
|
||||
|
||||
Returns:
|
||||
a list of PyTreeDefs representing the children of treedef.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> x = [(1, 2), 3, {'a': 4}]
|
||||
>>> treedef = jax.tree.structure(x)
|
||||
>>> jax.tree_util.treedef_children(treedef)
|
||||
[PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})]
|
||||
>>> _ == [jax.tree.structure(vals) for vals in x]
|
||||
True
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree_util.treedef_tuple`
|
||||
"""
|
||||
return treedef.children()
|
||||
|
||||
|
||||
@export
|
||||
def treedef_is_leaf(treedef: PyTreeDef) -> bool:
|
||||
"""Return True if the treedef represents a leaf.
|
||||
|
||||
Args:
|
||||
treedef: tree to check
|
||||
|
||||
Returns:
|
||||
True if treedef is a leaf (i.e. has a single node); False otherwise.
|
||||
|
||||
Example:
|
||||
>>> import jax
|
||||
>>> tree1 = jax.tree.structure(1)
|
||||
>>> jax.tree_util.treedef_is_leaf(tree1)
|
||||
True
|
||||
>>> tree2 = jax.tree.structure([1, 2])
|
||||
>>> jax.tree_util.treedef_is_leaf(tree2)
|
||||
False
|
||||
"""
|
||||
return treedef.num_nodes == 1
|
||||
|
||||
|
||||
# treedef_is_strict_leaf is not exported.
|
||||
def treedef_is_strict_leaf(treedef: PyTreeDef) -> bool:
|
||||
return treedef.num_nodes == 1 and treedef.num_leaves == 1
|
||||
|
||||
|
||||
@export
|
||||
def all_leaves(iterable: Iterable[Any],
|
||||
is_leaf: Callable[[Any], bool] | None = None) -> bool:
|
||||
"""Tests whether all elements in the given iterable are all leaves.
|
||||
|
||||
>>> tree = {"a": [1, 2, 3]}
|
||||
>>> assert all_leaves(jax.tree_util.tree_leaves(tree))
|
||||
>>> assert not all_leaves([tree])
|
||||
|
||||
This function is useful in advanced cases, for example if a library allows
|
||||
arbitrary map operations on a flat iterable of leaves it may want to check
|
||||
if the result is still a flat iterable of leaves.
|
||||
@ -219,6 +287,12 @@ def all_leaves(iterable: Iterable[Any],
|
||||
|
||||
Returns:
|
||||
A boolean indicating if all elements in the input are leaves.
|
||||
|
||||
Example:
|
||||
>>> import jax
|
||||
>>> tree = {"a": [1, 2, 3]}
|
||||
>>> assert all_leaves(jax.tree_util.tree_leaves(tree))
|
||||
>>> assert not all_leaves([tree])
|
||||
"""
|
||||
if is_leaf is None:
|
||||
return pytree.all_leaves(default_registry, iterable)
|
||||
@ -230,15 +304,17 @@ def all_leaves(iterable: Iterable[Any],
|
||||
_Children = TypeVar("_Children", bound=Iterable[Any])
|
||||
_AuxData = TypeVar("_AuxData", bound=Hashable)
|
||||
|
||||
|
||||
@export
|
||||
def register_pytree_node(nodetype: type[T],
|
||||
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
|
||||
unflatten_func: Callable[[_AuxData, _Children], T]):
|
||||
unflatten_func: Callable[[_AuxData, _Children], T]) -> None:
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
See :ref:`example usage <pytrees>`.
|
||||
|
||||
Args:
|
||||
nodetype: a Python type to treat as an internal pytree node.
|
||||
nodetype: a Python type to register as a pytree.
|
||||
flatten_func: a function to be used during flattening, taking a value of
|
||||
type ``nodetype`` and returning a pair, with (1) an iterable for the
|
||||
children to be flattened recursively, and (2) some hashable auxiliary data
|
||||
@ -247,6 +323,55 @@ def register_pytree_node(nodetype: type[T],
|
||||
returned by ``flatten_func`` and stored in the treedef, and the
|
||||
unflattened children. The function should return an instance of
|
||||
``nodetype``.
|
||||
|
||||
See also:
|
||||
- :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree.
|
||||
- :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass.
|
||||
- :func:`~jax.tree_util.register_pytree_with_keys`
|
||||
- :func:`~jax.tree_util.register_pytree_node_class`
|
||||
- :func:`~jax.tree_util.register_pytree_with_keys_class`
|
||||
|
||||
Example:
|
||||
First we'll define a custom type:
|
||||
|
||||
>>> class MyContainer:
|
||||
... def __init__(self, size):
|
||||
... self.x = jnp.zeros(size)
|
||||
... self.y = jnp.ones(size)
|
||||
... self.size = size
|
||||
|
||||
If we try using this in a JIT-compiled function, we'll get an error because JAX
|
||||
does not yet know how to handle this type:
|
||||
|
||||
>>> m = MyContainer(size=5)
|
||||
>>> def f(m):
|
||||
... return m.x + m.y + jnp.arange(m.size)
|
||||
>>> jax.jit(f)(m) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: Cannot interpret value of type <class 'jax.tree_util.MyContainer'> as an abstract array; it does not have a dtype attribute
|
||||
|
||||
In order to make our object recognized by JAX, we must register it as
|
||||
a pytree:
|
||||
|
||||
>>> def flatten_func(obj):
|
||||
... children = (obj.x, obj.y) # children must contain arrays & pytrees
|
||||
... aux_data = (obj.size,) # aux_data must contain static, hashable data.
|
||||
... return (children, aux_data)
|
||||
...
|
||||
>>> def unflatten_func(aux_data, children):
|
||||
... # Here we avoid `__init__` because it has extra logic we don't require:
|
||||
... obj = object.__new__(MyContainer)
|
||||
... obj.x, obj.y = children
|
||||
... obj.size, = aux_data
|
||||
... return obj
|
||||
...
|
||||
>>> jax.tree_util.register_pytree_node(MyContainer, flatten_func, unflatten_func)
|
||||
|
||||
Now with this defined, we can use instances of this type in JIT-compiled functions.
|
||||
|
||||
>>> jax.jit(f)(m)
|
||||
Array([1., 2., 3., 4., 5.], dtype=float32)
|
||||
"""
|
||||
default_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
@ -254,27 +379,55 @@ def register_pytree_node(nodetype: type[T],
|
||||
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
|
||||
|
||||
|
||||
@export
|
||||
def register_pytree_node_class(cls: Typ) -> Typ:
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
This function is a thin wrapper around ``register_pytree_node``, and provides
|
||||
a class-oriented interface::
|
||||
a class-oriented interface.
|
||||
|
||||
@register_pytree_node_class
|
||||
class Special:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
def tree_flatten(self):
|
||||
return ((self.x, self.y), None)
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
Args:
|
||||
cls: a type to register as a pytree
|
||||
|
||||
Returns:
|
||||
The input class ``cls`` is returned unchanged after being added to JAX's pytree
|
||||
registry. This return value allows ``register_pytree_node_class`` to be used as
|
||||
a decorator.
|
||||
|
||||
See also:
|
||||
- :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree.
|
||||
- :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass.
|
||||
- :func:`~jax.tree_util.register_pytree_node`
|
||||
- :func:`~jax.tree_util.register_pytree_with_keys`
|
||||
- :func:`~jax.tree_util.register_pytree_with_keys_class`
|
||||
|
||||
Example:
|
||||
Here we'll define a custom container that will be compatible with :func:`jax.jit`
|
||||
and other JAX transformations:
|
||||
|
||||
>>> import jax
|
||||
>>> @jax.tree_util.register_pytree_node_class
|
||||
... class MyContainer:
|
||||
... def __init__(self, x, y):
|
||||
... self.x = x
|
||||
... self.y = y
|
||||
... def tree_flatten(self):
|
||||
... return ((self.x, self.y), None)
|
||||
... @classmethod
|
||||
... def tree_unflatten(cls, aux_data, children):
|
||||
... return cls(*children)
|
||||
...
|
||||
>>> m = MyContainer(jnp.zeros(4), jnp.arange(4))
|
||||
>>> def f(m):
|
||||
... return m.x + 2 * m.y
|
||||
>>> jax.jit(f)(m)
|
||||
Array([0., 2., 4., 6.], dtype=float32)
|
||||
"""
|
||||
register_pytree_node(cls, op.methodcaller("tree_flatten"), cls.tree_unflatten)
|
||||
return cls
|
||||
|
||||
|
||||
@export
|
||||
def tree_map(f: Callable[..., Any],
|
||||
tree: Any,
|
||||
*rest: Any,
|
||||
@ -320,10 +473,38 @@ def tree_map(f: Callable[..., Any],
|
||||
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
|
||||
|
||||
|
||||
@export
|
||||
def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
|
||||
"""Build a treedef from a nested iterable structure
|
||||
|
||||
Args:
|
||||
treedef: the PyTreeDef structure to build.
|
||||
xs: nested iterables matching the arity as the treedef
|
||||
|
||||
Returns:
|
||||
object with structure defined by treedef
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.unflatten`
|
||||
|
||||
Example:
|
||||
>>> import jax
|
||||
>>> tree = [(1, 2), {'a': 3, 'b': 4}]
|
||||
>>> treedef = jax.tree.structure(tree)
|
||||
|
||||
Both ``build_tree`` and :func:`jax.tree_util.tree_unflatten` can reconstruct
|
||||
the tree from new values, but ``build_tree`` takes these values in terms of
|
||||
a nested rather than flat structure:
|
||||
|
||||
>>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]])
|
||||
[(10, 11), {'a': 12, 'b': 13}]
|
||||
>>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13])
|
||||
[(10, 11), {'a': 12, 'b': 13}]
|
||||
"""
|
||||
return treedef.from_iterable_tree(xs)
|
||||
|
||||
|
||||
@export
|
||||
def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None,
|
||||
pytree_to_transpose: Any) -> Any:
|
||||
"""Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).
|
||||
@ -406,6 +587,7 @@ def tree_reduce(function: Callable[[T, Any], T],
|
||||
...
|
||||
|
||||
|
||||
@export
|
||||
def tree_reduce(function: Callable[[T, Any], T],
|
||||
tree: Any,
|
||||
initializer: Any = no_initializer,
|
||||
@ -439,6 +621,8 @@ def tree_reduce(function: Callable[[T, Any], T],
|
||||
else:
|
||||
return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf), initializer)
|
||||
|
||||
|
||||
@export
|
||||
def tree_all(tree: Any) -> bool:
|
||||
"""Call all() over the leaves of a tree.
|
||||
|
||||
@ -449,12 +633,15 @@ def tree_all(tree: Any) -> bool:
|
||||
result: boolean True or False
|
||||
|
||||
Examples:
|
||||
|
||||
>>> import jax
|
||||
>>> jax.tree.all([True, {'a': True, 'b': (True, True)}])
|
||||
True
|
||||
>>> jax.tree.all([False, (True, False)])
|
||||
False
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree_util.tree_reduce`
|
||||
- :func:`jax.tree_util.tree_leaves`
|
||||
"""
|
||||
return all(tree_leaves(tree))
|
||||
|
||||
@ -495,6 +682,7 @@ class _HashableCallableShim:
|
||||
return f'_HashableCallableShim({self.fun!r})'
|
||||
|
||||
|
||||
@export
|
||||
class Partial(functools.partial):
|
||||
"""A version of functools.partial that works in pytrees.
|
||||
|
||||
@ -531,8 +719,7 @@ class Partial(functools.partial):
|
||||
Array(3, dtype=int32, weak_type=True)
|
||||
|
||||
Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in
|
||||
a
|
||||
``TypeError``.
|
||||
a ``TypeError``.
|
||||
|
||||
Note that if the result of ``Partial`` is used in the context where the
|
||||
value is traced, it results in all bound arguments being traced when passed
|
||||
@ -570,6 +757,7 @@ register_pytree_node(
|
||||
)
|
||||
|
||||
|
||||
# broadcast_prefix is not exported.
|
||||
def broadcast_prefix(prefix_tree: Any, full_tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> list[Any]:
|
||||
@ -583,19 +771,29 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
|
||||
return result
|
||||
|
||||
|
||||
# flatten_one_level is not exported.
|
||||
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
|
||||
"""Flatten the given pytree node by one level.
|
||||
|
||||
Args:
|
||||
pytree: A valid pytree node, either built-in or registered via
|
||||
``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
:func:`register_pytree_node` or related functions.
|
||||
|
||||
Returns:
|
||||
A pair of the pytree's flattened children and its hashable metadata.
|
||||
A pair of the pytrees flattened children and its hashable metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If the given pytree is not a built-in or registered container
|
||||
via ``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
|
||||
Example:
|
||||
>>> import jax
|
||||
>>> from jax._src.tree_util import flatten_one_level
|
||||
>>> flattened, meta = flatten_one_level({'a': [1, 2], 'b': {'c': 3}})
|
||||
>>> flattened
|
||||
([1, 2], {'c': 3})
|
||||
>>> meta
|
||||
('a', 'b')
|
||||
"""
|
||||
out = default_registry.flatten_one_level(pytree)
|
||||
if out is None:
|
||||
@ -603,11 +801,15 @@ def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
# prefix_errors is not exported
|
||||
def prefix_errors(prefix_tree: Any, full_tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
) -> list[Callable[[str], ValueError]]:
|
||||
return list(_prefix_error((), prefix_tree, full_tree, is_leaf))
|
||||
|
||||
|
||||
# equality_errors is not exported
|
||||
def equality_errors(
|
||||
tree1: Any, tree2: Any, is_leaf: Callable[[Any], bool] | None = None,
|
||||
) -> Iterable[tuple[KeyPath, str, str, str]]:
|
||||
@ -687,26 +889,37 @@ def _equality_errors(path, t1, t2, is_leaf):
|
||||
yield from _equality_errors((*path, k), c1, c2, is_leaf)
|
||||
|
||||
|
||||
@export
|
||||
@dataclass(frozen=True)
|
||||
class SequenceKey():
|
||||
"""Struct for use with :func:`jax.tree_util.register_pytree_with_keys`."""
|
||||
idx: int
|
||||
def __str__(self):
|
||||
return f'[{self.idx!r}]'
|
||||
|
||||
|
||||
@export
|
||||
@dataclass(frozen=True)
|
||||
class DictKey():
|
||||
"""Struct for use with :func:`jax.tree_util.register_pytree_with_keys`."""
|
||||
key: Hashable
|
||||
def __str__(self):
|
||||
return f'[{self.key!r}]'
|
||||
|
||||
|
||||
@export
|
||||
@dataclass(frozen=True)
|
||||
class GetAttrKey():
|
||||
"""Struct for use with :func:`jax.tree_util.register_pytree_with_keys`."""
|
||||
name: str
|
||||
def __str__(self):
|
||||
return f'.{self.name}'
|
||||
|
||||
|
||||
@export
|
||||
@dataclass(frozen=True)
|
||||
class FlattenedIndexKey():
|
||||
"""Struct for use with :func:`jax.tree_util.register_pytree_with_keys`."""
|
||||
key: int
|
||||
def __str__(self):
|
||||
return f'[<flat index {self.key}>]'
|
||||
@ -716,6 +929,8 @@ BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey]
|
||||
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
|
||||
KeyPath = tuple[KeyEntry, ...]
|
||||
|
||||
|
||||
@export
|
||||
def keystr(keys: KeyPath):
|
||||
"""Helper to pretty-print a tuple of keys.
|
||||
|
||||
@ -724,6 +939,12 @@ def keystr(keys: KeyPath):
|
||||
|
||||
Returns:
|
||||
A string that joins all string representations of the keys.
|
||||
|
||||
Example:
|
||||
>>> import jax
|
||||
>>> keys = (0, 1, 'a', 'b')
|
||||
>>> jax.tree_util.keystr(keys)
|
||||
'01ab'
|
||||
"""
|
||||
return ''.join([str(k) for k in keys])
|
||||
|
||||
@ -764,6 +985,7 @@ _register_keypaths(
|
||||
)
|
||||
|
||||
|
||||
@export
|
||||
def register_pytree_with_keys(
|
||||
nodetype: type[T],
|
||||
flatten_with_keys: Callable[
|
||||
@ -794,6 +1016,38 @@ def register_pytree_with_keys(
|
||||
in the same order as ``flatten_with_keys``, and return the same aux data.
|
||||
This argument is optional and only needed for faster traversal when
|
||||
calling functions without keys like ``tree_map`` and ``tree_flatten``.
|
||||
|
||||
Example:
|
||||
First we'll define a custom type:
|
||||
|
||||
>>> class MyContainer:
|
||||
... def __init__(self, size):
|
||||
... self.x = jnp.zeros(size)
|
||||
... self.y = jnp.ones(size)
|
||||
... self.size = size
|
||||
|
||||
Now register it using a key-aware flatten function:
|
||||
|
||||
>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
|
||||
>>> def flatten_with_keys(obj):
|
||||
... children = [(GetAttrKey('x'), obj.x),
|
||||
... (GetAttrKey('y'), obj.y)] # children must contain arrays & pytrees
|
||||
... aux_data = (obj.size,) # aux_data must contain static, hashable data.
|
||||
... return children, aux_data
|
||||
...
|
||||
>>> def unflatten(aux_data, children):
|
||||
... # Here we avoid `__init__` because it has extra logic we don't require:
|
||||
... obj = object.__new__(MyContainer)
|
||||
... obj.x, obj.y = children
|
||||
... obj.size, = aux_data
|
||||
... return obj
|
||||
...
|
||||
>>> jax.tree_util.register_pytree_node(MyContainer, flatten_with_keys, unflatten)
|
||||
|
||||
Now this can be used with functions like :func:`~jax.tree_util.tree_flatten_with_path`:
|
||||
|
||||
>>> m = MyContainer(4)
|
||||
>>> leaves, treedef = jax.tree_util.tree_flatten_with_path(m)
|
||||
"""
|
||||
if not flatten_func:
|
||||
def flatten_func_impl(tree):
|
||||
@ -807,6 +1061,7 @@ def register_pytree_with_keys(
|
||||
)
|
||||
|
||||
|
||||
@export
|
||||
def register_pytree_with_keys_class(cls: Typ) -> Typ:
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
@ -814,18 +1069,35 @@ def register_pytree_with_keys_class(cls: Typ) -> Typ:
|
||||
class that defines how it could be flattened with keys.
|
||||
|
||||
It is a thin wrapper around ``register_pytree_with_keys``, and
|
||||
provides a class-oriented interface::
|
||||
provides a class-oriented interface:
|
||||
|
||||
@register_pytree_with_keys_class
|
||||
class Special:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
def tree_flatten_with_keys(self):
|
||||
return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None)
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
Args:
|
||||
cls: a type to register as a pytree
|
||||
|
||||
Returns:
|
||||
The input class ``cls`` is returned unchanged after being added to JAX's pytree
|
||||
registry. This return value allows ``register_pytree_node_class`` to be used as
|
||||
a decorator.
|
||||
|
||||
See also:
|
||||
- :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree.
|
||||
- :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass.
|
||||
- :func:`~jax.tree_util.register_pytree_node`
|
||||
- :func:`~jax.tree_util.register_pytree_with_keys`
|
||||
- :func:`~jax.tree_util.register_pytree_node_class`
|
||||
|
||||
Example:
|
||||
>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
|
||||
>>> @register_pytree_with_keys_class
|
||||
... class Special:
|
||||
... def __init__(self, x, y):
|
||||
... self.x = x
|
||||
... self.y = y
|
||||
... def tree_flatten_with_keys(self):
|
||||
... return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None)
|
||||
... @classmethod
|
||||
... def tree_unflatten(cls, aux_data, children):
|
||||
... return cls(*children)
|
||||
"""
|
||||
flatten_func = (
|
||||
op.methodcaller("tree_flatten") if hasattr(cls, "tree_flatten") else None
|
||||
@ -837,6 +1109,7 @@ def register_pytree_with_keys_class(cls: Typ) -> Typ:
|
||||
return cls
|
||||
|
||||
|
||||
@export
|
||||
def register_dataclass(
|
||||
nodetype: Typ, data_fields: Sequence[str], meta_fields: Sequence[str]
|
||||
) -> Typ:
|
||||
@ -940,23 +1213,37 @@ def register_dataclass(
|
||||
return nodetype
|
||||
|
||||
|
||||
@export
|
||||
def register_static(cls: type[H]) -> type[H]:
|
||||
"""Registers `cls` as a pytree with no leaves.
|
||||
|
||||
Instances are treated as static by `jax.jit`, `jax.pmap`, etc. This can be an
|
||||
alternative to labeling inputs as static using `jax.jit`'s `static_argnums`
|
||||
and `static_argnames` kwargs, `jax.pmap`'s `static_broadcasted_argnums`, etc.
|
||||
Instances are treated as static by :func:`jax.jit`, :func:`jax.pmap`, etc. This can
|
||||
be an alternative to labeling inputs as static using ``jit``'s ``static_argnums``
|
||||
and ``static_argnames`` kwargs, ``pmap``'s ``static_broadcasted_argnums``, etc.
|
||||
|
||||
`cls` must be hashable, as defined in
|
||||
https://docs.python.org/3/glossary.html#term-hashable.
|
||||
Args:
|
||||
cls: type to be registered as static. Must be hashable, as defined in
|
||||
https://docs.python.org/3/glossary.html#term-hashable.
|
||||
|
||||
`register_static` can be applied to subclasses of builtin hashable classes
|
||||
such as `str`, like this:
|
||||
```
|
||||
@tree_util.register_static
|
||||
class StaticStr(str):
|
||||
pass
|
||||
```
|
||||
Returns:
|
||||
The input class ``cls`` is returned unchanged after being added to JAX's
|
||||
pytree registry. This allows ``register_static`` to be used as a decorator.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> @jax.tree_util.register_static
|
||||
... class StaticStr(str):
|
||||
... pass
|
||||
|
||||
This static string can now be used directly in :func:`jax.jit`-compiled
|
||||
functions, without marking the variable static using ``static_argnums``:
|
||||
|
||||
>>> @jax.jit
|
||||
... def f(x, y, s):
|
||||
... return x + y if s == 'add' else x - y
|
||||
...
|
||||
>>> f(1, 2, StaticStr('add'))
|
||||
Array(3, dtype=int32, weak_type=True)
|
||||
"""
|
||||
flatten = lambda obj: ((), obj)
|
||||
unflatten = lambda obj, empty_iter_children: obj
|
||||
@ -964,6 +1251,7 @@ def register_static(cls: type[H]) -> type[H]:
|
||||
return cls
|
||||
|
||||
|
||||
@export
|
||||
def tree_flatten_with_path(
|
||||
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
|
||||
@ -981,6 +1269,7 @@ def tree_flatten_with_path(
|
||||
return _generate_key_paths(tree, is_leaf), tree_def
|
||||
|
||||
|
||||
@export
|
||||
def tree_leaves_with_path(
|
||||
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> list[tuple[KeyPath, Any]]:
|
||||
@ -991,10 +1280,15 @@ def tree_leaves_with_path(
|
||||
``register_pytree_with_keys``.
|
||||
Returns:
|
||||
A list of key-leaf pairs, each of which contains a leaf and its key path.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree_util.tree_leaves`
|
||||
- :func:`jax.tree_util.tree_flatten_with_path`
|
||||
"""
|
||||
return _generate_key_paths(tree, is_leaf)
|
||||
|
||||
|
||||
# generate_key_paths is not exported.
|
||||
def generate_key_paths(
|
||||
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> list[tuple[KeyPath, Any]]:
|
||||
@ -1036,6 +1330,7 @@ def _generate_key_paths_(
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
|
||||
|
||||
@export
|
||||
def tree_map_with_path(f: Callable[..., Any],
|
||||
tree: Any, *rest: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None) -> Any:
|
||||
@ -1057,6 +1352,11 @@ def tree_map_with_path(f: Callable[..., Any],
|
||||
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
|
||||
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
|
||||
the tuple of values at corresponding nodes in ``rest``.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree_util.tree_map`
|
||||
- :func:`jax.tree_util.tree_flatten_with_path`
|
||||
- :func:`jax.tree_util.tree_leaves_with_path`
|
||||
"""
|
||||
|
||||
keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf)
|
||||
|
@ -32,6 +32,7 @@ class PackageStructureTest(jtu.JaxTestCase):
|
||||
# TODO(jakevdp): expand test to other public modules.
|
||||
_mod("jax.errors"),
|
||||
_mod("jax.nn.initializers"),
|
||||
_mod("jax.tree_util", exclude=['PyTreeDef', 'default_registry']),
|
||||
])
|
||||
def test_exported_names_match_module(self, module_name, include, exclude):
|
||||
"""Test that all public exports have __module__ set correctly."""
|
||||
@ -43,7 +44,8 @@ class PackageStructureTest(jtu.JaxTestCase):
|
||||
obj = getattr(module, name)
|
||||
if isinstance(obj, types.ModuleType):
|
||||
continue
|
||||
self.assertEqual(obj.__module__, module_name)
|
||||
self.assertEqual(obj.__module__, module_name,
|
||||
f"{obj} has {obj.__module__=}, expected {module_name}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user