rocm_jax/docs/notebooks/JAX_pytrees.ipynb
2019-12-30 11:27:12 -08:00

329 lines
11 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX_pytrees.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "g_vouXWulcNh",
"colab_type": "text"
},
"source": [
"# JAX pytrees\n",
"\n",
"Date: October 2019"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lu00ShwgaPEW",
"colab_type": "text"
},
"source": [
"**This is primarily JAX internal documentation, end-users are not supposed to need to understand this to use JAX, except when registering new user-defined container types with JAX. Some of these details may change.**\n",
"\n",
"Python has a lot of container data types (list, tuple, dict, namedtuple, etc.), and users sometimes define their own. To keep the JAX internals simpler while supporting lots of container types, we canonicalize nested containers into flat lists of numeric or array types at the `api.py` boundary (and also in control flow primitives). That way `grad`, `jit`, `vmap` etc., can handle user functions that accept and return these containers, while all the other parts of the system can operate on functions that only take (multiple) array arguments and always return a flat list of arrays. \n",
"\n",
"We refer to a recursive structured value whose leaves are basic types as a `pytree`. When JAX flattens a pytree it will produce a list of leaves and a `treedef` object that encodes the structure of the original value. The `treedef` can then be used to construct a matching structured value after transforming the leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we handle them assuming referential transparency and that they can't contain reference cycles. \n",
"\n",
"Here is a simple example:\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "X8DlAmOMmufl",
"colab_type": "code",
"outputId": "f5069593-b36e-4f2d-b8f0-7642e7034bbd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node\n",
"from jax import numpy as np\n",
"\n",
"# The structured value to be transformed\n",
"value_structured = [1., (2., 3.)]\n",
"\n",
"# The leaves in value_flat correspond to the `*` markers in value_tree\n",
"value_flat, value_tree = tree_flatten(value_structured)\n",
"print(\"value_flat={}\\nvalue_tree={}\".format(value_flat, value_tree))\n",
"\n",
"# Transform the flt value list using an element-wise numeric transformer\n",
"transformed_flat = list(map(lambda v: v * 2., value_flat))\n",
"print(\"transformed_flat={}\".format(transformed_flat))\n",
"\n",
"# Reconstruct the structured output, using the original \n",
"transformed_structured = tree_unflatten(value_tree, transformed_flat)\n",
"print(\"transformed_structured={}\".format(transformed_structured))"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"value_flat=[1.0, 2.0, 3.0]\n",
"value_tree=PyTreeDef(list, [*,PyTreeDef(tuple, [*,*])])\n",
"transformed_flat=[2.0, 4.0, 6.0]\n",
"transformed_structured=[2.0, (4.0, 6.0)]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sgUJpiXSsRSi",
"colab_type": "text"
},
"source": [
"Pytrees containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ViXja8YxsXZC",
"colab_type": "code",
"outputId": "ff8120b2-f1fc-4647-9e0d-c35ee87bdd2e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 459
}
},
"source": [
"from collections import namedtuple\n",
"Point = namedtuple('Point', ['x', 'y'])\n",
"\n",
"example_containers = [\n",
" (1., [2., 3.]),\n",
" (1., {'b': 2., 'a': 3.}),\n",
" 1.,\n",
" None,\n",
" np.zeros(2),\n",
" Point(1., 2.)\n",
"]\n",
"def show_example(structured):\n",
" flat, tree = tree_flatten(structured)\n",
" unflattened = tree_unflatten(tree, flat)\n",
" print(\"structured={}\\n flat={}\\n tree={}\\n unflattened={}\".format(\n",
" structured, flat, tree, unflattened))\n",
" \n",
"for structured in example_containers:\n",
" show_example(structured)\n",
" "
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"structured=(1.0, [2.0, 3.0])\n",
" flat=[1.0, 2.0, 3.0]\n",
" tree=PyTreeDef(tuple, [*,PyTreeDef(list, [*,*])])\n",
" unflattened=(1.0, [2.0, 3.0])\n",
"structured=(1.0, {'b': 2.0, 'a': 3.0})\n",
" flat=[1.0, 3.0, 2.0]\n",
" tree=PyTreeDef(tuple, [*,PyTreeDef(dict[['a', 'b']], [*,*])])\n",
" unflattened=(1.0, {'a': 3.0, 'b': 2.0})\n",
"structured=1.0\n",
" flat=[1.0]\n",
" tree=*\n",
" unflattened=1.0\n",
"structured=None\n",
" flat=[]\n",
" tree=PyTreeDef(None, [])\n",
" unflattened=None\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"structured=[0. 0.]\n",
" flat=[_FilledConstant([0., 0.], dtype=float32)]\n",
" tree=*\n",
" unflattened=[0. 0.]\n",
"structured=Point(x=1.0, y=2.0)\n",
" flat=[1.0, 2.0]\n",
" tree=PyTreeDef(namedtuple[<class '__main__.Point'>], [*,*])\n",
" unflattened=Point(x=1.0, y=2.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f5iYkKRx2ILR",
"colab_type": "text"
},
"source": [
"## Pytrees are extensible"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Cb35Y8vBtVKp",
"colab_type": "text"
},
"source": [
"By default, any part of a structured value that is not recognized as an internal pytree node is treated as a leaf (and such containers could not be passed to JAX-traceable functions):\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "17JyT_arth7P",
"colab_type": "code",
"outputId": "142485b6-c7fb-4bfd-b9a3-c01e8be7e9e4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"class Special(object):\n",
" def __init__(self, x, y):\n",
" self.x = x\n",
" self.y = y\n",
" \n",
" def __repr__(self):\n",
" return \"Special(x={}, y={})\".format(self.x, self.y)\n",
" \n",
"\n",
"show_example(Special(1., 2.))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"structured=Special(x=1.0, y=2.0)\n",
" flat=[Special(x=1.0, y=2.0)]\n",
" tree=*\n",
" unflattened=Special(x=1.0, y=2.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3y9NECzRthKi",
"colab_type": "text"
},
"source": [
"The set of Python types that are considered internal pytree nodes is extensible, through a global registry of types. Values of registered types\n",
"are traversed recursively:\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "3Emk3EN5uPMr",
"colab_type": "code",
"outputId": "4b5b3ff6-6b80-424c-a6c0-8da97b943a7b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"class RegisteredSpecial(Special):\n",
" def __repr__(self):\n",
" return \"RegisteredSpecial(x={}, y={})\".format(self.x, self.y)\n",
"\n",
"def special_flatten(v):\n",
" \"\"\"Specifies a flattening recipe.\n",
" \n",
" Params:\n",
" v: the value of registered type to flatten.\n",
" Returns: \n",
" a pair of an iterable with the children to be flattened recursively,\n",
" and some opaque auxiliary data to pass back to the unflattening recipe.\n",
" The auxiliary data is stored in the treedef for use during unflattening.\n",
" The auxiliary data could be used, e.g., for dictionary keys.\n",
" \"\"\"\n",
" children = (v.x, v.y)\n",
" aux_data = None\n",
" return (children, aux_data)\n",
"\n",
"def special_unflatten(aux_data, children):\n",
" \"\"\"Specifies an unflattening recipe.\n",
" \n",
" Params:\n",
" aux_data: the opaque data that was specified during flattening of the \n",
" current treedef.\n",
" children: the unflattened children\n",
" \n",
" Returns:\n",
" a re-constructed object of the registered type, using the specified \n",
" children and auxiliary data.\n",
" \"\"\"\n",
" return RegisteredSpecial(*children)\n",
"\n",
"# Global registration\n",
"register_pytree_node(\n",
" RegisteredSpecial,\n",
" special_flatten, # tell JAX what are the children nodes\n",
" special_unflatten # tell JAX how to pack back into a RegisteredSpecial\n",
")\n",
"\n",
"show_example(RegisteredSpecial(1., 2.))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"structured=RegisteredSpecial(x=1.0, y=2.0)\n",
" flat=[1.0, 2.0]\n",
" tree=PyTreeDef(<class '__main__.RegisteredSpecial'>[None], [*,*])\n",
" unflattened=RegisteredSpecial(x=1.0, y=2.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TIBrH5KaxImR",
"colab_type": "text"
},
"source": [
"JAX needs sometimes to compare treedef for equality. Therefore care must be taken to ensure that the auxiliary data specified in the flattening recipe supports a meaningful equality comparison. \n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qoi69-I64_qe",
"colab_type": "text"
},
"source": [
"The whole set of functions for operating on pytrees are in the [tree_util module](https://jax.readthedocs.io/en/latest/jax.tree_util.html).\n"
]
}
]
}