mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 18:06:07 +00:00
271 lines
9.4 KiB
ReStructuredText
271 lines
9.4 KiB
ReStructuredText
.. _pytrees:
|
|
|
|
Pytrees
|
|
=======
|
|
|
|
What is a pytree?
|
|
^^^^^^^^^^^^^^^^^
|
|
|
|
In JAX, we use the term *pytree* to refer to a tree-like structure built out of
|
|
container-like Python objects. Classes are considered container-like if they
|
|
are in the pytree registry, which by default includes lists, tuples, and dicts.
|
|
That is:
|
|
|
|
1. any object whose type is *not* in the pytree container registry is
|
|
considered a *leaf* pytree;
|
|
2. any object whose type is in the pytree container registry, and which
|
|
contains pytrees, is considered a pytree.
|
|
|
|
For each entry in the pytree container registry, a container-like type is
|
|
registered with a pair of functions which specify how to convert an instance of
|
|
the container type to a ``(children, metadata)`` pair and how to convert such a
|
|
pair back to an instance of the container type. Using these functions, JAX can
|
|
canonicalize any tree of registered container objects into tuples.
|
|
|
|
Example pytrees::
|
|
|
|
[1, "a", object()] # 3 leaves
|
|
|
|
(1, (2, 3), ()) # 3 leaves
|
|
|
|
[1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves
|
|
|
|
JAX can be extended to consider other container types as pytrees; see
|
|
`Extending pytrees`_ below.
|
|
|
|
Pytrees and JAX functions
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
Many JAX functions, like ``jax.lax.scan``, operate over pytrees of arrays.
|
|
JAX function transformations can be applied to functions that accept as input
|
|
and produce as output pytrees of arrays.
|
|
|
|
Applying optional parameters to pytrees
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
Some JAX function transformations take optional parameters that specify how
|
|
certain input or output values should be treated (e.g. the ``in_axes`` and
|
|
``out_axes`` arguments to ``vmap``). These parameters can also be pytrees, and
|
|
their structure must correspond to the pytree structure of the corresponding
|
|
arguments. In particular, to be able to "match up" leaves in these parameter
|
|
pytrees with values in the argument pytrees, the parameter pytrees are often
|
|
constrained to be tree prefixes of the argument pytrees.
|
|
|
|
For example, if we pass the following input to ``vmap`` (note that the input
|
|
arguments to a function considered a tuple)::
|
|
|
|
(a1, {"k1": a2, "k2": a3})
|
|
|
|
We can use the following ``in_axes`` pytree to specify that only the ``k2``
|
|
argument is mapped (``axis=0``) and the rest aren't mapped over
|
|
(``axis=None``)::
|
|
|
|
(None, {"k1": None, "k2": 0})
|
|
|
|
The optional parameter pytree structure must match that of the main input
|
|
pytree. However, the optional parameters can optionally be specified as a
|
|
"prefix" pytree, meaning that a single leaf value can be applied to an entire
|
|
sub-pytree. For example, if we have the same ``vmap`` input as above, but wish
|
|
to only map over the dictionary argument, we can use::
|
|
|
|
(None, 0) # equivalent to (None, {"k1": 0, "k2": 0})
|
|
|
|
Or, if we want every argument to be mapped, we can simply write a single leaf
|
|
value that is applied over the entire argument tuple pytree::
|
|
|
|
0
|
|
|
|
This happens to be the default ``in_axes`` value for ``vmap``!
|
|
|
|
The same logic applies to other optional parameters that refer to specific input
|
|
or output values of a transformed function, e.g. ``vmap``'s ``out_axes``.
|
|
|
|
Viewing the pytree definition of an object
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
To view the pytree definition of an arbitrary ``object`` for debugging purposes, you can use::
|
|
|
|
from jax.tree_util import tree_structure
|
|
print(tree_structure(object))
|
|
|
|
Developer information
|
|
^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
*This is primarily JAX internal documentation, end-users are not supposed to need
|
|
to understand this to use JAX, except when registering new user-defined
|
|
container types with JAX. Some of these details may change.*
|
|
|
|
Internal pytree handling
|
|
------------------------
|
|
|
|
JAX flattens pytrees into lists of leaves at the ``api.py`` boundary (and also
|
|
in control flow primitives). This keeps downstream JAX internals simpler:
|
|
transformations like ``grad``, ``jit``, and ``vmap`` can handle user functions
|
|
that accept and return the myriad different Python containers, while all the
|
|
other parts of the system can operate on functions that only take (multiple)
|
|
array arguments and always return a flat list of arrays.
|
|
|
|
When JAX flattens a pytree it will produce a list of leaves and a ``treedef``
|
|
object that encodes the structure of the original value. The ``treedef`` can
|
|
then be used to construct a matching structured value after transforming the
|
|
leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we
|
|
handle them assuming referential transparency and that they can't contain
|
|
reference cycles.
|
|
|
|
Here is a simple example::
|
|
|
|
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
|
import jax.numpy as jnp
|
|
|
|
# The structured value to be transformed
|
|
value_structured = [1., (2., 3.)]
|
|
|
|
# The leaves in value_flat correspond to the `*` markers in value_tree
|
|
value_flat, value_tree = tree_flatten(value_structured)
|
|
print("value_flat={}\nvalue_tree={}".format(value_flat, value_tree))
|
|
|
|
# Transform the flat value list using an element-wise numeric transformer
|
|
transformed_flat = list(map(lambda v: v * 2., value_flat))
|
|
print("transformed_flat={}".format(transformed_flat))
|
|
|
|
# Reconstruct the structured output, using the original
|
|
transformed_structured = tree_unflatten(value_tree, transformed_flat)
|
|
print("transformed_structured={}".format(transformed_structured))
|
|
|
|
# Output:
|
|
# value_flat=[1.0, 2.0, 3.0]
|
|
# value_tree=PyTreeDef(list, [*,PyTreeDef(tuple, [*,*])])
|
|
# transformed_flat=[2.0, 4.0, 6.0]
|
|
# transformed_structured=[2.0, (4.0, 6.0)]
|
|
|
|
By default, pytree containers can be lists, tuples, dicts, namedtuple, None,
|
|
OrderedDict. Other types of values, including numeric and ndarray values, are
|
|
treated as leaves::
|
|
|
|
from collections import namedtuple
|
|
Point = namedtuple('Point', ['x', 'y'])
|
|
|
|
example_containers = [
|
|
(1., [2., 3.]),
|
|
(1., {'b': 2., 'a': 3.}),
|
|
1.,
|
|
None,
|
|
jnp.zeros(2),
|
|
Point(1., 2.)
|
|
]
|
|
def show_example(structured):
|
|
flat, tree = tree_flatten(structured)
|
|
unflattened = tree_unflatten(tree, flat)
|
|
print("structured={}\n flat={}\n tree={}\n unflattened={}".format(
|
|
structured, flat, tree, unflattened))
|
|
|
|
for structured in example_containers:
|
|
show_example(structured)
|
|
|
|
# Output:
|
|
# structured=(1.0, [2.0, 3.0])
|
|
# flat=[1.0, 2.0, 3.0]
|
|
# tree=PyTreeDef(tuple, [*,PyTreeDef(list, [*,*])])
|
|
# unflattened=(1.0, [2.0, 3.0])
|
|
# structured=(1.0, {'b': 2.0, 'a': 3.0})
|
|
# flat=[1.0, 3.0, 2.0]
|
|
# tree=PyTreeDef(tuple, [*,PyTreeDef(dict[['a', 'b']], [*,*])])
|
|
# unflattened=(1.0, {'a': 3.0, 'b': 2.0})
|
|
# structured=1.0
|
|
# flat=[1.0]
|
|
# tree=*
|
|
# unflattened=1.0
|
|
# structured=None
|
|
# flat=[]
|
|
# tree=PyTreeDef(None, [])
|
|
# unflattened=None
|
|
# structured=[0. 0.]
|
|
# flat=[DeviceArray([0., 0.], dtype=float32)]
|
|
# tree=*
|
|
# unflattened=[0. 0.]
|
|
# structured=Point(x=1.0, y=2.0)
|
|
# flat=[1.0, 2.0]
|
|
# tree=PyTreeDef(namedtuple[<class '__main__.Point'>], [*,*])
|
|
# unflattened=Point(x=1.0, y=2.0)
|
|
|
|
Extending pytrees
|
|
-----------------
|
|
|
|
By default, any part of a structured value that is not recognized as an
|
|
internal pytree node (i.e. container-like) is treated as a leaf::
|
|
|
|
class Special(object):
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
def __repr__(self):
|
|
return "Special(x={}, y={})".format(self.x, self.y)
|
|
|
|
|
|
show_example(Special(1., 2.))
|
|
|
|
# Output:
|
|
# structured=Special(x=1.0, y=2.0)
|
|
# flat=[Special(x=1.0, y=2.0)]
|
|
# tree=*
|
|
# unflattened=Special(x=1.0, y=2.0)
|
|
|
|
The set of Python types that are considered internal pytree nodes is extensible,
|
|
through a global registry of types. Values of registered types are traversed
|
|
recursively::
|
|
|
|
class RegisteredSpecial(Special):
|
|
def __repr__(self):
|
|
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
|
|
|
|
def special_flatten(v):
|
|
"""Specifies a flattening recipe.
|
|
|
|
Params:
|
|
v: the value of registered type to flatten.
|
|
Returns:
|
|
a pair of an iterable with the children to be flattened recursively,
|
|
and some opaque auxiliary data to pass back to the unflattening recipe.
|
|
The auxiliary data is stored in the treedef for use during unflattening.
|
|
The auxiliary data could be used, e.g., for dictionary keys.
|
|
"""
|
|
children = (v.x, v.y)
|
|
aux_data = None
|
|
return (children, aux_data)
|
|
|
|
def special_unflatten(aux_data, children):
|
|
"""Specifies an unflattening recipe.
|
|
|
|
Params:
|
|
aux_data: the opaque data that was specified during flattening of the
|
|
current treedef.
|
|
children: the unflattened children
|
|
|
|
Returns:
|
|
a re-constructed object of the registered type, using the specified
|
|
children and auxiliary data.
|
|
"""
|
|
return RegisteredSpecial(*children)
|
|
|
|
# Global registration
|
|
register_pytree_node(
|
|
RegisteredSpecial,
|
|
special_flatten, # tell JAX what are the children nodes
|
|
special_unflatten # tell JAX how to pack back into a RegisteredSpecial
|
|
)
|
|
|
|
show_example(RegisteredSpecial(1., 2.))
|
|
|
|
# Output:
|
|
# structured=RegisteredSpecial(x=1.0, y=2.0)
|
|
# flat=[1.0, 2.0]
|
|
# tree=PyTreeDef(<class '__main__.RegisteredSpecial'>[None], [*,*])
|
|
# unflattened=RegisteredSpecial(x=1.0, y=2.0)
|
|
|
|
JAX needs sometimes to compare ``treedef`` for equality. Therefore, care must be
|
|
taken to ensure that the auxiliary data specified in the flattening recipe
|
|
supports a meaningful equality comparison.
|
|
|
|
The whole set of functions for operating on pytrees are in `tree_util module
|
|
<https://jax.readthedocs.io/en/latest/jax.tree_util.html>`_.
|