17 KiB
jupytext | kernelspec | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
+++ {"id": "-h05_PNNhZ-D"}
Working with Pytrees
Author: Vladimir Mikulik
Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures. In JAX, we refer to these as pytrees, but you can sometimes see them called nests, or just trees.
JAX has built-in support for such objects, both in its library functions as well as through the use of functions from jax.tree_utils
(with the most common ones also available as jax.tree_*
). This section will explain how to use them, give some useful snippets and point out common gotchas.
+++ {"id": "9UjxVY9ulSCn"}
What is a pytree?
As defined in the JAX pytree docs:
a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.
Some example pytrees:
---
executionInfo:
elapsed: 11002
status: ok
timestamp: 1692698031720
user:
displayName: ''
userId: ''
user_tz: -60
id: Wh6BApZ9lrR1
outputId: df1fa4cd-88a6-4d71-a376-b2ddf91568dd
---
import jax
import jax.numpy as jnp
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
# Let's see how many leaves they have:
for pytree in example_trees:
leaves = jax.tree_util.tree_leaves(pytree)
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
+++ {"id": "_tWkkGNwW8vf"}
We've also introduced our first jax.tree_*
function, which allowed us to extract the flattened leaves from the trees.
+++ {"id": "RcsmneIGlltm"}
Why pytrees?
In machine learning, some places where you commonly find pytrees are:
- Model parameters
- Dataset entries
- RL agent observations
They also often arise naturally when working in bulk with datasets (e.g., lists of lists of dicts).
+++ {"id": "sMrSGSIJn9MD"}
Common pytree functions
Perhaps the most commonly used pytree function is jax.tree_map
. It works analogously to Python's native map
, but on entire pytrees:
:id: wZRcuQu4n7o5
:outputId: 3528bc9f-54ed-49c8-b79a-1cbea176c0f3
list_of_lists = [
[1, 2, 3],
[1, 2],
[1, 2, 3, 4]
]
jax.tree_map(lambda x: x*2, list_of_lists)
+++ {"id": "xu8X3fk4orC9"}
jax.tree_map
also works with multiple arguments:
:id: KVpB4r1OkeUK
:outputId: 33f88a7e-aac7-48cd-d207-2c531cd37733
another_list_of_lists = list_of_lists
jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
+++ {"id": "dkRKy3LvowAb"}
When using multiple arguments with jax.tree_map
, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.
+++ {"id": "Lla4hDW6sgMZ"}
Example: ML model parameters
A simple example of training an MLP displays some ways in which pytree operations come in useful:
:id: j2ZUzWx8tKB2
import numpy as np
def init_mlp_params(layer_widths):
params = []
for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
params.append(
dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
biases=np.ones(shape=(n_out,))
)
)
return params
params = init_mlp_params([1, 128, 128, 1])
+++ {"id": "kUFwJOspuGvU"}
We can use jax.tree_map
to check that the shapes of our parameters are what we expect:
:id: ErWsXuxXse-z
:outputId: d3e549ab-40ef-470e-e460-1b5939d9696f
jax.tree_map(lambda x: x.shape, params)
+++ {"id": "zQtRKaj4ua6-"}
Now, let's train our MLP:
:id: iL4GvW9OuZ-X
def forward(params, x):
*hidden, last = params
for layer in hidden:
x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
return x @ last['weights'] + last['biases']
def loss_fn(params, x, y):
return jnp.mean((forward(params, x) - y) ** 2)
LEARNING_RATE = 0.0001
@jax.jit
def update(params, x, y):
grads = jax.grad(loss_fn)(params, x, y)
# Note that `grads` is a pytree with the same structure as `params`.
# `jax.grad` is one of the many JAX functions that has
# built-in support for pytrees.
# This is handy, because we can apply the SGD update using tree utils:
return jax.tree_map(
lambda p, g: p - LEARNING_RATE * g, params, grads
)
:id: B3HniT9-xohz
:outputId: d77e9811-373e-45d6-ccbe-edb6f43120d7
import matplotlib.pyplot as plt
xs = np.random.normal(size=(128, 1))
ys = xs ** 2
for _ in range(1000):
params = update(params, xs, ys)
plt.scatter(xs, ys)
plt.scatter(xs, forward(params, xs), label='Model prediction')
plt.legend();
+++ {"id": "lNAvmpzdoE9l"}
Key paths
In a pytree each leaf has a key path. A key path for a leaf is a list
of keys, where the length of the list is equal to the depth of the leaf in the pytree . Each key is a hashable object that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for dict
s is different from the type of keys for tuple
s.
For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique.
The APIs for working with key paths are:
-
jax.tree_util.tree_flatten_with_path
: Works similarly withjax.tree_util.tree_flatten
, but returns key paths. -
jax.tree_util.tree_map_with_path
: Works similarly withjax.tree_util.tree_map
, but the function also takes key paths as arguments. -
jax.tree_util.keystr
: Given a general key path, returns a reader-friendly string expression.
One use case is to print debugging information related to a certain leaf value:
:id: G6E2YzhvoE9l
:outputId: 5aec83c8-e15e-48eb-b2c3-6fa0164344b5
import collections
ATuple = collections.namedtuple("ATuple", ('name'))
tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)
for key_path, value in flattened:
print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
+++ {"id": "zrKqmANgoE9l"}
To express key paths, JAX provides a few default key types for the built-in pytree node types, namely:
SequenceKey(idx: int)
: for lists and tuples.DictKey(key: Hashable)
: for dictionaries.GetAttrKey(name: str)
: fornamedtuple
s and preferably custom pytree nodes (more in the next section)
You are free to define your own key types for your own custom nodes. They will work with jax.tree_util.keystr
as long as their __str__()
method is also overridden with a reader-friendly expression.
:id: ohDq0kGuoE9l
:outputId: 9b8ff3ec-3461-482e-ff27-30dc2a7e68c9
for key_path, _ in flattened:
print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
+++ {"id": "sBxOB21YNEDA"}
Custom pytree nodes
So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:
:id: CK8LN2PRFnQf
class MyContainer:
"""A named container."""
def __init__(self, name: str, a: int, b: int, c: int):
self.name = name
self.a = a
self.b = b
self.c = c
:id: OPGe2R7ZOXCT
:outputId: 40db1f41-9df8-4dea-972a-6a7bc44a49c6
jax.tree_util.tree_leaves([
MyContainer('Alice', 1, 2, 3),
MyContainer('Bob', 4, 5, 6)
])
+++ {"id": "vk4vucGXPADj"}
Accordingly, if we try to use a tree map expecting our leaves to be the elements inside the container, we will get an error:
:id: vIr9_JOIOku7
:outputId: dadc9c15-4a10-4fac-e70d-f23e7085cf74
try:
jax.tree_map(lambda x: x + 1, [
MyContainer('Alice', 1, 2, 3),
MyContainer('Bob', 4, 5, 6)
])
except TypeError as e:
print(f'TypeError: {e}')
+++ {"id": "nAZ4FR2lPN51", "tags": ["raises-exception"]}
To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:
:id: 2RR5cDFvoE9m
:outputId: 94745373-abe4-4bca-967c-4133e8027c30
from typing import Iterable
def flatten_MyContainer(container) -> tuple[Iterable[int], str]:
"""Returns an iterable over container contents, and aux data."""
flat_contents = [container.a, container.b, container.c]
# we don't want the name to appear as a child, so it is auxiliary data.
# auxiliary data is usually a description of the structure of a node,
# e.g., the keys of a dict -- anything that isn't a node's children.
aux_data = container.name
return flat_contents, aux_data
def unflatten_MyContainer(
aux_data: str, flat_contents: Iterable[int]) -> MyContainer:
"""Converts aux data and the flat contents into a MyContainer."""
return MyContainer(aux_data, *flat_contents)
jax.tree_util.register_pytree_node(
MyContainer, flatten_MyContainer, unflatten_MyContainer)
jax.tree_util.tree_leaves([
MyContainer('Alice', 1, 2, 3),
MyContainer('Bob', 4, 5, 6)
])
+++ {"id": "JXaEe76ZoE9m"}
Alternatively, using the key path API mentioned above, you can register this container with its keys in mind by defining how the keys should look like for each flattened-out value.
:id: D_juQx-2OybX
:outputId: ee2cf4ad-ec21-4636-c9c5-2c64b81429bb
class MyKeyPathContainer(MyContainer):
pass
def flatten_with_keys_MyKeyPathContainer(container) -> tuple[Iterable[int], str]:
"""Returns an iterable over container contents, and aux data."""
# GetAttrKey is a common way to express an attribute key. Users are free
# to pick any other expression that fits their use cases the best.
flat_contents = [(jax.tree_util.GetAttrKey('a'), container.a),
(jax.tree_util.GetAttrKey('b'), container.b),
(jax.tree_util.GetAttrKey('c'), container.c)]
# we don't want the name to appear as a child, so it is auxiliary data.
# auxiliary data is usually a description of the structure of a node,
# e.g., the keys of a dict -- anything that isn't a node's children.
aux_data = container.name
return flat_contents, aux_data
def unflatten_MyKeyPathContainer(
aux_data: str, flat_contents: Iterable[int]) -> MyKeyPathContainer:
"""Converts aux data and the flat contents into a MyContainer."""
return MyKeyPathContainer(aux_data, *flat_contents)
jax.tree_util.register_pytree_with_keys(
MyKeyPathContainer, flatten_with_keys_MyKeyPathContainer, unflatten_MyKeyPathContainer)
jax.tree_util.tree_leaves([
MyKeyPathContainer('Alice', 1, 2, 3),
MyKeyPathContainer('Bob', 4, 5, 6)
])
+++ {"id": "HPX23W4zoE9m"}
register_pytree_with_keys
is an extended API of register_pytree_node
, and containers registered in either way can freely use all the tree_util
utilities without error.
When a container registered with register_pytree_node
uses .*_with_path
APIs, the keys being returned will be a series of "flat index" fallbacks:
:id: E1BwD2aZoE9m
:outputId: 4fe12b06-aef4-426a-a732-891affa63842
flattened, _ = jax.tree_util.tree_flatten_with_path(MyContainer('Alice', 1, 2, 3))
for key_path, value in flattened:
print(f'MyContainer container{jax.tree_util.keystr(key_path)}: {value}')
flattened, _ = jax.tree_util.tree_flatten_with_path(MyKeyPathContainer('Alice', 1, 2, 3))
for key_path, value in flattened:
print(f'MyKeyPathContainer container{jax.tree_util.keystr(key_path)}: {value}')
+++ {"id": "JgnAp7fFShEB"}
Modern Python comes equipped with helpful tools to make defining containers easier. Some of these will work with JAX out-of-the-box, but others require more care. For instance, a NamedTuple
subclass doesn't need to be registered to be considered a pytree node type:
:id: 8DNoLABtO0fr
:outputId: 9a448508-43eb-4450-bfaf-eeeb59a9e349
from typing import NamedTuple, Any
class MyOtherContainer(NamedTuple):
name: str
a: Any
b: Any
c: Any
# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box:
jax.tree_util.tree_leaves([
MyOtherContainer('Alice', 1, 2, 3),
MyOtherContainer('Bob', 4, 5, 6)
])
+++ {"id": "TVdtzJDVTZb6"}
Notice that the name
field now appears as a leaf, as all tuple elements are children. That's the price we pay for not having to register the class the hard way.
+++ {"id": "wDbVszv-oE9n"}
One shortcut is to use jax.tree_util.register_static
to register a type as being a node without children:
---
executionInfo:
elapsed: 59
status: ok
timestamp: 1692698060536
user:
displayName: ''
userId: ''
user_tz: -60
id: Rclc079ioE9n
outputId: 6b6a4402-8fc1-409c-b6da-88568a612e1b
---
from typing import NamedTuple, Any
@jax.tree_util.register_static
class StaticStr(str):
pass
class YetAnotherContainer(NamedTuple):
name: StaticStr
a: Any
b: Any
c: Any
# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box:
jax.tree_util.tree_leaves([
YetAnotherContainer(StaticStr('Alice'), 1, 2, 3),
YetAnotherContainer(StaticStr('Bob'), 4, 5, 6)
])
+++ {"id": "kNsTszcEEHD0"}
Common pytree gotchas and patterns
+++ {"id": "0ki-JDENzyL7"}
Gotchas
Mistaking nodes for leaves
A common problem to look out for is accidentally introducing tree nodes instead of leaves:
:id: N-th4jOAGJlM
:outputId: 23eed14d-d383-4d88-d6f9-02bac06020df
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]
# Try to make another tree with ones instead of zeros
shapes = jax.tree_map(lambda x: x.shape, a_tree)
jax.tree_map(jnp.ones, shapes)
+++ {"id": "q8d4y-hfHTWh"}
What happened is that the shape
of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling jnp.ones
on e.g. (2, 3)
, it's called on 2
and 3
.
The solution will depend on the specifics, but there are two broadly applicable options:
- rewrite the code to avoid the intermediate
tree_map
. - convert the tuple into an
np.array
orjnp.array
, which makes the entire sequence a leaf.
+++ {"id": "4OKlbFlEIda-"}
Handling of None
jax.tree_utils
treats None
as a node without children, not as a leaf:
:id: gIwlwo2MJcEC
:outputId: 1e59f323-a7b7-42be-8603-afa4693c00cc
jax.tree_util.tree_leaves([None, None, None])
+++ {"id": "pwNz-rp1JvW4"}
Patterns
Transposing trees
If you would like to transpose a pytree, i.e. turn a list of trees into a tree of lists, you can do so using jax.tree_map
:
:id: UExN7-G7qU-F
:outputId: fd049086-ef37-44db-8e2c-9f1bd9fad950
def tree_transpose(list_of_trees):
"""Convert a list of trees of identical structure into a single tree of lists."""
return jax.tree_map(lambda *xs: list(xs), *list_of_trees)
# Convert a dataset from row-major to column-major:
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
+++ {"id": "Ao6R2ffm2CF4"}
For more complicated transposes, JAX provides jax.tree_transpose
, which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility:
:id: bZvVwxshz1D3
:outputId: a0314dc8-4267-41e6-a763-931d40433c26
jax.tree_transpose(
outer_treedef = jax.tree_structure([0 for e in episode_steps]),
inner_treedef = jax.tree_structure(episode_steps[0]),
pytree_to_transpose = episode_steps
)
+++ {"id": "KlYA2R6N2h_8"}
More Information
For more information on pytrees in JAX and the operations that are available, see the Pytrees section in the JAX documentation.