DOC: add register_pytree_node_class example

This commit is contained in:
Jake VanderPlas 2021-04-08 08:55:18 -07:00
parent b935d33ccb
commit 011c5a203e

View File

@ -142,7 +142,7 @@ _jnp.arange(10)
```
```{code-cell}
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
from jax.tree_util import tree_flatten, tree_unflatten
import jax.numpy as jnp
# The structured value to be transformed
@ -208,10 +208,13 @@ show_example(Special(1., 2.))
```
The set of Python types that are considered internal pytree nodes is extensible,
through a global registry of types. Values of registered types are traversed
recursively:
through a global registry of types, and values of registered types are traversed
recursively. To register a new type, you can use
{func}`~jax.tree_util.register_pytree_node`:
```{code-cell}
from jax.tree_util import register_pytree_node
class RegisteredSpecial(Special):
def __repr__(self):
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
@ -255,6 +258,28 @@ register_pytree_node(
show_example(RegisteredSpecial(1., 2.))
```
Alternatively, you can define appropriate `tree_flatten` and `tree_unflatten` methods
on your class and decorate it with {func}`~jax.tree_util.register_pytree_node_class`:
```{code-cell}
from jax.tree_util import register_pytree_node_class
class RegisteredSpecial2(Special):
def __repr__(self):
return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y)
def tree_flatten(self):
children = (v.x, v.y)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
show_example(RegisteredSpecial2(1., 2.))
```
JAX needs sometimes to compare `treedef` for equality. Therefore, care must be
taken to ensure that the auxiliary data specified in the flattening recipe
supports a meaningful equality comparison.