Copybara import of the project:

--
371c5a45ea08c8e92136761149d0016077a58652 by Jake VanderPlas <jakevdp@google.com>:

pytree doc: add discussion of children vs aux_data

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15007 from jakevdp:pytree-doc 371c5a45ea08c8e92136761149d0016077a58652
PiperOrigin-RevId: 517149897
This commit is contained in:
Jake Vanderplas 2023-03-16 09:53:44 -07:00 committed by jax authors
parent e275a9aa5c
commit 56267f08dd

View File

@ -280,9 +280,12 @@ class RegisteredSpecial2(Special):
show_example(RegisteredSpecial2(1., 2.))
```
JAX sometimes needs to compare `treedef` for equality. Therefore, care must be
taken to ensure that the auxiliary data specified in the flattening recipe
supports a meaningful equality comparison.
When defining an unflattening functions, in general `children` should contain all the
dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while
`aux_data` should contain all the static elements that will be rolled into the `treedef`
structure. JAX sometimes needs to compare `treedef` for equality, or compute its hash
for use in the JIT cache, and so care must be taken to ensure that the auxiliary data
specified in the flattening recipe supports meaningful hashing and equality comparisons.
The whole set of functions for operating on pytrees are in {mod}`jax.tree_util`.