mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
DOC: add register_pytree_node_class example
This commit is contained in:
parent
b935d33ccb
commit
011c5a203e
@ -142,7 +142,7 @@ _jnp.arange(10)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
import jax.numpy as jnp
|
||||
|
||||
# The structured value to be transformed
|
||||
@ -208,10 +208,13 @@ show_example(Special(1., 2.))
|
||||
```
|
||||
|
||||
The set of Python types that are considered internal pytree nodes is extensible,
|
||||
through a global registry of types. Values of registered types are traversed
|
||||
recursively:
|
||||
through a global registry of types, and values of registered types are traversed
|
||||
recursively. To register a new type, you can use
|
||||
{func}`~jax.tree_util.register_pytree_node`:
|
||||
|
||||
```{code-cell}
|
||||
from jax.tree_util import register_pytree_node
|
||||
|
||||
class RegisteredSpecial(Special):
|
||||
def __repr__(self):
|
||||
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
|
||||
@ -255,6 +258,28 @@ register_pytree_node(
|
||||
show_example(RegisteredSpecial(1., 2.))
|
||||
```
|
||||
|
||||
Alternatively, you can define appropriate `tree_flatten` and `tree_unflatten` methods
|
||||
on your class and decorate it with {func}`~jax.tree_util.register_pytree_node_class`:
|
||||
|
||||
```{code-cell}
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
class RegisteredSpecial2(Special):
|
||||
def __repr__(self):
|
||||
return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y)
|
||||
|
||||
def tree_flatten(self):
|
||||
children = (v.x, v.y)
|
||||
aux_data = None
|
||||
return (children, aux_data)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
|
||||
show_example(RegisteredSpecial2(1., 2.))
|
||||
```
|
||||
|
||||
JAX needs sometimes to compare `treedef` for equality. Therefore, care must be
|
||||
taken to ensure that the auxiliary data specified in the flattening recipe
|
||||
supports a meaningful equality comparison.
|
||||
|
Loading…
x
Reference in New Issue
Block a user