mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
DOC: fix issue in pytrees.md
This commit is contained in:
parent
3c003a68fc
commit
89dc691da7
@ -264,12 +264,13 @@ on your class and decorate it with {func}`~jax.tree_util.register_pytree_node_cl
|
||||
```{code-cell}
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
@register_pytree_node_class
|
||||
class RegisteredSpecial2(Special):
|
||||
def __repr__(self):
|
||||
return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y)
|
||||
|
||||
def tree_flatten(self):
|
||||
children = (v.x, v.y)
|
||||
children = (self.x, self.y)
|
||||
aux_data = None
|
||||
return (children, aux_data)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user