mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 07:56:06 +00:00
755 lines
34 KiB
Plaintext
755 lines
34 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "-h05_PNNhZ-D"
|
||
},
|
||
"source": [
|
||
"# Working with Pytrees\n",
|
||
"\n",
|
||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb)\n",
|
||
"\n",
|
||
"*Author: Vladimir Mikulik*\n",
|
||
"\n",
|
||
"Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures. In JAX, we refer to these as *pytrees*, but you can sometimes see them called *nests*, or just *trees*.\n",
|
||
"\n",
|
||
"JAX has built-in support for such objects, both in its library functions as well as through the use of functions from [`jax.tree_utils`](https://jax.readthedocs.io/en/latest/jax.tree_util.html) (with the most common ones also available as `jax.tree_*`). This section will explain how to use them, give some useful snippets and point out common gotchas."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "9UjxVY9ulSCn"
|
||
},
|
||
"source": [
|
||
"## What is a pytree?\n",
|
||
"\n",
|
||
"As defined in the [JAX pytree docs](https://jax.readthedocs.io/en/latest/pytrees.html):\n",
|
||
"\n",
|
||
"> a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.\n",
|
||
"\n",
|
||
"Some example pytrees:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"id": "Wh6BApZ9lrR1",
|
||
"outputId": "37b8d89c-8dd0-4f2b-f479-8333f4b3a2c3"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[1, 'a', <object object at 0x7fded60bb8c0>] has 3 leaves: [1, 'a', <object object at 0x7fded60bb8c0>]\n",
|
||
"(1, (2, 3), ()) has 3 leaves: [1, 2, 3]\n",
|
||
"[1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]\n",
|
||
"{'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]\n",
|
||
"DeviceArray([1, 2, 3], dtype=int32) has 1 leaves: [DeviceArray([1, 2, 3], dtype=int32)]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import jax\n",
|
||
"import jax.numpy as jnp\n",
|
||
"\n",
|
||
"example_trees = [\n",
|
||
" [1, 'a', object()],\n",
|
||
" (1, (2, 3), ()),\n",
|
||
" [1, {'k1': 2, 'k2': (3, 4)}, 5],\n",
|
||
" {'a': 2, 'b': (2, 3)},\n",
|
||
" jnp.array([1, 2, 3]),\n",
|
||
"]\n",
|
||
"\n",
|
||
"# Let's see how many leaves they have:\n",
|
||
"for pytree in example_trees:\n",
|
||
" leaves = jax.tree_leaves(pytree)\n",
|
||
" print(f\"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "_tWkkGNwW8vf"
|
||
},
|
||
"source": [
|
||
"We've also introduced our first `jax.tree_*` function, which allowed us to extract the flattened leaves from the trees."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "RcsmneIGlltm"
|
||
},
|
||
"source": [
|
||
"## Why pytrees?\n",
|
||
"\n",
|
||
"In machine learning, some places where you commonly find pytrees are:\n",
|
||
"* Model parameters\n",
|
||
"* Dataset entries\n",
|
||
"* RL agent observations\n",
|
||
"\n",
|
||
"They also often arise naturally when working in bulk with datasets (e.g., lists of lists of dicts)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "sMrSGSIJn9MD"
|
||
},
|
||
"source": [
|
||
"## Common pytree functions\n",
|
||
"Perhaps the most commonly used pytree function is `jax.tree_map`. It works analogously to Python's native `map`, but on entire pytrees:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {
|
||
"id": "wZRcuQu4n7o5",
|
||
"outputId": "3528bc9f-54ed-49c8-b79a-1cbea176c0f3"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[[2, 4, 6], [2, 4], [2, 4, 6, 8]]"
|
||
]
|
||
},
|
||
"execution_count": 2,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"list_of_lists = [\n",
|
||
" [1, 2, 3],\n",
|
||
" [1, 2],\n",
|
||
" [1, 2, 3, 4]\n",
|
||
"]\n",
|
||
"\n",
|
||
"jax.tree_map(lambda x: x*2, list_of_lists)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "xu8X3fk4orC9"
|
||
},
|
||
"source": [
|
||
"`jax.tree_map` also works with multiple arguments:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"id": "KVpB4r1OkeUK",
|
||
"outputId": "33f88a7e-aac7-48cd-d207-2c531cd37733"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[[2, 4, 6], [2, 4], [2, 4, 6, 8]]"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"another_list_of_lists = list_of_lists\n",
|
||
"jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "dkRKy3LvowAb"
|
||
},
|
||
"source": [
|
||
"When using multiple arguments with `jax.tree_map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "Lla4hDW6sgMZ"
|
||
},
|
||
"source": [
|
||
"## Example: ML model parameters\n",
|
||
"\n",
|
||
"A simple example of training an MLP displays some ways in which pytree operations come in useful:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {
|
||
"id": "j2ZUzWx8tKB2"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"def init_mlp_params(layer_widths):\n",
|
||
" params = []\n",
|
||
" for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):\n",
|
||
" params.append(\n",
|
||
" dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),\n",
|
||
" biases=np.ones(shape=(n_out,))\n",
|
||
" )\n",
|
||
" )\n",
|
||
" return params\n",
|
||
"\n",
|
||
"params = init_mlp_params([1, 128, 128, 1])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "kUFwJOspuGvU"
|
||
},
|
||
"source": [
|
||
"We can use `jax.tree_map` to check that the shapes of our parameters are what we expect:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"id": "ErWsXuxXse-z",
|
||
"outputId": "d3e549ab-40ef-470e-e460-1b5939d9696f"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[{'biases': (128,), 'weights': (1, 128)},\n",
|
||
" {'biases': (128,), 'weights': (128, 128)},\n",
|
||
" {'biases': (1,), 'weights': (128, 1)}]"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"jax.tree_map(lambda x: x.shape, params)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "zQtRKaj4ua6-"
|
||
},
|
||
"source": [
|
||
"Now, let's train our MLP:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {
|
||
"id": "iL4GvW9OuZ-X"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def forward(params, x):\n",
|
||
" *hidden, last = params\n",
|
||
" for layer in hidden:\n",
|
||
" x = jax.nn.relu(x @ layer['weights'] + layer['biases'])\n",
|
||
" return x @ last['weights'] + last['biases']\n",
|
||
"\n",
|
||
"def loss_fn(params, x, y):\n",
|
||
" return jnp.mean((forward(params, x) - y) ** 2)\n",
|
||
"\n",
|
||
"LEARNING_RATE = 0.0001\n",
|
||
"\n",
|
||
"@jax.jit\n",
|
||
"def update(params, x, y):\n",
|
||
"\n",
|
||
" grads = jax.grad(loss_fn)(params, x, y)\n",
|
||
" # Note that `grads` is a pytree with the same structure as `params`.\n",
|
||
" # `jax.grad` is one of the many JAX functions that has\n",
|
||
" # built-in support for pytrees.\n",
|
||
"\n",
|
||
" # This is handy, because we can apply the SGD update using tree utils:\n",
|
||
" return jax.tree_map(\n",
|
||
" lambda p, g: p - LEARNING_RATE * g, params, grads\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {
|
||
"id": "B3HniT9-xohz",
|
||
"outputId": "d77e9811-373e-45d6-ccbe-edb6f43120d7"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light",
|
||
"tags": []
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"xs = np.random.normal(size=(128, 1))\n",
|
||
"ys = xs ** 2\n",
|
||
"\n",
|
||
"for _ in range(1000):\n",
|
||
" params = update(params, xs, ys)\n",
|
||
"\n",
|
||
"plt.scatter(xs, ys)\n",
|
||
"plt.scatter(xs, forward(params, xs), label='Model prediction')\n",
|
||
"plt.legend();"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "sBxOB21YNEDA"
|
||
},
|
||
"source": [
|
||
"## Custom pytree nodes\n",
|
||
"\n",
|
||
"So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {
|
||
"id": "CK8LN2PRFnQf"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"class MyContainer:\n",
|
||
" \"\"\"A named container.\"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, name: str, a: int, b: int, c: int):\n",
|
||
" self.name = name\n",
|
||
" self.a = a\n",
|
||
" self.b = b\n",
|
||
" self.c = c"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"id": "OPGe2R7ZOXCT",
|
||
"outputId": "40db1f41-9df8-4dea-972a-6a7bc44a49c6"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[<__main__.MyContainer at 0x7fdec166ce50>,\n",
|
||
" <__main__.MyContainer at 0x7fded89ba490>]"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"jax.tree_leaves([\n",
|
||
" MyContainer('Alice', 1, 2, 3),\n",
|
||
" MyContainer('Bob', 4, 5, 6)\n",
|
||
"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "vk4vucGXPADj"
|
||
},
|
||
"source": [
|
||
"Accordingly, if we try to use a tree map expecting our leaves to be the elements inside the container, we will get an error:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {
|
||
"id": "vIr9_JOIOku7",
|
||
"outputId": "dadc9c15-4a10-4fac-e70d-f23e7085cf74"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"ename": "TypeError",
|
||
"evalue": "ignored",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
||
"\u001b[0;32m<ipython-input-10-d6b45a2ec2b9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m jax.tree_map(lambda x: x + 1, [\n\u001b[1;32m 2\u001b[0m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Alice'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Bob'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m ])\n",
|
||
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/tree_util.py\u001b[0m in \u001b[0;36mtree_map\u001b[0;34m(f, tree, is_leaf)\u001b[0m\n\u001b[1;32m 184\u001b[0m \"\"\"\n\u001b[1;32m 185\u001b[0m \u001b[0mleaves\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtreedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_leaf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 186\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mleaves\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 187\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 188\u001b[0m def tree_multimap(f: Callable[..., Any], tree: Any, *rest: Any,\n",
|
||
"\u001b[0;32m<ipython-input-10-d6b45a2ec2b9>\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(x)\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m jax.tree_map(lambda x: x + 1, [\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Alice'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Bob'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m ])\n",
|
||
"\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for +: 'MyContainer' and 'int'"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"jax.tree_map(lambda x: x + 1, [\n",
|
||
" MyContainer('Alice', 1, 2, 3),\n",
|
||
" MyContainer('Bob', 4, 5, 6)\n",
|
||
"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "nAZ4FR2lPN51",
|
||
"tags": [
|
||
"raises-exception"
|
||
]
|
||
},
|
||
"source": [
|
||
"To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"id": "D_juQx-2OybX",
|
||
"outputId": "ee2cf4ad-ec21-4636-c9c5-2c64b81429bb"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[1, 2, 3, 4, 5, 6]"
|
||
]
|
||
},
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import Tuple, Iterable\n",
|
||
"\n",
|
||
"def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:\n",
|
||
" \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
|
||
" flat_contents = [container.a, container.b, container.c]\n",
|
||
"\n",
|
||
" # we don't want the name to appear as a child, so it is auxiliary data.\n",
|
||
" # auxiliary data is usually a description of the structure of a node,\n",
|
||
" # e.g., the keys of a dict -- anything that isn't a node's children.\n",
|
||
" aux_data = container.name\n",
|
||
" return flat_contents, aux_data\n",
|
||
"\n",
|
||
"def unflatten_MyContainer(\n",
|
||
" aux_data: str, flat_contents: Iterable[int]) -> MyContainer:\n",
|
||
" \"\"\"Converts aux data and the flat contents into a MyContainer.\"\"\"\n",
|
||
" return MyContainer(aux_data, *flat_contents)\n",
|
||
"\n",
|
||
"jax.tree_util.register_pytree_node(\n",
|
||
" MyContainer, flatten_MyContainer, unflatten_MyContainer)\n",
|
||
"\n",
|
||
"jax.tree_leaves([\n",
|
||
" MyContainer('Alice', 1, 2, 3),\n",
|
||
" MyContainer('Bob', 4, 5, 6)\n",
|
||
"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "JgnAp7fFShEB"
|
||
},
|
||
"source": [
|
||
"Modern Python comes equipped with helpful tools to make defining containers easier. Some of these will work with JAX out-of-the-box, but others require more care. For instance:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"id": "8DNoLABtO0fr",
|
||
"outputId": "9a448508-43eb-4450-bfaf-eeeb59a9e349"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"['Alice', 1, 2, 3, 'Bob', 4, 5, 6]"
|
||
]
|
||
},
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import NamedTuple, Any\n",
|
||
"\n",
|
||
"class MyOtherContainer(NamedTuple):\n",
|
||
" name: str\n",
|
||
" a: Any\n",
|
||
" b: Any\n",
|
||
" c: Any\n",
|
||
"\n",
|
||
"# Since `tuple` is already registered with JAX, and NamedTuple is a subclass,\n",
|
||
"# this will work out-of-the-box:\n",
|
||
"jax.tree_leaves([\n",
|
||
" MyOtherContainer('Alice', 1, 2, 3),\n",
|
||
" MyOtherContainer('Bob', 4, 5, 6)\n",
|
||
"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "TVdtzJDVTZb6"
|
||
},
|
||
"source": [
|
||
"Notice that the `name` field now appears as a leaf, as all tuple elements are children. That's the price we pay for not having to register the class the hard way."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "kNsTszcEEHD0"
|
||
},
|
||
"source": [
|
||
"## Common pytree gotchas and patterns"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "0ki-JDENzyL7"
|
||
},
|
||
"source": [
|
||
"### Gotchas\n",
|
||
"#### Mistaking nodes for leaves\n",
|
||
"A common problem to look out for is accidentally introducing tree nodes instead of leaves:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"id": "N-th4jOAGJlM",
|
||
"outputId": "23eed14d-d383-4d88-d6f9-02bac06020df"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[(DeviceArray([1., 1.], dtype=float32),\n",
|
||
" DeviceArray([1., 1., 1.], dtype=float32)),\n",
|
||
" (DeviceArray([1., 1., 1.], dtype=float32),\n",
|
||
" DeviceArray([1., 1., 1., 1.], dtype=float32))]"
|
||
]
|
||
},
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]\n",
|
||
"\n",
|
||
"# Try to make another tree with ones instead of zeros\n",
|
||
"shapes = jax.tree_map(lambda x: x.shape, a_tree)\n",
|
||
"jax.tree_map(jnp.ones, shapes)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "q8d4y-hfHTWh"
|
||
},
|
||
"source": [
|
||
"What happened is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`.\n",
|
||
"\n",
|
||
"The solution will depend on the specifics, but there are two broadly applicable options:\n",
|
||
"* rewrite the code to avoid the intermediate `tree_map`.\n",
|
||
"* convert the tuple into an `np.array` or `jnp.array`, which makes the entire\n",
|
||
"sequence a leaf."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "4OKlbFlEIda-"
|
||
},
|
||
"source": [
|
||
"#### Handling of None\n",
|
||
"`jax.tree_utils` treats `None` as a node without children, not as a leaf:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {
|
||
"id": "gIwlwo2MJcEC",
|
||
"outputId": "1e59f323-a7b7-42be-8603-afa4693c00cc"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[]"
|
||
]
|
||
},
|
||
"execution_count": 14,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"jax.tree_leaves([None, None, None])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "pwNz-rp1JvW4"
|
||
},
|
||
"source": [
|
||
"### Patterns\n",
|
||
"#### Transposing trees\n",
|
||
"\n",
|
||
"If you would like to transpose a pytree, i.e. turn a list of trees into a tree of lists, you can do so using `jax.tree_map`:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {
|
||
"id": "UExN7-G7qU-F",
|
||
"outputId": "fd049086-ef37-44db-8e2c-9f1bd9fad950"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'obs': [3, 4], 't': [1, 2]}"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"def tree_transpose(list_of_trees):\n",
|
||
" \"\"\"Convert a list of trees of identical structure into a single tree of lists.\"\"\"\n",
|
||
" return jax.tree_map(lambda *xs: list(xs), *list_of_trees)\n",
|
||
"\n",
|
||
"\n",
|
||
"# Convert a dataset from row-major to column-major:\n",
|
||
"episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]\n",
|
||
"tree_transpose(episode_steps)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "Ao6R2ffm2CF4"
|
||
},
|
||
"source": [
|
||
"For more complicated transposes, JAX provides `jax.tree_transpose`, which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {
|
||
"id": "bZvVwxshz1D3",
|
||
"outputId": "a0314dc8-4267-41e6-a763-931d40433c26"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'obs': [3, 4], 't': [1, 2]}"
|
||
]
|
||
},
|
||
"execution_count": 16,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"jax.tree_transpose(\n",
|
||
" outer_treedef = jax.tree_structure([0 for e in episode_steps]),\n",
|
||
" inner_treedef = jax.tree_structure(episode_steps[0]),\n",
|
||
" pytree_to_transpose = episode_steps\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "KlYA2R6N2h_8"
|
||
},
|
||
"source": [
|
||
"## More Information\n",
|
||
"\n",
|
||
"For more information on pytrees in JAX and the operations that are available, see the [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) section in the JAX documentation."
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"colab": {
|
||
"collapsed_sections": [],
|
||
"name": "jax101-pytrees",
|
||
"provenance": []
|
||
},
|
||
"jupytext": {
|
||
"formats": "ipynb,md:myst"
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"name": "python3"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
}
|