mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
DOC: add alternative for pytree initialization
This commit is contained in:
parent
e276859d11
commit
47ec553c40
@ -320,3 +320,14 @@ class MyTree:
|
||||
a = jnp.asarray(a)
|
||||
self.a = a
|
||||
```
|
||||
Another possibility is to structure your `tree_unflatten` function so that it avoids
|
||||
calling `__init__`; for example:
|
||||
```{code-cell}
|
||||
def tree_unflatten(aux_data, children):
|
||||
del aux_data # unused in this class
|
||||
obj = object.__new__(MyTree)
|
||||
obj.a = a
|
||||
return obj
|
||||
```
|
||||
If you go this route, make sure that your `tree_unflatten` function stays in-sync with
|
||||
`__init__` if and when the code is updated.
|
Loading…
x
Reference in New Issue
Block a user