mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix broken example in pytrees.rst & migrate to myst-md
This commit is contained in:
parent
7e040c0a89
commit
35983ebba6
260
docs/pytrees.md
Normal file
260
docs/pytrees.md
Normal file
@ -0,0 +1,260 @@
|
||||
---
|
||||
jupytext:
|
||||
text_representation:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
name: python3
|
||||
language_info:
|
||||
name: python
|
||||
file_extension: .py
|
||||
---
|
||||
|
||||
(pytrees)=
|
||||
|
||||
# Pytrees
|
||||
|
||||
## What is a pytree?
|
||||
|
||||
In JAX, we use the term *pytree* to refer to a tree-like structure built out of
|
||||
container-like Python objects. Classes are considered container-like if they
|
||||
are in the pytree registry, which by default includes lists, tuples, and dicts.
|
||||
That is:
|
||||
|
||||
1. any object whose type is *not* in the pytree container registry is
|
||||
considered a *leaf* pytree;
|
||||
2. any object whose type is in the pytree container registry, and which
|
||||
contains pytrees, is considered a pytree.
|
||||
|
||||
For each entry in the pytree container registry, a container-like type is
|
||||
registered with a pair of functions which specify how to convert an instance of
|
||||
the container type to a `(children, metadata)` pair and how to convert such a
|
||||
pair back to an instance of the container type. Using these functions, JAX can
|
||||
canonicalize any tree of registered container objects into tuples.
|
||||
|
||||
Example pytrees:
|
||||
|
||||
```
|
||||
[1, "a", object()] # 3 leaves
|
||||
|
||||
(1, (2, 3), ()) # 3 leaves
|
||||
|
||||
[1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves
|
||||
```
|
||||
|
||||
JAX can be extended to consider other container types as pytrees; see
|
||||
[Extending pytrees](extending pytrees) below.
|
||||
|
||||
## Pytrees and JAX functions
|
||||
|
||||
Many JAX functions, like {func}`jax.lax.scan`, operate over pytrees of arrays.
|
||||
JAX function transformations can be applied to functions that accept as input
|
||||
and produce as output pytrees of arrays.
|
||||
|
||||
## Applying optional parameters to pytrees
|
||||
|
||||
Some JAX function transformations take optional parameters that specify how
|
||||
certain input or output values should be treated (e.g. the `in_axes` and
|
||||
`out_axes` arguments to {func}`~jax.vmap`). These parameters can also be pytrees,
|
||||
and their structure must correspond to the pytree structure of the corresponding
|
||||
arguments. In particular, to be able to "match up" leaves in these parameter
|
||||
pytrees with values in the argument pytrees, the parameter pytrees are often
|
||||
constrained to be tree prefixes of the argument pytrees.
|
||||
|
||||
For example, if we pass the following input to {func}`~jax.vmap` (note that the input
|
||||
arguments to a function considered a tuple):
|
||||
|
||||
```
|
||||
(a1, {"k1": a2, "k2": a3})
|
||||
```
|
||||
|
||||
We can use the following `in_axes` pytree to specify that only the `k2`
|
||||
argument is mapped (`axis=0`) and the rest aren't mapped over
|
||||
(`axis=None`):
|
||||
|
||||
```
|
||||
(None, {"k1": None, "k2": 0})
|
||||
```
|
||||
|
||||
The optional parameter pytree structure must match that of the main input
|
||||
pytree. However, the optional parameters can optionally be specified as a
|
||||
"prefix" pytree, meaning that a single leaf value can be applied to an entire
|
||||
sub-pytree. For example, if we have the same {func}`~jax.vmap` input as above,
|
||||
but wish to only map over the dictionary argument, we can use:
|
||||
|
||||
```
|
||||
(None, 0) # equivalent to (None, {"k1": 0, "k2": 0})
|
||||
```
|
||||
|
||||
Or, if we want every argument to be mapped, we can simply write a single leaf
|
||||
value that is applied over the entire argument tuple pytree:
|
||||
|
||||
```
|
||||
0
|
||||
```
|
||||
|
||||
This happens to be the default `in_axes` value for {func}`~jax.vmap`!
|
||||
|
||||
The same logic applies to other optional parameters that refer to specific input
|
||||
or output values of a transformed function, e.g. `vmap`'s `out_axes`.
|
||||
|
||||
## Viewing the pytree definition of an object
|
||||
|
||||
To view the pytree definition of an arbitrary `object` for debugging purposes, you can use:
|
||||
|
||||
```
|
||||
from jax.tree_util import tree_structure
|
||||
print(tree_structure(object))
|
||||
```
|
||||
|
||||
## Developer information
|
||||
|
||||
*This is primarily JAX internal documentation, end-users are not supposed to need
|
||||
to understand this to use JAX, except when registering new user-defined
|
||||
container types with JAX. Some of these details may change.*
|
||||
|
||||
### Internal pytree handling
|
||||
|
||||
JAX flattens pytrees into lists of leaves at the `api.py` boundary (and also
|
||||
in control flow primitives). This keeps downstream JAX internals simpler:
|
||||
transformations like {func}`~jax.grad`, {func}`~jax.jit`, and {func}`~jax.vmap`
|
||||
can handle user functions that accept and return the myriad different Python
|
||||
containers, while all the other parts of the system can operate on functions
|
||||
that only take (multiple) array arguments and always return a flat list of arrays.
|
||||
|
||||
When JAX flattens a pytree it will produce a list of leaves and a `treedef`
|
||||
object that encodes the structure of the original value. The `treedef` can
|
||||
then be used to construct a matching structured value after transforming the
|
||||
leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we
|
||||
handle them assuming referential transparency and that they can't contain
|
||||
reference cycles.
|
||||
|
||||
Here is a simple example:
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
|
||||
# Execute this to consume & hide the GPU warning.
|
||||
import jax.numpy as _jnp
|
||||
_jnp.arange(10)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
||||
import jax.numpy as jnp
|
||||
|
||||
# The structured value to be transformed
|
||||
value_structured = [1., (2., 3.)]
|
||||
|
||||
# The leaves in value_flat correspond to the `*` markers in value_tree
|
||||
value_flat, value_tree = tree_flatten(value_structured)
|
||||
print("value_flat={}\nvalue_tree={}".format(value_flat, value_tree))
|
||||
|
||||
# Transform the flat value list using an element-wise numeric transformer
|
||||
transformed_flat = list(map(lambda v: v * 2., value_flat))
|
||||
print("transformed_flat={}".format(transformed_flat))
|
||||
|
||||
# Reconstruct the structured output, using the original
|
||||
transformed_structured = tree_unflatten(value_tree, transformed_flat)
|
||||
print("transformed_structured={}".format(transformed_structured))
|
||||
```
|
||||
|
||||
By default, pytree containers can be lists, tuples, dicts, namedtuple, None,
|
||||
OrderedDict. Other types of values, including numeric and ndarray values, are
|
||||
treated as leaves:
|
||||
|
||||
```{code-cell}
|
||||
from collections import namedtuple
|
||||
Point = namedtuple('Point', ['x', 'y'])
|
||||
|
||||
example_containers = [
|
||||
(1., [2., 3.]),
|
||||
(1., {'b': 2., 'a': 3.}),
|
||||
1.,
|
||||
None,
|
||||
jnp.zeros(2),
|
||||
Point(1., 2.)
|
||||
]
|
||||
def show_example(structured):
|
||||
flat, tree = tree_flatten(structured)
|
||||
unflattened = tree_unflatten(tree, flat)
|
||||
print("structured={}\n flat={}\n tree={}\n unflattened={}".format(
|
||||
structured, flat, tree, unflattened))
|
||||
|
||||
for structured in example_containers:
|
||||
show_example(structured)
|
||||
```
|
||||
|
||||
### Extending pytrees
|
||||
|
||||
By default, any part of a structured value that is not recognized as an
|
||||
internal pytree node (i.e. container-like) is treated as a leaf:
|
||||
|
||||
```{code-cell}
|
||||
class Special(object):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __repr__(self):
|
||||
return "Special(x={}, y={})".format(self.x, self.y)
|
||||
|
||||
|
||||
show_example(Special(1., 2.))
|
||||
```
|
||||
|
||||
The set of Python types that are considered internal pytree nodes is extensible,
|
||||
through a global registry of types. Values of registered types are traversed
|
||||
recursively:
|
||||
|
||||
```{code-cell}
|
||||
class RegisteredSpecial(Special):
|
||||
def __repr__(self):
|
||||
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
|
||||
|
||||
def special_flatten(v):
|
||||
"""Specifies a flattening recipe.
|
||||
|
||||
Params:
|
||||
v: the value of registered type to flatten.
|
||||
Returns:
|
||||
a pair of an iterable with the children to be flattened recursively,
|
||||
and some opaque auxiliary data to pass back to the unflattening recipe.
|
||||
The auxiliary data is stored in the treedef for use during unflattening.
|
||||
The auxiliary data could be used, e.g., for dictionary keys.
|
||||
"""
|
||||
children = (v.x, v.y)
|
||||
aux_data = None
|
||||
return (children, aux_data)
|
||||
|
||||
def special_unflatten(aux_data, children):
|
||||
"""Specifies an unflattening recipe.
|
||||
|
||||
Params:
|
||||
aux_data: the opaque data that was specified during flattening of the
|
||||
current treedef.
|
||||
children: the unflattened children
|
||||
|
||||
Returns:
|
||||
a re-constructed object of the registered type, using the specified
|
||||
children and auxiliary data.
|
||||
"""
|
||||
return RegisteredSpecial(*children)
|
||||
|
||||
# Global registration
|
||||
register_pytree_node(
|
||||
RegisteredSpecial,
|
||||
special_flatten, # tell JAX what are the children nodes
|
||||
special_unflatten # tell JAX how to pack back into a RegisteredSpecial
|
||||
)
|
||||
|
||||
show_example(RegisteredSpecial(1., 2.))
|
||||
```
|
||||
|
||||
JAX needs sometimes to compare `treedef` for equality. Therefore, care must be
|
||||
taken to ensure that the auxiliary data specified in the flattening recipe
|
||||
supports a meaningful equality comparison.
|
||||
|
||||
The whole set of functions for operating on pytrees are in {mod}`jax.tree_util`.
|
270
docs/pytrees.rst
270
docs/pytrees.rst
@ -1,270 +0,0 @@
|
||||
.. _pytrees:
|
||||
|
||||
Pytrees
|
||||
=======
|
||||
|
||||
What is a pytree?
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
In JAX, we use the term *pytree* to refer to a tree-like structure built out of
|
||||
container-like Python objects. Classes are considered container-like if they
|
||||
are in the pytree registry, which by default includes lists, tuples, and dicts.
|
||||
That is:
|
||||
|
||||
1. any object whose type is *not* in the pytree container registry is
|
||||
considered a *leaf* pytree;
|
||||
2. any object whose type is in the pytree container registry, and which
|
||||
contains pytrees, is considered a pytree.
|
||||
|
||||
For each entry in the pytree container registry, a container-like type is
|
||||
registered with a pair of functions which specify how to convert an instance of
|
||||
the container type to a ``(children, metadata)`` pair and how to convert such a
|
||||
pair back to an instance of the container type. Using these functions, JAX can
|
||||
canonicalize any tree of registered container objects into tuples.
|
||||
|
||||
Example pytrees::
|
||||
|
||||
[1, "a", object()] # 3 leaves
|
||||
|
||||
(1, (2, 3), ()) # 3 leaves
|
||||
|
||||
[1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves
|
||||
|
||||
JAX can be extended to consider other container types as pytrees; see
|
||||
`Extending pytrees`_ below.
|
||||
|
||||
Pytrees and JAX functions
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Many JAX functions, like ``jax.lax.scan``, operate over pytrees of arrays.
|
||||
JAX function transformations can be applied to functions that accept as input
|
||||
and produce as output pytrees of arrays.
|
||||
|
||||
Applying optional parameters to pytrees
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Some JAX function transformations take optional parameters that specify how
|
||||
certain input or output values should be treated (e.g. the ``in_axes`` and
|
||||
``out_axes`` arguments to ``vmap``). These parameters can also be pytrees, and
|
||||
their structure must correspond to the pytree structure of the corresponding
|
||||
arguments. In particular, to be able to "match up" leaves in these parameter
|
||||
pytrees with values in the argument pytrees, the parameter pytrees are often
|
||||
constrained to be tree prefixes of the argument pytrees.
|
||||
|
||||
For example, if we pass the following input to ``vmap`` (note that the input
|
||||
arguments to a function considered a tuple)::
|
||||
|
||||
(a1, {"k1": a2, "k2": a3})
|
||||
|
||||
We can use the following ``in_axes`` pytree to specify that only the ``k2``
|
||||
argument is mapped (``axis=0``) and the rest aren't mapped over
|
||||
(``axis=None``)::
|
||||
|
||||
(None, {"k1": None, "k2": 0})
|
||||
|
||||
The optional parameter pytree structure must match that of the main input
|
||||
pytree. However, the optional parameters can optionally be specified as a
|
||||
"prefix" pytree, meaning that a single leaf value can be applied to an entire
|
||||
sub-pytree. For example, if we have the same ``vmap`` input as above, but wish
|
||||
to only map over the dictionary argument, we can use::
|
||||
|
||||
(None, 0) # equivalent to (None, {"k1": 0, "k2": 0})
|
||||
|
||||
Or, if we want every argument to be mapped, we can simply write a single leaf
|
||||
value that is applied over the entire argument tuple pytree::
|
||||
|
||||
0
|
||||
|
||||
This happens to be the default ``in_axes`` value for ``vmap``!
|
||||
|
||||
The same logic applies to other optional parameters that refer to specific input
|
||||
or output values of a transformed function, e.g. ``vmap``'s ``out_axes``.
|
||||
|
||||
Viewing the pytree definition of an object
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
To view the pytree definition of an arbitrary ``object`` for debugging purposes, you can use::
|
||||
|
||||
from jax.tree_util import tree_structure
|
||||
print(tree_structure(object))
|
||||
|
||||
Developer information
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
*This is primarily JAX internal documentation, end-users are not supposed to need
|
||||
to understand this to use JAX, except when registering new user-defined
|
||||
container types with JAX. Some of these details may change.*
|
||||
|
||||
Internal pytree handling
|
||||
------------------------
|
||||
|
||||
JAX flattens pytrees into lists of leaves at the ``api.py`` boundary (and also
|
||||
in control flow primitives). This keeps downstream JAX internals simpler:
|
||||
transformations like ``grad``, ``jit``, and ``vmap`` can handle user functions
|
||||
that accept and return the myriad different Python containers, while all the
|
||||
other parts of the system can operate on functions that only take (multiple)
|
||||
array arguments and always return a flat list of arrays.
|
||||
|
||||
When JAX flattens a pytree it will produce a list of leaves and a ``treedef``
|
||||
object that encodes the structure of the original value. The ``treedef`` can
|
||||
then be used to construct a matching structured value after transforming the
|
||||
leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we
|
||||
handle them assuming referential transparency and that they can't contain
|
||||
reference cycles.
|
||||
|
||||
Here is a simple example::
|
||||
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
||||
import jax.numpy as jnp
|
||||
|
||||
# The structured value to be transformed
|
||||
value_structured = [1., (2., 3.)]
|
||||
|
||||
# The leaves in value_flat correspond to the `*` markers in value_tree
|
||||
value_flat, value_tree = tree_flatten(value_structured)
|
||||
print("value_flat={}\nvalue_tree={}".format(value_flat, value_tree))
|
||||
|
||||
# Transform the flat value list using an element-wise numeric transformer
|
||||
transformed_flat = list(map(lambda v: v * 2., value_flat))
|
||||
print("transformed_flat={}".format(transformed_flat))
|
||||
|
||||
# Reconstruct the structured output, using the original
|
||||
transformed_structured = tree_unflatten(value_tree, transformed_flat)
|
||||
print("transformed_structured={}".format(transformed_structured))
|
||||
|
||||
# Output:
|
||||
# value_flat=[1.0, 2.0, 3.0]
|
||||
# value_tree=PyTreeDef(list, [*,PyTreeDef(tuple, [*,*])])
|
||||
# transformed_flat=[2.0, 4.0, 6.0]
|
||||
# transformed_structured=[2.0, (4.0, 6.0)]
|
||||
|
||||
By default, pytree containers can be lists, tuples, dicts, namedtuple, None,
|
||||
OrderedDict. Other types of values, including numeric and ndarray values, are
|
||||
treated as leaves::
|
||||
|
||||
from collections import namedtuple
|
||||
Point = namedtuple('Point', ['x', 'y'])
|
||||
|
||||
example_containers = [
|
||||
(1., [2., 3.]),
|
||||
(1., {'b': 2., 'a': 3.}),
|
||||
1.,
|
||||
None,
|
||||
jnp.zeros(2),
|
||||
Point(1., 2.)
|
||||
]
|
||||
def show_example(structured):
|
||||
flat, tree = tree_flatten(structured)
|
||||
unflattened = tree_unflatten(tree, flat)
|
||||
print("structured={}\n flat={}\n tree={}\n unflattened={}".format(
|
||||
structured, flat, tree, unflattened))
|
||||
|
||||
for structured in example_containers:
|
||||
show_example(structured)
|
||||
|
||||
# Output:
|
||||
# structured=(1.0, [2.0, 3.0])
|
||||
# flat=[1.0, 2.0, 3.0]
|
||||
# tree=PyTreeDef(tuple, [*,PyTreeDef(list, [*,*])])
|
||||
# unflattened=(1.0, [2.0, 3.0])
|
||||
# structured=(1.0, {'b': 2.0, 'a': 3.0})
|
||||
# flat=[1.0, 3.0, 2.0]
|
||||
# tree=PyTreeDef(tuple, [*,PyTreeDef(dict[['a', 'b']], [*,*])])
|
||||
# unflattened=(1.0, {'a': 3.0, 'b': 2.0})
|
||||
# structured=1.0
|
||||
# flat=[1.0]
|
||||
# tree=*
|
||||
# unflattened=1.0
|
||||
# structured=None
|
||||
# flat=[]
|
||||
# tree=PyTreeDef(None, [])
|
||||
# unflattened=None
|
||||
# structured=[0. 0.]
|
||||
# flat=[DeviceArray([0., 0.], dtype=float32)]
|
||||
# tree=*
|
||||
# unflattened=[0. 0.]
|
||||
# structured=Point(x=1.0, y=2.0)
|
||||
# flat=[1.0, 2.0]
|
||||
# tree=PyTreeDef(namedtuple[<class '__main__.Point'>], [*,*])
|
||||
# unflattened=Point(x=1.0, y=2.0)
|
||||
|
||||
Extending pytrees
|
||||
-----------------
|
||||
|
||||
By default, any part of a structured value that is not recognized as an
|
||||
internal pytree node (i.e. container-like) is treated as a leaf::
|
||||
|
||||
class Special(object):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __repr__(self):
|
||||
return "Special(x={}, y={})".format(self.x, self.y)
|
||||
|
||||
|
||||
show_example(Special(1., 2.))
|
||||
|
||||
# Output:
|
||||
# structured=Special(x=1.0, y=2.0)
|
||||
# flat=[Special(x=1.0, y=2.0)]
|
||||
# tree=*
|
||||
# unflattened=Special(x=1.0, y=2.0)
|
||||
|
||||
The set of Python types that are considered internal pytree nodes is extensible,
|
||||
through a global registry of types. Values of registered types are traversed
|
||||
recursively::
|
||||
|
||||
class RegisteredSpecial(Special):
|
||||
def __repr__(self):
|
||||
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
|
||||
|
||||
def special_flatten(v):
|
||||
"""Specifies a flattening recipe.
|
||||
|
||||
Params:
|
||||
v: the value of registered type to flatten.
|
||||
Returns:
|
||||
a pair of an iterable with the children to be flattened recursively,
|
||||
and some opaque auxiliary data to pass back to the unflattening recipe.
|
||||
The auxiliary data is stored in the treedef for use during unflattening.
|
||||
The auxiliary data could be used, e.g., for dictionary keys.
|
||||
"""
|
||||
children = (v.x, v.y)
|
||||
aux_data = None
|
||||
return (children, aux_data)
|
||||
|
||||
def special_unflatten(aux_data, children):
|
||||
"""Specifies an unflattening recipe.
|
||||
|
||||
Params:
|
||||
aux_data: the opaque data that was specified during flattening of the
|
||||
current treedef.
|
||||
children: the unflattened children
|
||||
|
||||
Returns:
|
||||
a re-constructed object of the registered type, using the specified
|
||||
children and auxiliary data.
|
||||
"""
|
||||
return RegisteredSpecial(*children)
|
||||
|
||||
# Global registration
|
||||
register_pytree_node(
|
||||
RegisteredSpecial,
|
||||
special_flatten, # tell JAX what are the children nodes
|
||||
special_unflatten # tell JAX how to pack back into a RegisteredSpecial
|
||||
)
|
||||
|
||||
show_example(RegisteredSpecial(1., 2.))
|
||||
|
||||
# Output:
|
||||
# structured=RegisteredSpecial(x=1.0, y=2.0)
|
||||
# flat=[1.0, 2.0]
|
||||
# tree=PyTreeDef(<class '__main__.RegisteredSpecial'>[None], [*,*])
|
||||
# unflattened=RegisteredSpecial(x=1.0, y=2.0)
|
||||
|
||||
JAX needs sometimes to compare ``treedef`` for equality. Therefore, care must be
|
||||
taken to ensure that the auxiliary data specified in the flattening recipe
|
||||
supports a meaningful equality comparison.
|
||||
|
||||
The whole set of functions for operating on pytrees are in `tree_util module
|
||||
<https://jax.readthedocs.io/en/latest/jax.tree_util.html>`_.
|
Loading…
x
Reference in New Issue
Block a user