DOC: fix issue in pytrees.md

This commit is contained in:
Jake VanderPlas 2021-04-15 09:57:53 -07:00
parent 3c003a68fc
commit 89dc691da7

View File

@ -264,12 +264,13 @@ on your class and decorate it with {func}`~jax.tree_util.register_pytree_node_cl
```{code-cell}
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class RegisteredSpecial2(Special):
def __repr__(self):
return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y)
def tree_flatten(self):
children = (v.x, v.y)
children = (self.x, self.y)
aux_data = None
return (children, aux_data)