rocm_jax/docs/jax-101/05.1-pytrees.ipynb
2022-04-01 14:52:16 -07:00

755 lines
34 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "-h05_PNNhZ-D"
},
"source": [
"# Working with Pytrees\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb)\n",
"\n",
"*Author: Vladimir Mikulik*\n",
"\n",
"Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures. In JAX, we refer to these as *pytrees*, but you can sometimes see them called *nests*, or just *trees*.\n",
"\n",
"JAX has built-in support for such objects, both in its library functions as well as through the use of functions from [`jax.tree_utils`](https://jax.readthedocs.io/en/latest/jax.tree_util.html) (with the most common ones also available as `jax.tree_*`). This section will explain how to use them, give some useful snippets and point out common gotchas."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9UjxVY9ulSCn"
},
"source": [
"## What is a pytree?\n",
"\n",
"As defined in the [JAX pytree docs](https://jax.readthedocs.io/en/latest/pytrees.html):\n",
"\n",
"> a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything thats not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.\n",
"\n",
"Some example pytrees:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "Wh6BApZ9lrR1",
"outputId": "37b8d89c-8dd0-4f2b-f479-8333f4b3a2c3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 'a', <object object at 0x7fded60bb8c0>] has 3 leaves: [1, 'a', <object object at 0x7fded60bb8c0>]\n",
"(1, (2, 3), ()) has 3 leaves: [1, 2, 3]\n",
"[1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]\n",
"{'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]\n",
"DeviceArray([1, 2, 3], dtype=int32) has 1 leaves: [DeviceArray([1, 2, 3], dtype=int32)]\n"
]
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"example_trees = [\n",
" [1, 'a', object()],\n",
" (1, (2, 3), ()),\n",
" [1, {'k1': 2, 'k2': (3, 4)}, 5],\n",
" {'a': 2, 'b': (2, 3)},\n",
" jnp.array([1, 2, 3]),\n",
"]\n",
"\n",
"# Let's see how many leaves they have:\n",
"for pytree in example_trees:\n",
" leaves = jax.tree_leaves(pytree)\n",
" print(f\"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_tWkkGNwW8vf"
},
"source": [
"We've also introduced our first `jax.tree_*` function, which allowed us to extract the flattened leaves from the trees."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RcsmneIGlltm"
},
"source": [
"## Why pytrees?\n",
"\n",
"In machine learning, some places where you commonly find pytrees are:\n",
"* Model parameters\n",
"* Dataset entries\n",
"* RL agent observations\n",
"\n",
"They also often arise naturally when working in bulk with datasets (e.g., lists of lists of dicts)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sMrSGSIJn9MD"
},
"source": [
"## Common pytree functions\n",
"Perhaps the most commonly used pytree function is `jax.tree_map`. It works analogously to Python's native `map`, but on entire pytrees:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "wZRcuQu4n7o5",
"outputId": "3528bc9f-54ed-49c8-b79a-1cbea176c0f3"
},
"outputs": [
{
"data": {
"text/plain": [
"[[2, 4, 6], [2, 4], [2, 4, 6, 8]]"
]
},
"execution_count": 2,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"list_of_lists = [\n",
" [1, 2, 3],\n",
" [1, 2],\n",
" [1, 2, 3, 4]\n",
"]\n",
"\n",
"jax.tree_map(lambda x: x*2, list_of_lists)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xu8X3fk4orC9"
},
"source": [
"`jax.tree_map` also works with multiple arguments:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "KVpB4r1OkeUK",
"outputId": "33f88a7e-aac7-48cd-d207-2c531cd37733"
},
"outputs": [
{
"data": {
"text/plain": [
"[[2, 4, 6], [2, 4], [2, 4, 6, 8]]"
]
},
"execution_count": 3,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"another_list_of_lists = list_of_lists\n",
"jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dkRKy3LvowAb"
},
"source": [
"When using multiple arguments with `jax.tree_map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Lla4hDW6sgMZ"
},
"source": [
"## Example: ML model parameters\n",
"\n",
"A simple example of training an MLP displays some ways in which pytree operations come in useful:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "j2ZUzWx8tKB2"
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"def init_mlp_params(layer_widths):\n",
" params = []\n",
" for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):\n",
" params.append(\n",
" dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),\n",
" biases=np.ones(shape=(n_out,))\n",
" )\n",
" )\n",
" return params\n",
"\n",
"params = init_mlp_params([1, 128, 128, 1])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kUFwJOspuGvU"
},
"source": [
"We can use `jax.tree_map` to check that the shapes of our parameters are what we expect:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "ErWsXuxXse-z",
"outputId": "d3e549ab-40ef-470e-e460-1b5939d9696f"
},
"outputs": [
{
"data": {
"text/plain": [
"[{'biases': (128,), 'weights': (1, 128)},\n",
" {'biases': (128,), 'weights': (128, 128)},\n",
" {'biases': (1,), 'weights': (128, 1)}]"
]
},
"execution_count": 5,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"jax.tree_map(lambda x: x.shape, params)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zQtRKaj4ua6-"
},
"source": [
"Now, let's train our MLP:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "iL4GvW9OuZ-X"
},
"outputs": [],
"source": [
"def forward(params, x):\n",
" *hidden, last = params\n",
" for layer in hidden:\n",
" x = jax.nn.relu(x @ layer['weights'] + layer['biases'])\n",
" return x @ last['weights'] + last['biases']\n",
"\n",
"def loss_fn(params, x, y):\n",
" return jnp.mean((forward(params, x) - y) ** 2)\n",
"\n",
"LEARNING_RATE = 0.0001\n",
"\n",
"@jax.jit\n",
"def update(params, x, y):\n",
"\n",
" grads = jax.grad(loss_fn)(params, x, y)\n",
" # Note that `grads` is a pytree with the same structure as `params`.\n",
" # `jax.grad` is one of the many JAX functions that has\n",
" # built-in support for pytrees.\n",
"\n",
" # This is handy, because we can apply the SGD update using tree utils:\n",
" return jax.tree_map(\n",
" lambda p, g: p - LEARNING_RATE * g, params, grads\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "B3HniT9-xohz",
"outputId": "d77e9811-373e-45d6-ccbe-edb6f43120d7"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"xs = np.random.normal(size=(128, 1))\n",
"ys = xs ** 2\n",
"\n",
"for _ in range(1000):\n",
" params = update(params, xs, ys)\n",
"\n",
"plt.scatter(xs, ys)\n",
"plt.scatter(xs, forward(params, xs), label='Model prediction')\n",
"plt.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sBxOB21YNEDA"
},
"source": [
"## Custom pytree nodes\n",
"\n",
"So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "CK8LN2PRFnQf"
},
"outputs": [],
"source": [
"class MyContainer:\n",
" \"\"\"A named container.\"\"\"\n",
"\n",
" def __init__(self, name: str, a: int, b: int, c: int):\n",
" self.name = name\n",
" self.a = a\n",
" self.b = b\n",
" self.c = c"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "OPGe2R7ZOXCT",
"outputId": "40db1f41-9df8-4dea-972a-6a7bc44a49c6"
},
"outputs": [
{
"data": {
"text/plain": [
"[<__main__.MyContainer at 0x7fdec166ce50>,\n",
" <__main__.MyContainer at 0x7fded89ba490>]"
]
},
"execution_count": 9,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"jax.tree_leaves([\n",
" MyContainer('Alice', 1, 2, 3),\n",
" MyContainer('Bob', 4, 5, 6)\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vk4vucGXPADj"
},
"source": [
"Accordingly, if we try to use a tree map expecting our leaves to be the elements inside the container, we will get an error:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "vIr9_JOIOku7",
"outputId": "dadc9c15-4a10-4fac-e70d-f23e7085cf74"
},
"outputs": [
{
"ename": "TypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-10-d6b45a2ec2b9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m jax.tree_map(lambda x: x + 1, [\n\u001b[1;32m 2\u001b[0m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Alice'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Bob'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m ])\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/tree_util.py\u001b[0m in \u001b[0;36mtree_map\u001b[0;34m(f, tree, is_leaf)\u001b[0m\n\u001b[1;32m 184\u001b[0m \"\"\"\n\u001b[1;32m 185\u001b[0m \u001b[0mleaves\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtreedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_leaf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 186\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mleaves\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 187\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 188\u001b[0m def tree_multimap(f: Callable[..., Any], tree: Any, *rest: Any,\n",
"\u001b[0;32m<ipython-input-10-d6b45a2ec2b9>\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(x)\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m jax.tree_map(lambda x: x + 1, [\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Alice'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mMyContainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Bob'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m ])\n",
"\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for +: 'MyContainer' and 'int'"
]
}
],
"source": [
"jax.tree_map(lambda x: x + 1, [\n",
" MyContainer('Alice', 1, 2, 3),\n",
" MyContainer('Bob', 4, 5, 6)\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nAZ4FR2lPN51",
"tags": [
"raises-exception"
]
},
"source": [
"To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "D_juQx-2OybX",
"outputId": "ee2cf4ad-ec21-4636-c9c5-2c64b81429bb"
},
"outputs": [
{
"data": {
"text/plain": [
"[1, 2, 3, 4, 5, 6]"
]
},
"execution_count": 11,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"from typing import Tuple, Iterable\n",
"\n",
"def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:\n",
" \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
" flat_contents = [container.a, container.b, container.c]\n",
"\n",
" # we don't want the name to appear as a child, so it is auxiliary data.\n",
" # auxiliary data is usually a description of the structure of a node,\n",
" # e.g., the keys of a dict -- anything that isn't a node's children.\n",
" aux_data = container.name\n",
" return flat_contents, aux_data\n",
"\n",
"def unflatten_MyContainer(\n",
" aux_data: str, flat_contents: Iterable[int]) -> MyContainer:\n",
" \"\"\"Converts aux data and the flat contents into a MyContainer.\"\"\"\n",
" return MyContainer(aux_data, *flat_contents)\n",
"\n",
"jax.tree_util.register_pytree_node(\n",
" MyContainer, flatten_MyContainer, unflatten_MyContainer)\n",
"\n",
"jax.tree_leaves([\n",
" MyContainer('Alice', 1, 2, 3),\n",
" MyContainer('Bob', 4, 5, 6)\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JgnAp7fFShEB"
},
"source": [
"Modern Python comes equipped with helpful tools to make defining containers easier. Some of these will work with JAX out-of-the-box, but others require more care. For instance:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "8DNoLABtO0fr",
"outputId": "9a448508-43eb-4450-bfaf-eeeb59a9e349"
},
"outputs": [
{
"data": {
"text/plain": [
"['Alice', 1, 2, 3, 'Bob', 4, 5, 6]"
]
},
"execution_count": 12,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"from typing import NamedTuple, Any\n",
"\n",
"class MyOtherContainer(NamedTuple):\n",
" name: str\n",
" a: Any\n",
" b: Any\n",
" c: Any\n",
"\n",
"# Since `tuple` is already registered with JAX, and NamedTuple is a subclass,\n",
"# this will work out-of-the-box:\n",
"jax.tree_leaves([\n",
" MyOtherContainer('Alice', 1, 2, 3),\n",
" MyOtherContainer('Bob', 4, 5, 6)\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TVdtzJDVTZb6"
},
"source": [
"Notice that the `name` field now appears as a leaf, as all tuple elements are children. That's the price we pay for not having to register the class the hard way."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kNsTszcEEHD0"
},
"source": [
"## Common pytree gotchas and patterns"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ki-JDENzyL7"
},
"source": [
"### Gotchas\n",
"#### Mistaking nodes for leaves\n",
"A common problem to look out for is accidentally introducing tree nodes instead of leaves:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "N-th4jOAGJlM",
"outputId": "23eed14d-d383-4d88-d6f9-02bac06020df"
},
"outputs": [
{
"data": {
"text/plain": [
"[(DeviceArray([1., 1.], dtype=float32),\n",
" DeviceArray([1., 1., 1.], dtype=float32)),\n",
" (DeviceArray([1., 1., 1.], dtype=float32),\n",
" DeviceArray([1., 1., 1., 1.], dtype=float32))]"
]
},
"execution_count": 13,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]\n",
"\n",
"# Try to make another tree with ones instead of zeros\n",
"shapes = jax.tree_map(lambda x: x.shape, a_tree)\n",
"jax.tree_map(jnp.ones, shapes)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q8d4y-hfHTWh"
},
"source": [
"What happened is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`.\n",
"\n",
"The solution will depend on the specifics, but there are two broadly applicable options:\n",
"* rewrite the code to avoid the intermediate `tree_map`.\n",
"* convert the tuple into an `np.array` or `jnp.array`, which makes the entire\n",
"sequence a leaf."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4OKlbFlEIda-"
},
"source": [
"#### Handling of None\n",
"`jax.tree_utils` treats `None` as a node without children, not as a leaf:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "gIwlwo2MJcEC",
"outputId": "1e59f323-a7b7-42be-8603-afa4693c00cc"
},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 14,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"jax.tree_leaves([None, None, None])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pwNz-rp1JvW4"
},
"source": [
"### Patterns\n",
"#### Transposing trees\n",
"\n",
"If you would like to transpose a pytree, i.e. turn a list of trees into a tree of lists, you can do so using `jax.tree_map`:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "UExN7-G7qU-F",
"outputId": "fd049086-ef37-44db-8e2c-9f1bd9fad950"
},
"outputs": [
{
"data": {
"text/plain": [
"{'obs': [3, 4], 't': [1, 2]}"
]
},
"execution_count": 15,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def tree_transpose(list_of_trees):\n",
" \"\"\"Convert a list of trees of identical structure into a single tree of lists.\"\"\"\n",
" return jax.tree_map(lambda *xs: list(xs), *list_of_trees)\n",
"\n",
"\n",
"# Convert a dataset from row-major to column-major:\n",
"episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]\n",
"tree_transpose(episode_steps)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ao6R2ffm2CF4"
},
"source": [
"For more complicated transposes, JAX provides `jax.tree_transpose`, which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "bZvVwxshz1D3",
"outputId": "a0314dc8-4267-41e6-a763-931d40433c26"
},
"outputs": [
{
"data": {
"text/plain": [
"{'obs': [3, 4], 't': [1, 2]}"
]
},
"execution_count": 16,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"jax.tree_transpose(\n",
" outer_treedef = jax.tree_structure([0 for e in episode_steps]),\n",
" inner_treedef = jax.tree_structure(episode_steps[0]),\n",
" pytree_to_transpose = episode_steps\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KlYA2R6N2h_8"
},
"source": [
"## More Information\n",
"\n",
"For more information on pytrees in JAX and the operations that are available, see the [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) section in the JAX documentation."
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "jax101-pytrees",
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}