mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Copybara import of the project:
-- 371c5a45ea08c8e92136761149d0016077a58652 by Jake VanderPlas <jakevdp@google.com>: pytree doc: add discussion of children vs aux_data COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15007 from jakevdp:pytree-doc 371c5a45ea08c8e92136761149d0016077a58652 PiperOrigin-RevId: 517149897
This commit is contained in:
parent
e275a9aa5c
commit
56267f08dd
@ -280,9 +280,12 @@ class RegisteredSpecial2(Special):
|
||||
show_example(RegisteredSpecial2(1., 2.))
|
||||
```
|
||||
|
||||
JAX sometimes needs to compare `treedef` for equality. Therefore, care must be
|
||||
taken to ensure that the auxiliary data specified in the flattening recipe
|
||||
supports a meaningful equality comparison.
|
||||
When defining an unflattening functions, in general `children` should contain all the
|
||||
dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while
|
||||
`aux_data` should contain all the static elements that will be rolled into the `treedef`
|
||||
structure. JAX sometimes needs to compare `treedef` for equality, or compute its hash
|
||||
for use in the JIT cache, and so care must be taken to ensure that the auxiliary data
|
||||
specified in the flattening recipe supports meaningful hashing and equality comparisons.
|
||||
|
||||
The whole set of functions for operating on pytrees are in {mod}`jax.tree_util`.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user