mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Minor edits
This commit is contained in:
parent
b12a8019c8
commit
8777864c96
@ -3,7 +3,7 @@
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "JAX pytrees",
|
||||
"name": "JAX_pytrees.ipynb",
|
||||
"provenance": [],
|
||||
"collapsed_sections": []
|
||||
},
|
||||
@ -46,7 +46,7 @@
|
||||
"metadata": {
|
||||
"id": "X8DlAmOMmufl",
|
||||
"colab_type": "code",
|
||||
"outputId": "2e7821f3-7d0d-48c2-e9d3-6e19cef07255",
|
||||
"outputId": "f5069593-b36e-4f2d-b8f0-7642e7034bbd",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 85
|
||||
@ -69,7 +69,7 @@
|
||||
"\n",
|
||||
"# Reconstruct the structured output, using the original \n",
|
||||
"transformed_structured = tree_unflatten(value_tree, transformed_flat)\n",
|
||||
"print(\"transformed_structured={}\", transformed_structured)"
|
||||
"print(\"transformed_structured={}\".format(transformed_structured))"
|
||||
],
|
||||
"execution_count": 1,
|
||||
"outputs": [
|
||||
@ -79,7 +79,7 @@
|
||||
"value_flat=[1.0, 2.0, 3.0]\n",
|
||||
"value_tree=PyTreeDef(list, [*,PyTreeDef(tuple, [*,*])])\n",
|
||||
"transformed_flat=[2.0, 4.0, 6.0]\n",
|
||||
"transformed_structured={} [2.0, (4.0, 6.0)]\n"
|
||||
"transformed_structured=[2.0, (4.0, 6.0)]\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
@ -92,7 +92,7 @@
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"Pytrees containers can be lists, tuples, dicts, namedtuple. Numeric and ndarray values are treated as leaves:"
|
||||
"Pytrees containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -100,10 +100,10 @@
|
||||
"metadata": {
|
||||
"id": "ViXja8YxsXZC",
|
||||
"colab_type": "code",
|
||||
"outputId": "75ae4b11-262f-47cc-d962-0a3334043cd0",
|
||||
"outputId": "ff8120b2-f1fc-4647-9e0d-c35ee87bdd2e",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 425
|
||||
"height": 459
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
@ -128,7 +128,7 @@
|
||||
" show_example(structured)\n",
|
||||
" "
|
||||
],
|
||||
"execution_count": 0,
|
||||
"execution_count": 2,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -148,7 +148,13 @@
|
||||
"structured=None\n",
|
||||
" flat=[]\n",
|
||||
" tree=PyTreeDef(None, [])\n",
|
||||
" unflattened=None\n",
|
||||
" unflattened=None\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"structured=[0. 0.]\n",
|
||||
" flat=[_FilledConstant([0., 0.], dtype=float32)]\n",
|
||||
" tree=*\n",
|
||||
@ -320,4 +326,4 @@
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user