mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 07:56:06 +00:00
755 lines
34 KiB
Plaintext
755 lines
34 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "-h05_PNNhZ-D"
|
||
},
|
||
"source": [
|
||
"# Working with Pytrees\n",
|
||
"\n",
|
||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb)\n",
|
||
"\n",
|
||
"*Author: Vladimir Mikulik*\n",
|
||
"\n",
|
||
"Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures. In JAX, we refer to these as *pytrees*, but you can sometimes see them called *nests*, or just *trees*.\n",
|
||
"\n",
|
||
"JAX has built-in support for such objects, both in its library functions as well as through the use of functions from [`jax.tree_utils`](https://jax.readthedocs.io/en/latest/jax.tree_util.html) (with the most common ones also available as `jax.tree_*`). This section will explain how to use them, give some useful snippets and point out common gotchas."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "9UjxVY9ulSCn"
|
||
},
|
||
"source": [
|
||
"## What is a pytree?\n",
|
||
"\n",
|
||
"As defined in the [JAX pytree docs](https://jax.readthedocs.io/en/latest/pytrees.html):\n",
|
||
"\n",
|
||
"> a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.\n",
|
||
"\n",
|
||
"Some example pytrees:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"id": "Wh6BApZ9lrR1",
|
||
"outputId": "37b8d89c-8dd0-4f2b-f479-8333f4b3a2c3"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[1, 'a', <object object at 0x7fded60bb8c0>] has 3 leaves: [1, 'a', <object object at 0x7fded60bb8c0>]\n",
|
||
"(1, (2, 3), ()) has 3 leaves: [1, 2, 3]\n",
|
||
"[1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]\n",
|
||
"{'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]\n",
|
||
"DeviceArray([1, 2, 3], dtype=int32) has 1 leaves: [DeviceArray([1, 2, 3], dtype=int32)]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import jax\n",
|
||
"import jax.numpy as jnp\n",
|
||
"\n",
|
||
"example_trees = [\n",
|
||
" [1, 'a', object()],\n",
|
||
" (1, (2, 3), ()),\n",
|
||
" [1, {'k1': 2, 'k2': (3, 4)}, 5],\n",
|
||
" {'a': 2, 'b': (2, 3)},\n",
|
||
" jnp.array([1, 2, 3]),\n",
|
||
"]\n",
|
||
"\n",
|
||
"# Let's see how many leaves they have:\n",
|
||
"for pytree in example_trees:\n",
|
||
" leaves = jax.tree_leaves(pytree)\n",
|
||
" print(f\"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "_tWkkGNwW8vf"
|
||
},
|
||
"source": [
|
||
"We've also introduced our first `jax.tree_*` function, which allowed us to extract the flattened leaves from the trees."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "RcsmneIGlltm"
|
||
},
|
||
"source": [
|
||
"## Why pytrees?\n",
|
||
"\n",
|
||
"In machine learning, some places where you commonly find pytrees are:\n",
|
||
"* Model parameters\n",
|
||
"* Dataset entries\n",
|
||
"* RL agent observations\n",
|
||
"\n",
|
||
"They also often arise naturally when working in bulk with datasets (e.g., lists of lists of dicts)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "sMrSGSIJn9MD"
|
||
},
|
||
"source": [
|
||
"## Common pytree functions\n",
|
||
"Perhaps the most commonly used pytree function is `jax.tree_map`. It works analogously to Python's native `map`, but on entire pytrees:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {
|
||
"id": "wZRcuQu4n7o5",
|
||
"outputId": "3528bc9f-54ed-49c8-b79a-1cbea176c0f3"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[[2, 4, 6], [2, 4], [2, 4, 6, 8]]"
|
||
]
|
||
},
|
||
"execution_count": 2,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"list_of_lists = [\n",
|
||
" [1, 2, 3],\n",
|
||
" [1, 2],\n",
|
||
" [1, 2, 3, 4]\n",
|
||
"]\n",
|
||
"\n",
|
||
"jax.tree_map(lambda x: x*2, list_of_lists)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "xu8X3fk4orC9"
|
||
},
|
||
"source": [
|
||
"`jax.tree_map` also works with multiple arguments:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"id": "KVpB4r1OkeUK",
|
||
"outputId": "33f88a7e-aac7-48cd-d207-2c531cd37733"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[[2, 4, 6], [2, 4], [2, 4, 6, 8]]"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"another_list_of_lists = list_of_lists\n",
|
||
"jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "dkRKy3LvowAb"
|
||
},
|
||
"source": [
|
||
"When using multiple arguments with `jax.tree_map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "Lla4hDW6sgMZ"
|
||
},
|
||
"source": [
|
||
"## Example: ML model parameters\n",
|
||
"\n",
|
||
"A simple example of training an MLP displays some ways in which pytree operations come in useful:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {
|
||
"id": "j2ZUzWx8tKB2"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"def init_mlp_params(layer_widths):\n",
|
||
" params = []\n",
|
||
" for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):\n",
|
||
" params.append(\n",
|
||
" dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),\n",
|
||
" biases=np.ones(shape=(n_out,))\n",
|
||
" )\n",
|
||
" )\n",
|
||
" return params\n",
|
||
"\n",
|
||
"params = init_mlp_params([1, 128, 128, 1])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "kUFwJOspuGvU"
|
||
},
|
||
"source": [
|
||
"We can use `jax.tree_map` to check that the shapes of our parameters are what we expect:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"id": "ErWsXuxXse-z",
|
||
"outputId": "d3e549ab-40ef-470e-e460-1b5939d9696f"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[{'biases': (128,), 'weights': (1, 128)},\n",
|
||
" {'biases': (128,), 'weights': (128, 128)},\n",
|
||
" {'biases': (1,), 'weights': (128, 1)}]"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"jax.tree_map(lambda x: x.shape, params)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "zQtRKaj4ua6-"
|
||
},
|
||
"source": [
|
||
"Now, let's train our MLP:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {
|
||
"id": "iL4GvW9OuZ-X"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def forward(params, x):\n",
|
||
" *hidden, last = params\n",
|
||
" for layer in hidden:\n",
|
||
" x = jax.nn.relu(x @ layer['weights'] + layer['biases'])\n",
|
||
" return x @ last['weights'] + last['biases']\n",
|
||
"\n",
|
||
"def loss_fn(params, x, y):\n",
|
||
" return jnp.mean((forward(params, x) - y) ** 2)\n",
|
||
"\n",
|
||
"LEARNING_RATE = 0.0001\n",
|
||
"\n",
|
||
"@jax.jit\n",
|
||
"def update(params, x, y):\n",
|
||
"\n",
|
||
" grads = jax.grad(loss_fn)(params, x, y)\n",
|
||
" # Note that `grads` is a pytree with the same structure as `params`.\n",
|
||
" # `jax.grad` is one of the many JAX functions that has\n",
|
||
" # built-in support for pytrees.\n",
|
||
"\n",
|
||
" # This is handy, because we can apply the SGD update using tree utils:\n",
|
||
" return jax.tree_map(\n",
|
||
" lambda p, g: p - LEARNING_RATE * g, params, grads\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {
|
||
"id": "B3HniT9-xohz",
|
||
"outputId": "d77e9811-373e-45d6-ccbe-edb6f43120d7"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de3TU9Z3/8ec7w0QDKFFBrQkUTn8urUBIfkbUBS8LRawKjfwsar1fSn+nVkRbCqweCa0rVM7xQtfub63Xrjcixohal/XOYtWabDCISM/Wa6JbAQ1WiWaSfH5/TCZkwlzJXL4z83qcwwnz/X6ZvCckr3znczXnHCIi4l1F2S5ARERiU1CLiHicglpExOMU1CIiHqegFhHxuCHpeNKRI0e6sWPHpuOpRUTyUlNT0w7n3KhI59IS1GPHjqWxsTEdTy0ikpfM7P1o59T0ISLicQpqERGPU1CLiHhcWtqoIwkEArS2tvLVV19l6lNKBuy///6Ul5fj9/uzXYpI3spYULe2tnLAAQcwduxYzCxTn1bSyDnHzp07aW1tZdy4cdkuRyRvZSyov/rqK4V0njEzDjnkELZv357tUkSyqqG5jVXrt/FRewdHlJawaNZ4aqrKUvb8GQtqQCGdh/R/KoWuobmNpfWb6Qh0A9DW3sHS+s0AKQtrdSaKiAzCqvXb+kI6pCPQzar121L2OeIGtZmNN7NN/f58bmYLU1ZBBpkZ559/ft/jrq4uRo0axRlnnJHU84wdO5YdO3YM+prBeu+995g4cSIAjY2NLFiwIOb1N954Y9jjv//7v09bbSKF4qP2jqSO74u4Qe2c2+acq3TOVQJHA7uBx1JWQQYNGzaMN998k46O4BfwmWeeoawsde1IqdLV1ZX0v6murmb16tUxrxkY1H/84x+T/jwiEu6I0pKkju+LZJs+ZgB/cc5FneqYMi11cMtEqC0NfmypS8nTnnbaaTz11FMAPPTQQ5x77rl95z799FNqamqoqKjguOOOo6WlBYCdO3dyyimnMGHCBC6//HL674pz//33M2XKFCorK/nxj39Md3f4W6CBhg8fztVXX82ECROYMWNGX0fcySefzMKFC6murua2226jqamJk046iaOPPppZs2bx8ccfA9DU1MTkyZOZPHkyt99+e9/zvvjii33vDL744gsuueQSJk2aREVFBY8++ihLliyho6ODyspKzjvvvL5aIDh6Y9GiRUycOJFJkyaxZs2avuc8+eSTOeuss/j2t7/Neeedh3YEEgm3aNZ4Svy+sGMlfh+LZo1P2edINqjPAR6KdMLM5ptZo5k1DnoUQEsdPLEAdn0IuODHJxakJKzPOeccHn74Yb766itaWlo49thj+84tW7aMqqoqWlpauPHGG7nwwgsBWL58OdOmTWPLli2ceeaZfPDBBwBs3bqVNWvW8PLLL7Np0yZ8Ph8PPPBAzM//5ZdfUl1dzZYtWzjppJNYvnx537nOzs6+Jowrr7yStWvX0tTUxKWXXsq1114LwCWXXMJvfvMb3njjjaif41e/+hUjRoxg8+bNtLS0MH36dFauXElJSQmbNm3aq8b6+no2bdrEG2+8wbPPPsuiRYv6fjE0Nzdz66238tZbb/HOO+/w8ssvJ/HVFsl/NVVlrJg7ibLSEgwoKy1hxdxJ2Rn1YWbFwBxgaaTzzrk7gDsAqqurB3fb9dwvITCgfSfQETxeMW9QT11RUcF7773HQw89xGmnnRZ2buPGjTz66KMATJ8+nZ07d/L555+zYcMG6uvrATj99NM56KCDgmU+9xxNTU0cc8wxAHR0dHDooYfG/PxFRUWcffbZAJx//vnMnTu371zo+LZt23jzzTeZOXMmAN3d3XzjG9+gvb2d9vZ2TjzxRAAuuOACnn766b0+x7PPPsvDDz/c9zhUbzQbN27k3HPPxefzcdhhh3HSSSfx+uuvc+CBBzJlyhTKy8sBqKys5L333mPatGkxn0+k0NRUlaU0mAdKZnje94D/cs79NV3F9NnVmtzxJM2ZM4ef//znvPjii+zcuXOfn8c5x0UXXcSKFSv2+Tn6D28bNmxY3/NOmDCBV155Jeza9vb2ff48+2q//fbr+7vP59un9nMRGZxkmj7OJUqzR8qNKE/ueJIuvfRSli1bxqRJk8KOn3DCCX3NAi+++CIjR47kwAMP5MQTT+TBBx8E4Omnn+azzz4DYMaMGaxdu5ZPPvkECLZxv/9+7Ob7np4e1q5dC8CDDz4Y8e50/PjxbN++vS+oA4EAW7ZsobS0lNLSUjZu3AgQtZll5syZYe3XoXr9fj+BQGCv60844QTWrFlDd3c327dvZ8OGDUyZMiXm6xCRzEkoqM1sGDATqE9vOb1mXA/+AT2m/pLg8RQoLy+POJSttraWpqYmKioqWLJkCffddx8QbLvesGEDEyZMoL6+njFjxgBw1FFHccMNN3DKKadQUVHBzJkz+9p2oxk2bBh/+tOfmDhxIs8//zzXX7/3ayouLmbt2rUsXryYyZMnU1lZ2TdC45577uGKK66gsrIyasfeddddx2effcbEiROZPHkyL7zwAgDz58+noqKirzMx5Mwzz6SiooLJkyczffp0brrpJg4//PA4X0URyRRLRy9+dXW1G7hxwNatW/nOd76T+JO01AXbpHe1Bu+kZ1w/6PZpLxg+fDhffPFFtstIqaT/b0VkL2bW5JyrjnQuo1PIk1IxLy+CWURksDSFPMPy7W5aRNIvo0GtyRL5R/+nIumXsaDef//92blzp36w80hoPer9998/26WI5LWMtVGXl5fT2tqqtYvzTGiHFxFJn4wFtd/v1y4gIiL7wLujPkREPCTdu7jEoqAWEYkjE7u4xKLheSIicWRiF5dYFNQiInFkYheXWBTUIiJxZGIXl1gU1CIicWRiF5dY1JkoIhJHqMNQoz5ERDws3bu4xKKmDxERj1NQi4h4nIJaRMTjFNQiIh6X6J6JpWa21szeNrOtZnZ8ugsTEZGgREd93Ab8u3PuLDMrBoamupBsLngiIuJlcYPazEYAJwIXAzjnOoHOVBaR7QVPRES8LJGmj3HAduAeM2s2szvNbNjAi8xsvpk1mlljspsDZHvBExERL0skqIcA/xv4F+dcFfAlsGTgRc65O5xz1c656lGjRiVVRLYXPBER8bJEgroVaHXOvdb7eC3B4E6ZbC94IiLiZXGD2jn3P8CHZhZafWQG8FYqi8j2giciIl6W6KiPK4EHekd8vANcksoisr3giYiIlyUU1M65TUB1OgvJ5oInIiJeppmJIiIep6AWEfE4BbWIiMcpqEVEPE5BLSIyWC11cMtEqC0NfmypS+nTaysuEZHBaKmDJxZAoHcm9a4Pg48BKual5FPojlpEZDCe++WekA4JdASPp4iCWkRkMHa1Jnd8HyioRUQGY0R5csf3gYJaRGQwZlwP/gELyPlLgsdTREEtIjIYFfNg9moYMRqw4MfZq1PWkQga9SEikpiWumAH4a7WYLPGjOv3hHHFvJQG80AKahGReDIwBC8WNX2IiMSTgSF4sSioRUTiycAQvFgU1CIi8WRgCF4sCmoRkXgyMAQvFgW1iEg8GRiCF0tCoz7M7D3gb0A30OWcS+u2XCIinpPmIXixJDM87x+cczvSVomIiESkpg8REY9LNKgd8B9m1mRm89NZkIiIhEu06WOac67NzA4FnjGzt51zG/pf0Bvg8wHGjBmT4jJFRApXQnfUzrm23o+fAI8BUyJcc4dzrto5Vz1q1KjUVikiUsDiBrWZDTOzA0J/B04B3kx3YSIiEpRI08dhwGNmFrr+Qefcv6e1KhGRfdTQ3Maq9dv4qL2DI0pLWDRrPDVVZdkua1DiBrVz7h1gcgZqEREZlIbmNhY98gaBHgdAW3sHix55AyCnw1rD80Qkb9Su29IX0iGBHkftui1Zqig1FNQikjfaOwJJHc8VCmoREY9TUItI3jhoqD+p47lCQS0ieWPZ7An4fRZ2zO8zls2ekKWKUkN7JopI3giN7Ci44XkiIrmkpqos54N5IAW1iOSXJ6+BpnvBdYP54OiL4Yybs13VoCioRSR/PHkNNN6157Hr3vM4h8NanYkikj+a7k3ueI5QUItI/nDdyR3PEQpqEckf5kvueI5QUItI/jj64uSO5wh1JopIzom6lGmow1CjPkREsqehuY2l9ZvpCATbndvaO1havxlgT1jneDAPpKYPEckpq9ZvY2b3S2wsXsA7+/2QjcULmNn9EqvWb8t2aWmjoBaRnFL9+TOs9N9JedEOigzKi3aw0n8n1Z8/k+3S0kZBLSI5ZWnxIwy1zrBjQ62TpcWPZKmi9FNQi0hOOYwdSR3PBwkHtZn5zKzZzJ5MZ0EiIrHYiPKkjueDZO6orwK2pqsQEZGEzLge/CXhx/wlweN5KqGgNrNy4HTgzvSWIyISR8U8mL0aRowGLPhx9urg8TyV6DjqW4FfAAdEu8DM5gPzAcaMGTP4ykREoqmYl9fBPFDcO2ozOwP4xDnXFOs659wdzrlq51z1qFGjUlagiEihS6TpYyowx8zeAx4GppvZ/WmtSkRE+sQNaufcUudcuXNuLHAO8Lxz7vy0VyYiIoCXxlG31MEtE6G2NPixpS7bFYmIeEJSizI5514EXkx5FS118MQCCHQEH+/6MPgYCqrDQEQkEm/cUT/3yz0hHRLoCB4XESlw3gjqXa3JHRcRKSDeCOpoUz/zeEqoiEiivBHUBTglVEQkUd4I6gKcEioikijvbMVVYFNCRUQS5Y07ahERiUpBLSLicQpqEUk/zTweFO+0UYtIftLM40HTHbWIpJdmHg+aglpE0kszjwdNQS0i6aWZx4OmoBaR9NLM40FTZ2ICGprbWFrfQkegB4Aigx8eO4YbaiZluTKRHBDqMHzul8HmjhHlwZBWR2LCFNRxNDS3cc2aTfT0O9bj4P5XPwBQWIskQjOPB0VNH3GsWr8tLKT7e+i1DzNai4gUJgV1HB+1d0Q91+1cBisR8TBNaEmruE0fZrY/sAHYr/f6tc65ZekuzCuOKC2hLUpY+8wyXI2IByUwoaWhuY1V67fxUXsHR5SWsGjWeGqqyrJUcO5J5I76a2C6c24yUAmcambHpbcs71g0a3zUL9K5x47OaC0inhRnQkuwM34zbe0dOKCtvYOl9ZtpaG7LfK05Km5Qu6Aveh/6e/8UzHv+mqoybj67khL/ni9VkcH5x2nUhwgtdcE76Eh6J7SsWr+NjkB32KmOQDer1m9Ld3V5I6FRH2bmA5qA/wXc7px7LcI184H5AGPGjElljVlXU1Wmt2kiA4WaPKLpndASrZ8nVv+PhEuoM9E51+2cqwTKgSlmNjHCNXc456qdc9WjRo1KdZ3ZpY4Skb1FavII6Teh5YjSkoiXRDsue0tq1Idzrh14ATg1PeV4UOiuYdeHgAt+rP8R3Dcn25WJZFestTr6baW3aNZ4Svy+sNMlfh+LZo1PZ3V5JW5Qm9koMyvt/XsJMBN4O92FeUa0u4Z3X4Inr8l8PSJeEXUNj9Fhk1tqqspYMXcSZaUlGFBWWsKKuZPUnJiERNqovwHc19tOXQTUOeeeTG9ZHhLrrqHpXjjjZg09ksLRUrdnKnjJQVDkh57AnvNR1vBQP8/gxA1q51wLUJWBWrxpRHn0Xm3X3Tf0KNSrHRp6BOgbU/LLwPHSHZ+CrxhKDoaOz7SGRxpprY94ZlwfbJOOxHysWr+Nmd0v8YviOo6wHXzkRnJT1zxWrS9WUEt+idQM2N0JxcNg8bvZqalAaAp5PBXzYNxJkc8dfTHVnz/DSv+dlBftoMigvGgHK/13Uv35M5mtUyTdtAFA1iioE3HROqi+DKy359p8wcdn3MzS4kcYap1hlw+1Tm4t/q2G8kl+0QYAWaOmj0SdcXPwzwCHsSPi5Qaw60O6Hr8y+EVWu53kuhnXh7dRQ1/noTrU00t31INkce4mhnR/Rc9jP9adteS+innB8dEjRgMW/Dh7NQ3dU7WWR5qZS8NSndXV1a6xsTHlz+tJA3vCo/GXhE0CEMkXU1c+H3GFybLSEl5eMj0LFeUmM2tyzlVHOqc76sHqd5cR83dev9XERPKJ1vJIPwV1KlTMg6vfZLl/IbtdcdTLnHrHJc80NLdRFGVddq3lkToK6hSqPH0+17v5dLnIX9a/MjLDFYmkT2iyV6SdjrSWR2opqFOopqqMaWf+hGsC/3evO+vdrpgVnT8IPtBqfJIHIq0zDcGdj7SWR2opqFOspqqMpgNnsiRwOa09I+lxRmvPSJYELqfxwJmRV+N7YoHCWnJOtDboHucU0immcdRpsGjWeJbWd7Kuc1rfsRK/jxWzxsNzEUaIhDoaNSJEcki0/UTVNp16uqNOg5jLOkadhvuhmkIkp2id6czRHXWaRF3WMdZqfH0bE8yHD16NOBNSJJsGzkD8P0eX8cLb2zUjMc0U1JkWaRruXhw03g1jjlNziHhGpCV9H21qU8dhBqjpI9MGTsONymmCjKTPPow80m7i2aOgzoKG7qlM/Xo14756gP8hxkbAmiAj6bCPI480AzF7FNQZFnr7GFrA5sbOH9ATbep5yUEaby2pF2kDgASWONBu4tmTyOa2o83sBTN7y8y2mNlVmSgsXw18+7iuZxr/1v1degZe6CuGr/+29+7nvx6nwJbB2ccNADTKI3sSuaPuAn7mnDsKOA64wsyOSm9Z+SvS28RlXZdydedPwpePLB4evmloSMenmiAjg7OPGwBoN/HsSWRz24+Bj3v//jcz2wqUAW+luba8FG2SQOOBM+HqFXsO1JZGfxJNkJFkDdw93Fcc3O8wJMru4QNpN/HsSKqN2szGEtyR/LUI5+abWaOZNW7fvj011eWhSG8f/UXG7s4uxi15iqkrnw8uuB5veyN1NEqiWuqg4Sd7mtE6PoXuruDu4f02ANAvfu9KeBy1mQ0HHgUWOuc+H3jeOXcHcAcENw5IWYV5JnQ3Epo0MKLEz5edXXy2O9jM0dbewdVrNtH5d5cwb/eq6OOttU+dJOrpxRGa0Xp7RWrbM16OJC+hO2oz8xMM6Qecc/XpLSn/1VSV8fKS6by78nSG7TeEQHf47zUHLP7zt3l90vLeu54BEnybKgIE76CTOS6ek8ioDwPuArY65zSnOcWijUF1wMK3joTF78Lc34V3NE7+YbC9UcP2RApCIk0fU4ELgM1mtqn32D865/6QvrIKR7TORegX4hXz9rQfDtyjMTRZIXSdyEAlB0e+e470bk08Ke4dtXNuo3POnHMVzrnK3j8K6RRZNGt81InkEScS7ONkBSlg3/t1cJRHf77i4HHJCZqZmGU1VWWcd9yYvcI66kSCfZysIAWsYh58//bw5rPv3653YDlEq+d5wA01k6j+5sFhy0eGQnrqyufDl5CMtkyqRoFILP2bzyTnmIuwMeVgVVdXu8bGxpQ/byEZuKQkBO+yf3/M+xyzeVl484e/RONgJSkD15XWOtLZZ2ZNzrnqSOfU9OFR0ZaUPOeV0cFhe/3fxiqkJQkDFwZra+9gaf3m4EQr8SQ1fXhUtGF73c5x4evfZMXc9ZHvgPpPFR5RHhxvrRCXfmKtK627am/SHbVHxVo6Mupi7S118PgV4SvuPX6FxllLGK0rnXsU1B4VaU2Q/iL+UD29OHyhHQg+fnJhiquTXKZ1pXOPmj48KvQW9Gd1b9AdocM34g9VtCnBnV8G76rVBFKw+ncejijx4/dZ2NIFWlfa2xTUHhYK64GjP/qvtpdwj339j7SzeT558hpouhdcN5gPjr446v9tQ3Mbix55g0DvVkLtHQGKgIOG+mnfHdCojxygoPa4RFbbW1q/OXhttKnCIY13wXsvw0/3WqVWcsmT1wT/L0Nc957HEcK6dt2WvpAO6QGcg3dXnp7GQiVV1EadA+KtttfXuZjIlOAdbwd/0CV3Nd2T1PH2jgg7BcU4Lt6joM4xMXvsK+ZB9WUQdfWQXk33prwuySC31w6bsY9LzlNQ55i4PfZn3Axz74j9JK479nnxppa64LK2STpoqD+p4+I9Cuock9BO0BXzwD8s9hNpHevcEradVhRR/s+XzZ6A3xf+LsvvM5bNnpDKCiWNFNQ5Jt5O0A3NbUxd+TxXfXkRMd8Ih9axVljnhojbafVXBLNvjXimpqqMVWdNDvueWXXWZI3yyCFalCmPNDS3sWjtG32djXOKNnKj/y6G2dfRW61LDg7uIiPeVjsi+rkRo7VUQB7QokwFYvkTW8JGhKzrmcbEr+9hYs8aonYwdnyqu2qvizdK5+o3FdJ5TkGdR0Jjqwf6srOb3SWHR/+H2h3Gu1rqoPHu6Oe1nVZBSGRz27vN7BMzezMTBUl6LNl1JlFbubQ7jDfdNyc4o5QYzZPaTqsgJHJHfS9waprrkBQoLYk+3GpdzzQ+Y3jkk9odxnvumwPvvhTzkp09wxn74DDGLnmK8373SoYKk2xIZHPbDUCMecniFbVzYg+3qg1cyG43YJNTLDgCRMP1vCVOSPc4WN51Yd/jl//yqcI6j6WsjdrM5ptZo5k1bt++PVVPK0moqSrj/Agb5Yas65nGksDltPaMxBF6Q937tnrXh8G32b86VIGdbXE6D3sc/Fv3d1nXMy3s+Mt/0f1UvkpZUDvn7nDOVTvnqkeNGpWqp5Uk3VAziVvOrox6fl3PNKZ1rqatZ2TkQO/+Gh77scI6WwYuuDSAc7Aw8BOWdV2awaIk2zTqIw/VVJUx9VuxRwMcYTuin3Q9wQkWkllxQhrga4bsdSct+U9Bnace+NHxHHlo9GnkH7mRsZ8g1nKpknotdXFD2jn4RWB+1PPxfjlL7kpkeN5DwCvAeDNrNbPL0l+WpMIz15zMrWdXUhZhIaebuuZFH64nmdVSF5zOH4Nz8J89E6LeTU/91sE88KPj01GdeICmkBeIhuY2rl6zKWxE7u/9/8QJRVuwaL2PmpqcGbdMjLnYknPw++7vRmyXPmion+brT0lndZIhmkIu1FSVcd6AESEXBq7lP3smxJgIo4WbMiLGhKNYIQ1oBbwCoaAuIKERIf2bQi4MXMtVgZ8Eh+xFCuxAh6aYp1uUCUfxQvr848ZoBbwCoaAuMKFtvfqHdWjIXrQba6cJMWn1+reupGPARKTdrpirogzDM4IhfUPNpAxVKNmmoC5QYRsN9Io2EsQAdn1I1+NXKqwHK7RLS21p3y+/hW8dyeLeiUg9zmjtGcmSwOUROw6LgFvOrlRIFxh1Jhaw6xo2c/+rH/Q9nlO0kZX+OxlqnVH/zW72Z2jtXzNRXv4Jje4I9Nv30l/CVV9ewuMJjI0+cD8fLcu17E6+UmeiRHRDzSRuPbuSEn/w2yBsinmU398l7iteWX1x5orMJ08vDg9pgEAHS4sfiftPzz9ujEK6gA3JdgGSXTVVZWHbeC2t97GucxobixdQHmH2ohkcs/PxTJeZ+1rqok4iOowdlPh9dAQibzqs9mjRHbX06b8fY6wJMT56mLryecYteYqpK5+nobkts4XmohgjZzpKDu/7uvfnM1NIC6A2aomiobmN2Q0T8Nne3x/dDj52IznCdvCRG8lNXfN41nciN86t0HCxKFxtKRZhXI1zsNy/kNrrlmehKvEStVFL0mqqynhu6Ol73VU7B44iyot2UGRQXrSDW/2/ZbG7k0Vr39Ddda/X1/0rXyw7FLdsBG7ZCKK9PfmM4dz3xZQMVye5RkEtUZ2y+AFePeRMulwRzkGXK+JL9meI9YRdV2Rwoe9ZrrO7WLhmE99a+geua9icpaqz7/V1/0pV02KG29eY0fdnYFbvdsXUBi7kiAhrsYj0p85Eien4Bff2/X0IMLS2NOJ1ZnCB71maev6OdT3TuP/VD6hvai2Y5pCG5jZq122hvSPAxuKbGFK09x10KKwdxkfuEG7qmscTPdO4JcKYdpH+1EYtSdn9628ztOPjqOedg26KeKB7etisOp8Z5x47Ou86xhqa21j+xBY+2x1gTtFGfjGkjjLbEXWhK+dg3NcPAsGJROeps1B6xWqjVlBLclrqcPU/irrdV4hz8CX78Y+By8Jm2BX7jJvOmpzTd9nXNWzmodc+pLvfz87yIXdzge9ZiuJ8YbpcEUd+fT9HlJawaNb4nP46SGopqCW1nrwG13hX3LCG8HbZgcFdloNhdd7vXgnbm3BO0UZq/b/nIL6IvlxsL+fg1UPODGtOEglRUEvqPXkNNN4NUZdyiize2srLZk/wXHA3NLexav022trDZxUmMuU+9OPVg/GnQ2oU0hKVglrSo6UuOJEjxqL3kfT/lvuM4dQGLgxrHinxF7Eii52QwRmaLXQEevY6N6doIzf672IYXwPEvYveXfINhi5+Ox1lSp5RUEt6JbApayz9vwXfdmV8r3PVXtcUGfzw2NR2vF3XsJkHX/uAngR/BOYUbeRm/78wJMIkoEgcYHN/px1yJCGDDmozOxW4DfABdzrnVsa6XkFdgJ68BpruCe5gPgj9vx0jjR7ZF0ZvaEYYy5yIREZzRPys1ZfCGTcn/wmlIA0qqM3MB/wZmAm0Aq8D5zrn3or2bxTUBWwfm0OiifTtmaoAT0Qi7dB7KTkYvvdr3UlLUgYb1McDtc65Wb2PlwI451ZE+zcKagH2ucMxEdG+bVMd4huLF1BetPcqghFpM2AZhMEG9VnAqc65y3sfXwAc65z76YDr5gPzAcaMGXP0+++/n4raJR+k+C47noHf0l+4/bi267KIO6ZEs3zI3Zznex4fPYk1d1RfpmYOGZSMBHV/uqOWqFrqggvoR1mbOV2CU7fpG/sdabRJyO/9/8QJRVsSC2grgqMvUUjLoMUK6kTW+mgDRvd7XN57TCR5FfP2NA201EH9fNLRNDKQGWETdA7mC27z/5bb+G3fsW4HPttzfVT+Epi9Wk0ckjGJrJ73OnCkmY0zs2LgHGBdesuSglAxD2rbYe7vgh1wGdZ/ZTszGFK05+9RjRitkJaMi3tH7ZzrMrOfAusJDs+72zm3Je2VSeHof5cN8M/Hwg4PThIxH1z9ZrarkAKU0DKnzrk/AH9Icy0iQT99Lfb5++bAuy9lppb+jr44859TBK1HLbnooigtb+kcXTLuJHUYSkmgCMMAAALUSURBVNYoqCV/DGxCgQHhHZqjmISiYqi5XW3SklUKaslvA8M7ZnAPeDzupOh37yIZpKCWwhLprlvE47S5rYiIxymoRUQ8TkEtIuJxCmoREY9TUIuIeFxatuIys+2A19c5HQkkuNBwXiik11tIrxUK6/Xm82v9pnNuVKQTaQnqXGBmjdGWFMxHhfR6C+m1QmG93kJ6rf2p6UNExOMU1CIiHlfIQX1HtgvIsEJ6vYX0WqGwXm8hvdY+BdtGLSKSKwr5jlpEJCcoqEVEPK6gg9rMVpnZ22bWYmaPmVlptmtKFzP7gZltMbMeM8vb4U1mdqqZbTOz/zazJdmuJ53M7G4z+8TM8n5/MDMbbWYvmNlbvd/HV2W7pkwq6KAGngEmOucqgD8DS7NcTzq9CcwFNmS7kHQxMx9wO/A94CjgXDM7KrtVpdW9wKnZLiJDuoCfOeeOAo4Drsjz/9swBR3Uzrn/cM519T58FSjPZj3p5Jzb6pzblu060mwK8N/OuXecc53Aw8D3s1xT2jjnNgCfZruOTHDOfeyc+6/ev/8N2AqUZbeqzCnooB7gUuDpbBchg1IG9N8wsZUC+mEuFGY2FqgC4uyCnD/yfocXM3sWODzCqWudc4/3XnMtwbdWD2SytlRL5LWK5DIzGw48Cix0zn2e7XoyJe+D2jn33Vjnzexi4AxghsvxQeXxXmsBaANG93tc3ntM8oCZ+QmG9APOufps15NJBd30YWanAr8A5jjndme7Hhm014EjzWycmRUD5wDanTYPmJkBdwFbnXM3Z7ueTCvooAb+GTgAeMbMNpnZ/8t2QeliZmeaWStwPPCUma3Pdk2p1tsx/FNgPcHOpjrn3JbsVpU+ZvYQ8Aow3sxazeyybNeURlOBC4DpvT+rm8zstGwXlSmaQi4i4nGFfkctIuJ5CmoREY9TUIuIeJyCWkTE4xTUIiIep6AWEfE4BbWIiMf9fzGan+7zZOrTAAAAAElFTkSuQmCC\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light",
|
||
"tags": []
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"xs = np.random.normal(size=(128, 1))\n",
|
||
"ys = xs ** 2\n",
|
||
"\n",
|
||
"for _ in range(1000):\n",
|
||
" params = update(params, xs, ys)\n",
|
||
"\n",
|
||
"plt.scatter(xs, ys)\n",
|
||
"plt.scatter(xs, forward(params, xs), label='Model prediction')\n",
|
||
"plt.legend();"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "sBxOB21YNEDA"
|
||
},
|
||
"source": [
|
||
"## Custom pytree nodes\n",
|
||
"\n",
|
||
"So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {
|
||
"id": "CK8LN2PRFnQf"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"class MyContainer:\n",
|
||
" \"\"\"A named container.\"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, name: str, a: int, b: int, c: int):\n",
|
||
" self.name = name\n",
|
||
" self.a = a\n",
|
||
" self.b = b\n",
|
||
" self.c = c"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"id": "OPGe2R7ZOXCT",
|
||
"outputId": "40db1f41-9df8-4dea-972a-6a7bc44a49c6"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[<__main__.MyContainer at 0x7fdec166ce50>,\n",
|
||
" <__main__.MyContainer at 0x7fded89ba490>]"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"jax.tree_leaves([\n",
|
||
" MyContainer('Alice', 1, 2, 3),\n",
|
||
" MyContainer('Bob', 4, 5, 6)\n",
|
||
"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "vk4vucGXPADj"
|
||
},
|
||
"source": [
|
||
"Accordingly, if we try to use a tree map expecting our leaves to be the elements inside the container, we will get an error:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {
|
||
"id": "vIr9_JOIOku7",
|
||
"outputId": "dadc9c15-4a10-4fac-e70d-f23e7085cf74"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"ename": "TypeError",
|
||
"evalue": "ignored",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
||
"\u001b[0;32m<ipython-input-10-d6b45a2ec2b9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m jax.tree_map(lambda x: x + 1, [\n\u001b[1;32m 2\u001b[0m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Alice'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Bob'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m ])\n",
|
||
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/tree_util.py\u001b[0m in \u001b[0;36mtree_map\u001b[0;34m(f, tree, is_leaf)\u001b[0m\n\u001b[1;32m 184\u001b[0m \"\"\"\n\u001b[1;32m 185\u001b[0m \u001b[0mleaves\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtreedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_leaf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 186\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mleaves\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 187\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 188\u001b[0m def tree_multimap(f: Callable[..., Any], tree: Any, *rest: Any,\n",
|
||
"\u001b[0;32m<ipython-input-10-d6b45a2ec2b9>\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(x)\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m jax.tree_map(lambda x: x + 1, [\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Alice'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Bob'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m ])\n",
|
||
"\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for +: 'MyContainer' and 'int'"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"jax.tree_map(lambda x: x + 1, [\n",
|
||
" MyContainer('Alice', 1, 2, 3),\n",
|
||
" MyContainer('Bob', 4, 5, 6)\n",
|
||
"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "nAZ4FR2lPN51",
|
||
"tags": [
|
||
"raises-exception"
|
||
]
|
||
},
|
||
"source": [
|
||
"To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"id": "D_juQx-2OybX",
|
||
"outputId": "ee2cf4ad-ec21-4636-c9c5-2c64b81429bb"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[1, 2, 3, 4, 5, 6]"
|
||
]
|
||
},
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import Tuple, Iterable\n",
|
||
"\n",
|
||
"def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:\n",
|
||
" \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
|
||
" flat_contents = [container.a, container.b, container.c]\n",
|
||
"\n",
|
||
" # we don't want the name to appear as a child, so it is auxiliary data.\n",
|
||
" # auxiliary data is usually a description of the structure of a node,\n",
|
||
" # e.g., the keys of a dict -- anything that isn't a node's children.\n",
|
||
" aux_data = container.name\n",
|
||
" return flat_contents, aux_data\n",
|
||
"\n",
|
||
"def unflatten_MyContainer(\n",
|
||
" aux_data: str, flat_contents: Iterable[int]) -> MyContainer:\n",
|
||
" \"\"\"Converts aux data and the flat contents into a MyContainer.\"\"\"\n",
|
||
" return MyContainer(aux_data, *flat_contents)\n",
|
||
"\n",
|
||
"jax.tree_util.register_pytree_node(\n",
|
||
" MyContainer, flatten_MyContainer, unflatten_MyContainer)\n",
|
||
"\n",
|
||
"jax.tree_leaves([\n",
|
||
" MyContainer('Alice', 1, 2, 3),\n",
|
||
" MyContainer('Bob', 4, 5, 6)\n",
|
||
"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "JgnAp7fFShEB"
|
||
},
|
||
"source": [
|
||
"Modern Python comes equipped with helpful tools to make defining containers easier. Some of these will work with JAX out-of-the-box, but others require more care. For instance:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"id": "8DNoLABtO0fr",
|
||
"outputId": "9a448508-43eb-4450-bfaf-eeeb59a9e349"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"['Alice', 1, 2, 3, 'Bob', 4, 5, 6]"
|
||
]
|
||
},
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import NamedTuple, Any\n",
|
||
"\n",
|
||
"class MyOtherContainer(NamedTuple):\n",
|
||
" name: str\n",
|
||
" a: Any\n",
|
||
" b: Any\n",
|
||
" c: Any\n",
|
||
"\n",
|
||
"# Since `tuple` is already registered with JAX, and NamedTuple is a subclass,\n",
|
||
"# this will work out-of-the-box:\n",
|
||
"jax.tree_leaves([\n",
|
||
" MyOtherContainer('Alice', 1, 2, 3),\n",
|
||
" MyOtherContainer('Bob', 4, 5, 6)\n",
|
||
"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "TVdtzJDVTZb6"
|
||
},
|
||
"source": [
|
||
"Notice that the `name` field now appears as a leaf, as all tuple elements are children. That's the price we pay for not having to register the class the hard way."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "kNsTszcEEHD0"
|
||
},
|
||
"source": [
|
||
"## Common pytree gotchas and patterns"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "0ki-JDENzyL7"
|
||
},
|
||
"source": [
|
||
"### Gotchas\n",
|
||
"#### Mistaking nodes for leaves\n",
|
||
"A common problem to look out for is accidentally introducing tree nodes instead of leaves:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"id": "N-th4jOAGJlM",
|
||
"outputId": "23eed14d-d383-4d88-d6f9-02bac06020df"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[(DeviceArray([1., 1.], dtype=float32),\n",
|
||
" DeviceArray([1., 1., 1.], dtype=float32)),\n",
|
||
" (DeviceArray([1., 1., 1.], dtype=float32),\n",
|
||
" DeviceArray([1., 1., 1., 1.], dtype=float32))]"
|
||
]
|
||
},
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]\n",
|
||
"\n",
|
||
"# Try to make another tree with ones instead of zeros\n",
|
||
"shapes = jax.tree_map(lambda x: x.shape, a_tree)\n",
|
||
"jax.tree_map(jnp.ones, shapes)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "q8d4y-hfHTWh"
|
||
},
|
||
"source": [
|
||
"What happened is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`.\n",
|
||
"\n",
|
||
"The solution will depend on the specifics, but there are two broadly applicable options:\n",
|
||
"* rewrite the code to avoid the intermediate `tree_map`.\n",
|
||
"* convert the tuple into an `np.array` or `jnp.array`, which makes the entire\n",
|
||
"sequence a leaf."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "4OKlbFlEIda-"
|
||
},
|
||
"source": [
|
||
"#### Handling of None\n",
|
||
"`jax.tree_utils` treats `None` as a node without children, not as a leaf:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {
|
||
"id": "gIwlwo2MJcEC",
|
||
"outputId": "1e59f323-a7b7-42be-8603-afa4693c00cc"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[]"
|
||
]
|
||
},
|
||
"execution_count": 14,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"jax.tree_leaves([None, None, None])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "pwNz-rp1JvW4"
|
||
},
|
||
"source": [
|
||
"### Patterns\n",
|
||
"#### Transposing trees\n",
|
||
"\n",
|
||
"If you would like to transpose a pytree, i.e. turn a list of trees into a tree of lists, you can do so using `jax.tree_map`:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {
|
||
"id": "UExN7-G7qU-F",
|
||
"outputId": "fd049086-ef37-44db-8e2c-9f1bd9fad950"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'obs': [3, 4], 't': [1, 2]}"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"def tree_transpose(list_of_trees):\n",
|
||
" \"\"\"Convert a list of trees of identical structure into a single tree of lists.\"\"\"\n",
|
||
" return jax.tree_map(lambda *xs: list(xs), *list_of_trees)\n",
|
||
"\n",
|
||
"\n",
|
||
"# Convert a dataset from row-major to column-major:\n",
|
||
"episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]\n",
|
||
"tree_transpose(episode_steps)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "Ao6R2ffm2CF4"
|
||
},
|
||
"source": [
|
||
"For more complicated transposes, JAX provides `jax.tree_transpose`, which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {
|
||
"id": "bZvVwxshz1D3",
|
||
"outputId": "a0314dc8-4267-41e6-a763-931d40433c26"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'obs': [3, 4], 't': [1, 2]}"
|
||
]
|
||
},
|
||
"execution_count": 16,
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"jax.tree_transpose(\n",
|
||
" outer_treedef = jax.tree_structure([0 for e in episode_steps]),\n",
|
||
" inner_treedef = jax.tree_structure(episode_steps[0]),\n",
|
||
" pytree_to_transpose = episode_steps\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "KlYA2R6N2h_8"
|
||
},
|
||
"source": [
|
||
"## More Information\n",
|
||
"\n",
|
||
"For more information on pytrees in JAX and the operations that are available, see the [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) section in the JAX documentation."
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"colab": {
|
||
"collapsed_sections": [],
|
||
"name": "jax101-pytrees",
|
||
"provenance": []
|
||
},
|
||
"jupytext": {
|
||
"formats": "ipynb,md:myst"
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"name": "python3"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
}
|