JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees.
This section will explain how to use them, provide useful code examples, and point out common "gotchas" and patterns.
A pytree is a container-like structure built out of container-like Python objects — “leaf” pytrees and/or more pytrees. A pytree can include lists, tuples, and dicts. A leaf is anything that’s not a pytree, such as an array, but a single leaf is also a pytree.
In the context of machine learning (ML), a pytree can contain:
Below is an example of a simple pytree. In JAX, you can use {func}`jax.tree.leaves`, to extract the flattened leaves from the trees, as demonstrated here:
Any tree-like structure built out of container-like Python objects can be treated as a pytree in JAX.
Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. Any object whose type is *not* in the pytree container registry will be treated as a leaf node in the tree.
The most commonly used pytree function is {func}`jax.tree.map`. It works analogously to Python's native `map`, but transparently operates over entire pytrees.
When using multiple arguments with {func}`jax.tree.map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.
This example demonstrates how pytree operations can be useful when training a simple [multi-layer perceptron (MLP)](https://en.wikipedia.org/wiki/Multilayer_perceptron).
Begin with defining the initial model parameters:
```{code-cell}
import numpy as np
def init_mlp_params(layer_widths):
params = []
for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
This section explains how in JAX you can extend the set of Python types that will be considered _internal nodes_ in pytrees (pytree nodes) by using {func}`jax.tree_util.register_pytree_node` with {func}`jax.tree.map`.
Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you _register_ it with JAX. This is also the case even if your container class has trees inside it. For example:
As a solution, JAX allows to extend the set of types to be considered internal pytree nodes through a global registry of types. Additionally, the values of registered types are traversed recursively.
Modern Python comes equipped with helpful tools to make defining containers easier. Some will work with JAX out-of-the-box, but others require more care.
For instance, a Python `NamedTuple` subclass doesn't need to be registered to be considered a pytree node type:
```{code-cell}
from typing import NamedTuple, Any
class MyOtherContainer(NamedTuple):
name: str
a: Any
b: Any
c: Any
# NamedTuple subclasses are handled as pytree nodes, so
Notice that the `name` field now appears as a leaf, because all tuple elements are children. This is what happens when you don't have to register the class the hard way.
Unlike `NamedTuple` subclasses, classes decorated with `@dataclass` are not automatically pytrees. However, they can be registered as pytrees using the {func}`jax.tree_util.register_dataclass` decorator:
Notice that the `name` field does not appear as a leaf. This is because we included it in the `meta_fields` argument to {func}`jax.tree_util.register_dataclass`, indicating that it should be treated as metadata/auxiliary data, just like `aux_data` in `RegisteredSpecial` above. Now instances of `MyDataclassContainer` can be passed into JIT-ed functions, and `name` will be treated as static (see {ref}`jit-marking-arguments-as-static` for more information on static args):
Contrast this with `MyOtherContainer`, the `NamedTuple` subclass. Since the `name` field is a pytree leaf, JIT expects it to be convertible to {class}`jax.Array`, and the following raises an error:
Many JAX functions, like {func}`jax.lax.scan`, operate over pytrees of arrays. In addition, all JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays.
Some JAX function transformations take optional parameters that specify how certain input or output values should be treated (such as the `in_axes` and `out_axes` arguments to {func}`jax.vmap`). These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments. In particular, to be able to “match up” leaves in these parameter pytrees with values in the argument pytrees, the parameter pytrees are often constrained to be tree prefixes of the argument pytrees.
then you can use the following `in_axes` pytree to specify that only the `k2` argument is mapped (`axis=0`), and the rest aren’t mapped over (`axis=None`):
The optional parameter pytree structure must match that of the main input pytree. However, the optional parameters can optionally be specified as a “prefix” pytree, meaning that a single leaf value can be applied to an entire sub-pytree.
The same logic applies to other optional parameters that refer to specific input or output values of a transformed function, such as `out_axes` in {func}`jax.vmap`.
In a pytree each leaf has a _key path_. A key path for a leaf is a `list` of _keys_, where the length of the list is equal to the depth of the leaf in the pytree . Each _key_ is a [hashable object](https://docs.python.org/3/glossary.html#term-hashable) that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for `dict`s is different from the type of keys for `tuple`s.
For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique.
JAX has the following `jax.tree_util.*` methods for working with key paths:
You are free to define your own key types for your custom nodes. They will work with {func}`jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression.
What happened here is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`.
The solution will depend on the specifics, but there are two broadly applicable options:
### Custom pytrees and initialization with unexpected values
Another common gotcha with user-defined pytree objects is that JAX transformations occasionally initialize them with unexpected values, so that any input validation done at initialization may fail. For example:
jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to `MyTree`.
```
- In the first case with `jax.vmap(...)(tree)`, JAX’s internals use arrays of `object()` values to infer the structure of the tree
- In the second case with `jax.jacobian(...)(tree)`, the Jacobian of a function mapping a tree to a tree is defined as a tree of trees.
**Potential solution 1:**
- The `__init__` and `__new__` methods of custom pytree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases. For example:
```{code-cell}
class MyTree:
def __init__(self, a):
if not (type(a) is object or a is None or isinstance(a, MyTree)):
a = jnp.asarray(a)
self.a = a
```
**Potential solution 2:**
- Structure your custom `tree_unflatten` function so that it avoids calling `__init__`. If you choose this route, make sure that your `tree_unflatten` function stays in sync with `__init__` if and when the code is updated. Example:
```{code-cell}
def tree_unflatten(aux_data, children):
del aux_data # Unused in this class.
obj = object.__new__(MyTree)
obj.a = a
return obj
```
(pytrees-common-pytree-patterns)=
## Common pytree patterns
This section covers some of the most common patterns with JAX pytrees.
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func}`jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
**Option 2:** For more complex transposes, use {func}`jax.tree.transpose`, which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility. For example: