Minor edits

This commit is contained in:
George Necula 2019-10-30 04:52:46 +01:00
parent b12a8019c8
commit 8777864c96

View File

@ -3,7 +3,7 @@
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX pytrees",
"name": "JAX_pytrees.ipynb",
"provenance": [],
"collapsed_sections": []
},
@ -46,7 +46,7 @@
"metadata": {
"id": "X8DlAmOMmufl",
"colab_type": "code",
"outputId": "2e7821f3-7d0d-48c2-e9d3-6e19cef07255",
"outputId": "f5069593-b36e-4f2d-b8f0-7642e7034bbd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
@ -69,7 +69,7 @@
"\n",
"# Reconstruct the structured output, using the original \n",
"transformed_structured = tree_unflatten(value_tree, transformed_flat)\n",
"print(\"transformed_structured={}\", transformed_structured)"
"print(\"transformed_structured={}\".format(transformed_structured))"
],
"execution_count": 1,
"outputs": [
@ -79,7 +79,7 @@
"value_flat=[1.0, 2.0, 3.0]\n",
"value_tree=PyTreeDef(list, [*,PyTreeDef(tuple, [*,*])])\n",
"transformed_flat=[2.0, 4.0, 6.0]\n",
"transformed_structured={} [2.0, (4.0, 6.0)]\n"
"transformed_structured=[2.0, (4.0, 6.0)]\n"
],
"name": "stdout"
}
@ -92,7 +92,7 @@
"colab_type": "text"
},
"source": [
"Pytrees containers can be lists, tuples, dicts, namedtuple. Numeric and ndarray values are treated as leaves:"
"Pytrees containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:"
]
},
{
@ -100,10 +100,10 @@
"metadata": {
"id": "ViXja8YxsXZC",
"colab_type": "code",
"outputId": "75ae4b11-262f-47cc-d962-0a3334043cd0",
"outputId": "ff8120b2-f1fc-4647-9e0d-c35ee87bdd2e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 425
"height": 459
}
},
"source": [
@ -128,7 +128,7 @@
" show_example(structured)\n",
" "
],
"execution_count": 0,
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
@ -148,7 +148,13 @@
"structured=None\n",
" flat=[]\n",
" tree=PyTreeDef(None, [])\n",
" unflattened=None\n",
" unflattened=None\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"structured=[0. 0.]\n",
" flat=[_FilledConstant([0., 0.], dtype=float32)]\n",
" tree=*\n",
@ -320,4 +326,4 @@
]
}
]
}
}