rocm_jax/docs/jax-101/07-state.ipynb
2024-03-07 12:40:10 -08:00

421 lines
29 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Ga0xSM8xhBIm"
},
"source": [
"# Stateful Computations in JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/07-state.ipynb)\n",
"\n",
"*Authors: Vladimir Mikulik*\n",
"\n",
"This section explores how JAX constrains the implementation of stateful programs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Avjnyrjojo8z"
},
"source": [
"## Motivation\n",
"\n",
"In machine learning, program state most often comes in the form of:\n",
"* model parameters,\n",
"* optimizer state, and\n",
"* stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization).\n",
"\n",
"Some JAX transformations, most notably `jax.jit`, impose constraints on the functions they transform. In particular, the function transformed by `jax.jit` must have no side-effects. This is because any such side-effects will only be executed once, when the python version of the function is run during compilation. These side-effects will not be executed by the compiled function on subsequent runs.\n",
"\n",
"Changing program state is one kind of side-effect. So, if we can't have side effects, how do we update model parameters, the optimizer state, and use stateful layers in our models? This colab will explain this in detail, but the short answer is: with [functional programming](https://en.wikipedia.org/wiki/Functional_programming)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s_-6semKkSzp"
},
"source": [
"## A simple example: Counter\n",
"\n",
"Let's start by looking at a simple stateful program: a counter."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "B3aoCHpjg8gm",
"outputId": "5cbcfbf5-5c42-498f-a175-050438518337"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"2\n",
"3\n"
]
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"class Counter:\n",
" \"\"\"A simple counter.\"\"\"\n",
"\n",
" def __init__(self):\n",
" self.n = 0\n",
"\n",
" def count(self) -> int:\n",
" \"\"\"Increments the counter and returns the new value.\"\"\"\n",
" self.n += 1\n",
" return self.n\n",
"\n",
" def reset(self):\n",
" \"\"\"Resets the counter to zero.\"\"\"\n",
" self.n = 0\n",
"\n",
"\n",
"counter = Counter()\n",
"\n",
"for _ in range(3):\n",
" print(counter.count())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SQ-RNLfdiw04"
},
"source": [
"The `n` attribute maintains the counter's _state_ between successive calls of `count`. It is modified as a side effect of calling `count`.\n",
"\n",
"Let's say we want to count fast, so we `jax.jit` the `count` method. (In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of wanting to JIT-compile the update of model parameters, where `jax.jit` makes an enormous difference)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "5jSjmJMon03W",
"outputId": "d952f16b-9b30-4753-ed94-cc914a929a36"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"1\n",
"1\n"
]
}
],
"source": [
"counter.reset()\n",
"fast_count = jax.jit(counter.count)\n",
"\n",
"for _ in range(3):\n",
" print(fast_count())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "weiI0V7_pKGv"
},
"source": [
"Oh no! Our counter isn't working. This is because the line\n",
"```\n",
"self.n += 1\n",
"```\n",
"in `count` is only called once, when JAX compiles the method call. Moreover, since the return value doesn't depend on the arguments to `count`, once it returns the first 1, subsequent calls to `fast_count` will always return 1. This won't do. So, how do we fix it?\n",
"\n",
"## The solution: explicit state\n",
"\n",
"Part of the problem with our counter was that the returned value didn't depend on the arguments, meaning a constant was \"baked into\" the compiled output. But it shouldn't be a constant -- it should depend on the state. Well, then why don't we make the state into an argument?"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "53pSdK4KoOEZ",
"outputId": "5ac72b9c-7029-4bf2-de8d-1d412bd74c79"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"2\n",
"3\n"
]
}
],
"source": [
"CounterState = int\n",
"\n",
"class CounterV2:\n",
"\n",
" def count(self, n: CounterState) -> tuple[int, CounterState]:\n",
" # You could just return n+1, but here we separate its role as \n",
" # the output and as the counter state for didactic purposes.\n",
" return n+1, n+1\n",
"\n",
" def reset(self) -> CounterState:\n",
" return 0\n",
"\n",
"counter = CounterV2()\n",
"state = counter.reset()\n",
"\n",
"for _ in range(3):\n",
" value, state = counter.count(state)\n",
" print(value)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PrBjmgZtq89b"
},
"source": [
"In this new version of `Counter`, we moved `n` to be an argument of `count`, and added another return value that represents the new, updated, state. To use this counter, we now need to keep track of the state explicitly. But in return, we can now safely `jax.jit` this counter:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "LO4Xzcq_q8PH",
"outputId": "25c06a56-f2bf-4c54-a3c3-6e093d484362"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"2\n",
"3\n"
]
}
],
"source": [
"state = counter.reset()\n",
"fast_count = jax.jit(counter.count)\n",
"\n",
"for _ in range(3):\n",
" value, state = fast_count(state)\n",
" print(value)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MzMSWD2_sgnh"
},
"source": [
"## A general strategy\n",
"\n",
"We can apply the same process to any stateful method to convert it into a stateless one. We took a class of the form\n",
"\n",
"```python\n",
"class StatefulClass\n",
"\n",
" state: State\n",
"\n",
" def stateful_method(*args, **kwargs) -> Output:\n",
"```\n",
"\n",
"and turned it into a class of the form\n",
"\n",
"```python\n",
"class StatelessClass\n",
"\n",
" def stateless_method(state: State, *args, **kwargs) -> (Output, State):\n",
"```\n",
"\n",
"This is a common [functional programming](https://en.wikipedia.org/wiki/Functional_programming) pattern, and, essentially, is the way that state is handled in all JAX programs.\n",
"\n",
"Notice that the need for a class becomes less clear once we have rewritten it this way. We could just keep `stateless_method`, since the class is no longer doing any work. This is because, like the strategy we just applied, object-oriented programming (OOP) is a way to help programmers understand program state. \n",
"\n",
"In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?\n",
"\n",
"Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I2SqRx14_z98"
},
"source": [
"## Simple worked example: Linear Regression\n",
"\n",
"Let's apply this strategy to a simple machine learning model: linear regression via gradient descent.\n",
"\n",
"Here, we only deal with one kind of state: the model parameters. But generally, you'll see many kinds of state being threaded in and out of JAX functions, like optimizer state, layer statistics for batchnorm, and others.\n",
"\n",
"The function to look at carefully is `update`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "wQdU7DoAseW6"
},
"outputs": [],
"source": [
"from typing import NamedTuple\n",
"\n",
"class Params(NamedTuple):\n",
" weight: jnp.ndarray\n",
" bias: jnp.ndarray\n",
"\n",
"\n",
"def init(rng) -> Params:\n",
" \"\"\"Returns the initial model params.\"\"\"\n",
" weights_key, bias_key = jax.random.split(rng)\n",
" weight = jax.random.normal(weights_key, ())\n",
" bias = jax.random.normal(bias_key, ())\n",
" return Params(weight, bias)\n",
"\n",
"\n",
"def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:\n",
" \"\"\"Computes the least squares error of the model's predictions on x against y.\"\"\"\n",
" pred = params.weight * x + params.bias\n",
" return jnp.mean((pred - y) ** 2)\n",
"\n",
"\n",
"LEARNING_RATE = 0.005\n",
"\n",
"@jax.jit\n",
"def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:\n",
" \"\"\"Performs one SGD update step on params using the given data.\"\"\"\n",
" grad = jax.grad(loss)(params, x, y)\n",
"\n",
" # If we were using Adam or another stateful optimizer,\n",
" # we would also do something like\n",
" # ```\n",
" # updates, new_optimizer_state = optimizer(grad, optimizer_state)\n",
" # ```\n",
" # and then use `updates` instead of `grad` to actually update the params.\n",
" # (And we'd include `new_optimizer_state` in the output, naturally.)\n",
"\n",
" new_params = jax.tree_map(\n",
" lambda param, g: param - g * LEARNING_RATE, params, grad)\n",
"\n",
" return new_params"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dKySWouu2-Hu"
},
"source": [
"Notice that we manually pipe the params in and out of the update function."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "jQCYYy0yxO6K",
"outputId": "1f3b69d2-e90b-4065-cbcc-6422978d25c2"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deYDV8/7H8ednplHTOtGEJiouJVqGKWm7t3LrIhlxyZVsCdm5UZaSm1vUvdxf1pCtKJRBcVsUKaJpFUm4baO0GUqTZvn8/jhzzsw58z3LzDnTWXo9/uG7nO/3O9L7fOb9fX/eH2OtRURE4ldStB9ARETCo0AuIhLnFMhFROKcArmISJxTIBcRiXM1onHTRo0a2ebNm0fj1iIicWv58uW7rLXpvvujEsibN29Obm5uNG4tIhK3jDGbnPYrtSIiEucUyEVE4pwCuYhInItKjtxJYWEhW7du5cCBA9F+FImgWrVq0bRpU1JSUqL9KCIJK2YC+datW6lXrx7NmzfHGBPtx5EIsNaye/dutm7dSosWLaL9OCIJK2YC+YEDBxTEE4wxhqOOOoqdO3dG+1FEoi5nZR7j56znx/wCmqSlMqxPS7IzMyJy7ZgJ5ICCeALSn6mIK4iPmPklBYXFAOTlFzBi5pcAEQnmetkpIlLNxs9Z7wnibgWFxYyfsz4i11cgL8cYw8CBAz3bRUVFpKen07dv30pdp3nz5uzatatK5zRv3pw2bdrQtm1bevfuzfbt2yt17/IefPBBJkyYAMDIkSOZP3++33NXrVrF+++/79l+9913GTduXJXvLSJlfswvqNT+ylIgL6dOnTqsXbuWggLXf9x58+aRkRGZHFZlLFy4kDVr1pCVlcU///lPr2PWWkpKSip9zYceeoizzz7b73HfQN6vXz+GDx9e6fuISEVN0lIrtb+yFMh9nHvuucyePRuA119/ncsuu8xzbM+ePWRnZ9O2bVs6derEmjVrANi9eze9e/fm1FNPZfDgwZRfdWnKlCl07NiR9u3bc/3111Nc7P3rVSDdu3fnu+++Y+PGjbRs2ZJBgwZx2mmnsWXLFsaPH0+HDh1o27Yto0aN8nzm4Ycf5uSTT6Zr166sX1/2a9tVV13FW2+9BcCyZcvo3Lkz7dq1o2PHjvzyyy+MHDmS6dOn0759e6ZPn85LL73EzTffDMDGjRvp2bMnbdu2pVevXmzevNlzzVtvvZXOnTtzwgkneK4vIt6G9WlJakqy177UlGSG9WkZkevH1MtOj9tvh1WrInvN9u3h8ceDnjZgwAAeeugh+vbty5o1a7jmmmv45JNPABg1ahSZmZnk5OSwYMECBg0axKpVqxg9ejRdu3Zl5MiRzJ49mxdeeAGAdevWMX36dJYsWUJKSgpDhw5l6tSpDBo0KKRHnjVrFm3atAFgw4YNvPzyy3Tq1Im5c+eyYcMGvvjiC6y19OvXj0WLFlGnTh2mTZvGqlWrKCoq4vTTT+eMM87wuubBgwe59NJLmT59Oh06dODXX3+ldu3aPPTQQ+Tm5vLEE08A8NJLL3k+c8stt3DllVdy5ZVXMnnyZG699VZycnIA2LZtG4sXL+abb76hX79+XHzxxSH9bCKHE/cLzcOiaiUWtG3blo0bN/L6669z7rnneh1bvHgxM2bMAKBnz57s3r2bX3/9lUWLFjFz5kwAzjvvPBo2bAjAhx9+yPLly+nQoQMABQUFNG7cOOgz9OjRg+TkZNq2bcuYMWPIz8+nWbNmdOrUCYC5c+cyd+5cMjMzAdi3bx8bNmxg7969XHjhhdSuXRtwpUd8rV+/nmOPPdbzTPXr1w/6PJ999pnn57viiiu4++67Pceys7NJSkqidevW/PTTT0GvJXK4ys7MiFjg9hWbgTyEkXN16tevH3//+9/56KOP2L17d5WvY63lyiuvZOzYsZX63MKFC2nUqJFnOz8/nzp16nhdd8SIEVx//fVen3s8Cv/datas6fl3LeQtEh3KkTu45pprGDVqlCet4datWzemTp0KwEcffUSjRo2oX78+3bt357XXXgPggw8+4OeffwagV69evPXWW+zYsQNw5dg3bXLsQlkpffr0YfLkyezbtw+AvLw8duzYQffu3cnJyaGgoIC9e/fy3nvvVfhsy5Yt2bZtG8uWLQNg7969FBUVUa9ePfbu3et4v86dOzNt2jQApk6dSrdu3cL+GUQkcmJzRB5lTZs25dZbb62w/8EHH+Saa66hbdu21K5dm5dffhlw5c4vu+wyTj31VDp37szxxx8PQOvWrRkzZgy9e/empKSElJQUnnzySZo1axbW8/Xu3Zt169Zx1llnAVC3bl2mTJnC6aefzqWXXkq7du1o3LixJ31S3hFHHMH06dO55ZZbKCgoIDU1lfnz59OjRw/GjRtH+/btGTFihNdnJk6cyNVXX8348eNJT0/nxRdfDOv5RSSyTDR+Hc7KyrK+C0usW7eOU0455ZA/i1Q//dmKRIYxZrm1Nst3v1IrIiJxToFcRCTOxVQgV9VD4tGfqUj1i5lAXqtWLXbv3q2/+AnE3Y+8Vq1a0X4UkYQWsaoVY0wykAvkWWsr12UKV6XI1q1b1bs6wbhXCBKR6hPJ8sPbgHVA8KmCDlJSUrSKjIgkHPeCEufOnUq/75eyacoM+nZrFdF7RCSQG2OaAucBDwN3RuKaIiLxLmdlHlOemMGSybd59l3/9iqK6taL6HT9SI3IHwfuBur5O8EYMwQYAngmzIiIJKy9e+nZ5RSyC8pmTHe46VV2HlGP8XPWRzSQh/2y0xjTF9hhrV0e6Dxr7SRrbZa1Nis9PT3c24qIxCZr4YYboH596pcG8SsueYjm98xiZ11XQ71ILSjhFokReRegnzHmXKAWUN8YM8VaOzDI50REEst770G5rqPTulzE8K5XVzgtUgtKuIU9IrfWjrDWNrXWNgcGAAsUxEXksJKXB8aUBfGmTWHfPmpN/E+1LijhFjN15CIicae4GHr2dAVut9WrYcsWqFOH7MwMxvZvQ0ZaKgbISEtlbP82Ee9LHjNNs0RE4soTT8Att5RtP/kkDB1arbf01zRLbWxFRAJw14G7l2j7R4tieg7oXXZCr14wZw4kJ/u/SDVTIBcR8SNnZR4jZn5JQWExqQcPMPPhQRy9b0/ZCVu3Qkb1LN9WGcqRi4j4MX7OegoKixk5fxLrHrvYE8SHDRrjKjOMgSAOGpGLiPj1hxWLWfLmKM/2K5nnMbL3jRhgfPQeqwIFchERX9u3w7HH8nLp5u7U+nS//nl+q1kbiHwdeLgUyEVE3EpKoG9f+OADz66Lrp3I8kZlDf2qow48XMqRi4gAPPecq/LEHcT//W+wlituurDa68DDpRG5iBzevv4aTj21bLtLF/joI6jhCo/ZmRkxF7h9KZCLyOGpoMAVwP/3v7J9GzdCs2ZRe6SqUmpFRA4/w4dD7dplQXzGDFc5YRwGcdCIXETikO9sy2F9WoaW/liwwDUT0+2aa+D5510Nr+KYArmIxJXysy0B8vILGDHzSwD/wXznTmjcuGy7bl3XrMwGDar7cQ8JpVZEJK64Z1uWV1BYzPg56yuebC307+8dxJcuhb17Kx3Ec1bm0WXcAloMn02XcQvIWZlXlcevFgrkIhJX/K2uU2H/K69AUhK8/bZre+xYV2A/88xK39P9W0BefgGWst8CYiWYK7UiInGlSVoqeQ7B3DPb8ttvoWW5CTunnw6ffQZHHFHlewb6LSAWShM1IheRuDKsT0vHVXfu6dEcWrf2DuLffw/Ll4cVxKESvwVEiQK5iMQVp1V33t4xl36dToR161wnvf66K41ywgkRuae/3iqx0nNFqRURiTue2ZaLF0O3bmUH/vY3mDIl4uWEw/q09KqUgdjquaJALiLxZ88eSE93NbkCV4+UHTvgyCOr5XbuPHiVatcPAQVyEYkf1sLll7tSJ26ffAJdu1b7rWO554oCuYjEjIAzNqdPhwEDyk4eNQoefDAqzxlrFMhFJCb4m7FZe+smevfrUnZi69awYgXUrBmlJ409CuQiEhN8a7VTigt586XbOG3M92UnrV8PJ58chaeLbSo/FJGYUL4m+8alb7JhwoWc9lNpEH/5ZVd+XEHckUbkIhITmqSlkv71KnJevcuz74OTOzPmqodYMqhXgE+KArlIAqlye9do++UXPh7dlxoHykblp98ylYIGRzL2L62i+GDxQYFcJEFUqb1rkOtV+5eCtXDttfDii55gdMvgCcw6qhVN0lIZGS9fRFEWdiA3xhwHvAIcDVhgkrX2P+FeV0QqJ5KNnSL9peDo7bddLWbdhg+HsWOZCEyMzB0OG5EYkRcBd1lrVxhj6gHLjTHzrLVfR+DaIhKiSDZ2qtZuf5s3ey+pdsIJsHYtpMZG35J4FHbVirV2m7V2Rem/7wXWAfpdSOQQi2Rjp2rp9ldUBJ07ewfxr75ydShUEA9LRMsPjTHNgUzgc4djQ4wxucaY3J07d0bytiKC//auVWnsFPFuf489Bikprr7gAJMmufLjrVtX7XriJWKB3BhTF5gB3G6t/dX3uLV2krU2y1qblZ6eHqnbikgpp/auY/u3qVIqJNQvhWDLny187b+uToR33gnA9q49obgYrruu0s8k/kWkasUYk4IriE+11s6MxDVFpPIi1dgplG5/AV+I/qE+B487nh6/5HvO73DTq+xr2Iixq7epEiXCIlG1YoAXgHXW2n+H/0giEgucvhTKlyQmGUOxtV7HCwqLKblxKHz+Lu41eQb9dTSLTjjDtRFDy6MlkkiMyLsAVwBfGmNWle6711r7fgSuLSLVqDK14r4jcN8g3uP7Zbz41mjP9gtZF/CPXhVTKE7rbUp4wg7k1trFQGSX4xCRalfZWnGnkkSAo/fu4vOnrvJsb697JMvmLGXyoi3gELRN6b01Ko8cNc0SSTDBXkC6BaoVd+JbephUUszUafd6BfFzrv4/Ot30CuMWbWFYn5aOIzxbem+JHAVykQTiHmXn5RdgKRtlOwXzytaKly89HLhiNj+Mv4Aum9YA8MCfb6D5PbNY1/gEzzWyMzOwjleKndXnE4UCuUgCqcwou7K14sP6tKTdns1sfKQvY+Y9DcCnx7flhGHv8OrpfR2vkRHjq88nCgVykQRSmVF2KLXiOSvzaD96LqfcOYNO3dvyznNDPccuGD6N9/8zhZo1j/B7jUhOUhL/1P1QJIE0SUt1rApxGgEHqxXPWZnHsDdXc8+8SQzOfcfzuSF/HcW5I67jndLzspod6fcasb76fKIw1vrLYlWfrKwsm5ube8jvK5LofCtRwDUCrsoMzzuueZTHXrzHs/1q5rk80Ns1Is9IS2XJ8J6ReWgJmTFmubU2y3e/RuQiCaB8PXha7RRq1kjil4JCmqSl0qNVOuPnrOeO6atCGxFv3w7HHstjpZs/16pH1xte4LeatT2nqBY8tiiQi8Q531H4z/sLSU1J5rFL2wM41ornbtrDwm92eqc72h0L558P75fN5et75eOsPeYPFe6pWvDYokAuEueCVao4HZu6dLOnNDAvv4DcByaQPfvxspP+9S9yelzK+jdXQ0nF9Ku7FtwpkMftcnNxTIFcJIZUJQhWpXe4OzSfuGsLH75wY9mBs86CRYugRg2yS3fdPn2V78f9Xv+QrCwkFSiQi8SIqgZBf5UqScZQP7UGP+8vrHCsZuHvzJ18E83yt3v2dblhMjRrxo/3z/H6Ehk/Z33IlTDVurKQ+KU6cpEYUdkp825Otdrgamq170ARKcneE+WHffwy6/99kSeI35A9gub3zOLHBo29ZoTeMX0VzYfPZv/BIlKSvK/hrxa8WlYWkqA0IheJEVUNgu6R7l1vrK7QkbCwxJKWmkKdmjU4bs3nTHv9Xs+xN9qczd3n3AbGYKDCdHr39s/7C0lJNqSlpngqYfylfCpTxy6Ro0AuEiPCCYLZmRnc4SeXbXbvYsnEy8t2pKYy+7+5/OfT7ZjSXHywcsLCYkudmjVYNap3wPOG9WnpWMeumZzVS4FcJEZUNgj6vhhtkJpCfkG5fLi1PJ0zlnO+/dSza8iNE5lXvwVNPt3uNaruMm5B0GDu+5tBoBezqlo5tBTIRWJEZYKg04vR8rnwC9cu4LHZZQt2Tex5FU91vtRvPXlefoFjeqW88r8ZBHsxq8B9aCmQixwCoZYVhhoEnV6MFhZbmu/J46PnrvfsW3v0iVx4xQQKk1MgSD25BU8w9w3qvr8ZqDoltiiQi1SzqpYVBgr+vmmOI4oKee/l22i5a7NnX/chz7G54bEBn83pBWdG6b0CffGoOiW2KJCLVLOqjF6DBf/yLyhvXfI6dy6e6vnsXdn3MKNlN892akoytVKSHOvJnbgXhahK7bqqU6JDdeQi1awqo9dgNeXD+rSk8/Zv2PhIX08Qn3Xqn8hZvoVuI28hIy0Vg2t0PbZ/G0adf2qFWnN/C+2GEozVZzy2aEQuUs2qMnr1F+Tz8guY/fFXZJ/dnuyiIgBKMJx3/wyu79+xQh9wX+XTJT1apTNjeV6VSgVVnRJbFMhFqllVaqsdg7+1PD5rAuc98nHZvkWLSOrWjQ9CeA6ndEmgRSGqcj2JDgVykQjx93KyKqNX3+B/3rpPePLdRzzHX+w5kKs/fDXsZ1YwTgwK5CIREEpddWUCZnZmBrmb9vDRf79g8bODPfs3HHUc5131fxTWSOHqyP4IEscUyEXClLMyjzumr6pQyhdOXfU7X2xkwNCLGLNtg2dfz8HP8MNRTQH/q9PL4UmBXCQMOSvzuNMhiLtVqa76kUe4YPhwz+Zd597BjDa9PNuqDhFfCuQiYXjw3a8oCXA81LrqnJV5zHr+HZ5/6ibPvv+efBY3Zo/AGu8q4aospCyJLSKB3BjzF+A/QDLwvLV2XCSuKxLrvJpUOQhl5Dxr0Tr+/OczyD5YNnrPunkKu+qkVTjXnVLpMm6Byv7EI+xAboxJBp4E/gxsBZYZY9611n4d7rVF4l3AAGstXHcdfV94wbPrb5eO4dPmrkWTnfqd9GiVHrGl1LS2ZuKIxMzOjsB31tofrLUHgWnABRG4rkjMa1g7xe8xY1zB0lFODiQlQWkQf+bMi2h+zyxPEIeyviflZ2gu/GZnlVYRqnD70iqb8isCjZj5pf/nlZgWidRKBrCl3PZW4Ezfk4wxQ4AhAMcff3wEbisSfaPOP5W73lxNsdNK85aKo+XNm6FZs7KTWrSgx9VP8r/fKmbaM9JSWTK8p9c+f4tHVPalqroXJpZD1mvFWjvJWptlrc1KT08/VLcV8ZKzMo8u4xbQYvhsuoxbUKkRqNNnszMz+Ndf2/kdmXtGy0VF0LWrdxBfuxZ++IHb+rYNuW+Jv5enlW1Wpe6FiSUSI/I84Lhy201L94lEXfk8cFrtFPYdKKKwdPRcmfxyKK1ob/czWu4zbxqMKCsf5NlnYcgQz2ZlZn5Gaik1dS9MLJEI5MuAk4wxLXAF8AHA3yJwXZGw+AZfpzauoaYTgqUinHLUp/70PbNfuq1sxznnwKxZrty4j1BnfkaqWZXW1kwsYQdya22RMeZmYA6u8sPJ1tqvwn4ykTA5BV8noaQTgqUiyh+vfbCAT565lqMKfi134o9wbOBFHkIVif4o6l6YWCJSR26tfR94PxLXEomUUPO9oaQT/KUi0kpz4+7jo+c9zZUrZnuO33T5GJ6ccl+IT3xoqWFW4tDMTklY/oKvr7z8AtqPnsuD/U4FnEepPVqlM2Xp5gqf/aWgkJyVeUyos5WzRlzp2f/iGefz6DlDGdu/TeR+IBE/FMglYTnlgVOSDTWSDAWF3uV++QWF3PnGKpKNcXwZuvCbnY73aPTrbrJPb+rZ3lXvKLoPfpaGjRsyVqkKOUQUyCVh+csDj5+z3nGkXmKhxHrXg7tfaPqmaZJKinnljZF03bTas2/h63PoMaA3mtIsh5oCuSQ0pzywv0k1/ri/BNzBf+DK9xkz9ynP8VFnX8/LZ5xPxsYaLAn/kUUqTYFcDjuh5s7Lnz+sT0uee/IdZr9Q1p1w6XGn8bcBD1OS5JrMo8k0Ei0K5BL3Ktv8aViflgx7azWFxd5plCSDV44cXDn14r376PjH9mTv3eXZ3+nGl9hev5HX5zWZRqJFgVziWigzLn25949+7yvPJKG01JQKVStptVO4afYzDP7ibc9nh14yiiMvu4hfludBFSfTqOugRJqx1t/aJtUnKyvL5ubmHvL7SuLpMm6BY5rEqeFUpcybB717ezantv8L9/W+CYwho9xL08oGY98vHnB9CWixCAmFMWa5tTbLd79G5BLXIt786aef4JhjPJu/1KxDlxtfZF/N2l7XrspkmpyVedz1xmqK/VTGKJBLVR2y7oci1cFfXjrJmJA7HOaszKPrP+fz4R86egXxa296ina3T/cK4oHuGeweI2Z+WSGIu+lFqYRDI3KJqnDzxf5mXLoDZrCcec7KPJaNnMDiWY979j3a61pOHj+a84FPI9RYKljfF70olXAokEvU+HtRmbtpDwu/2RlScPc347I839SF+8uj1vff8uHzN5Jdet7KY1vy18sfoSi5Bhlz1nty7JF4MRloxK2ugxIuBXKJGn+tYacu3exZqzLYiDrUlERefgEths+mVkoSB38vZODK9xk9/1nP8a43vMDWBkdX+rqh8le7nmyMXnRK2BTIJWr8BUvfLHKgl4GVmdxjgcwNKxj14SRa7trMiiYtea7DhXzQqqvjde/P+bJSXypO3KP/vPwCx8WUFcQlEhTIJWoqE4T9BX2nxlhOmv7yE/cteIFzvv2ULQ2OZsiF9zH3pE6uFZJ9uFerLx/E3covdBws5eKbOrLgCeYZqh+XCFIgl6hxCsK+o1Y3fy8DnRpj9WiV7smx1yw8wI1LZ3D9FzMoMYYJ3QbyXIcL+T2lptd1ko2hxFqvxlr+Zli4R+bBJiE5pY7cQTysGncRHwrkEjX+gvCM5XmVqhRxrOm2Ft58k+1DbuGYX3bw7indGfunq9lW33nh78vOPI4x2WW9wwM11ko2JqQV6LXAsRwqCuQSVU5BOKvZkeFViqxZA7feCh9/TOrJrbmi39/5pEnrgB+ZtXqbV6VMWu0UxzU+DYRcC64FjuVQUSCXmFPlJch274aRI+GZZ6BhQ3j6aRpcdx0XrdnOD6VfDP7SJfkFheQXuAK3v7y9AS7vdDwLv9kZUoDWAsdyqCiQS/wrLoZJk+D++yE/H4YOhdGj4cgjAe8vBn+9Wfxxejnpr1+Kb4DWAsdyqCiQS1SF3Qnw449daZQ1a1hxQjvuzX6IvRmnMGxTAdlHVryPUxlgIE4vJysToLXAsRwKCuQSNVVpQeuxZQsMGwbTp7P/mAxGXHQv75x4lquc0Oc6vvXgvmWA+w8WOebD3ZxeTipASyxRIJeo8TezM2AnwIICmDABxo51VaaMGkXfpI78sN+5oyDgWA9ucfUgXzK8p2OqpDy9nJRYp0AuUeOvDC8vv4Au4xZ4py3aN4G334a77oKNG+Hii10BvVkzfhg+2+/1A9WD5xcUkrMyz/Ol8eC7X3leeLrp5aTEA7WxlWqRszKPLuMWBGwl62+ka3AFc1v6zxeencWOTt3hoougbl1YsADefBOaNSNnZR4V52aWXT9YzbZ71J6dmcGqUb15/NL2ZKSlYnClXTSFXuKBRuQScaHmvp2mwZd/EVn/wD7uWDyVK1bMZn+tOjBxItxwA9Qo+9/W34jbgGeGZqAqFd9jyn1LPNKIXCIuUO7bLWdlHjOW5zkG8aSSYgas+i8LJw3hyuWzmNauD38a/AzcfLNXEIfAjbeyMzMY1qclqSnJfp812aHXiki8CWtEbowZD5wPHAS+B6621uZH4sEkfoUyNd1fH5Iztn7N6PnPctpP3/NF09aMPvt6vjr6RDL8pGH8zZ50n+8eXd/uZ8q9v1maIvEk3BH5POA0a21b4FtgRPiPJPHOX+67/H7fYH/03l089t4EZky9m6N+y+fW84dxyd8e4aujTwz4wtFpxO17fnZmht8vAn/7ReJJWIHcWjvXWltUurkUaBr+I0m8cwquBldO3M0d1GsWHWToZ2+w4LkbOHf9EiaedSk9r3uWd1v/0dNitmYN//+bZmdmMLZ/m6AvKEMJ+CLxytgI/WppjHkPmG6tneLn+BBgCMDxxx9/xqZNmyJyX4lNlz/3GUu+3+O1r/xCCjkrtjL3kee5e+4kmudvY85JnRjTczBb0o5xvF4kFmEIexapSJQZY5Zba7Mq7A8WyI0x8wGnv133WWvfKT3nPiAL6G9D+GbIysqyubm5IT24xCZ/QTFnZZ5jPbZbRloqS7KbwO23w5w5/K/x8Tzwp+tY3CIz6D3Vx1sOd/4CedCXndbas4Nc+CqgL9ArlCAu8al84G6QmsJvB4soLPZeqT53054KvcTLq/f7b1w983l4YBbUrg2PPUaLm27i4rU7WB7CKj/q4y3iLNyqlb8AdwN/tNbuj8wjSazx7VXiNNouKCzm9c+3OFaBGFvCxV9+yN0fv8xRBb/AtdfCww9D48aAdxOqQDXfmiov4iysHLkx5jugJrC7dNdSa+0NwT6n1Er8yFmZxx3TV4XcLdBX+x/X8+D8Z2i/bQPLm7Ri76P/4k+Xn+v3fH9tZsv3AleOWw5XVU6tBGKt/UM4n5fYF6hXia9kYzwj8vR9e7jn45e5eO2H/FT3SO44707qXjOIf/RvF/AagSb4lE/bVGVFe5FEpZmdElCoeenUlGQuO/M46ieVMOTzGSx87nr6ff0xT595MT0HP8Ostr04o0WjoNfxlz4JtE6myOFOvVYEqFiF4l6JPtBovGHtFPL3F5alObat5p7Xbqfeph+Yf2IHxvQczMYjS0fLJTZwe9pS/pZH8/ciVC9ARRTIBecmV1OWbvZ7vjtf7Vl1fsMGuON6mD2beiefzFUXP8hHJ1ZI43kFXX/li/5W3/H3IlQvQEUUyAXnvif+pKWm8GC/U10Bd+9eGDMGHnsMatWC8ePh1lvZ8O/F4BB0k4yhxfDZpNVOYd+BIgpLvMsXoaz7oNPIXQsZizhTjlwqlZ6oU7MG2e2OhVdfhZNPhkcfhcsvh2+/hb//HY44wm/HwWJrsaNYFowAAA1rSURBVMDP+ws9QdwtWL471Kn4IocjjcjFbwdBJ0etWwNd74fPPoMOHSAnB8480+sc3/RIUrlqlkCCfaGoV7iIMwVycXzB6Ouo3/IZtugVLvlyHqSnw4svwqBBkOT8S135oNvCz1JsvpTvFqkaBXIJOLOyRnERV66YxW2LXyO16He+HziEkyY+Ag0ahHz9UEb8qSnJ9GiVXnGtTo3ARYJSjlwAVzBfMrwnj1/a3pPf7v7Dcj548RYeWPA8KzNa0Xfwk3x1xwOVCuLg3EI2JdmQlpriyXdfdEYGM5bnea3VOWLml45rfYqIt4i1sa0MTdGPbUuemUaXGy8DYGPasfyj12A+PLEjGFfwrVOzRqVHzcFayPqbmq+OhyJlqmWKvsQnv0F1xw44+mi6lJ63s3Yava99ioM1UjyfzS8o9DTNqsw0+WAvKkNZHk5EnCm1cphxT/4pn8K4d8Zqtnf/Mxx9tOe8a4c+SYdbpngFcScFhcXc9cbqsFMgoSwPJyLOFMgPM76Tf/66Zh5fP3wex3wy37Xj0UfBWs4ffEHA1efLK7Y27Hy2lmITqTqlVg4z7lTFibu38OHzN3r2rzy2JZmbvoSUFE/qpaCw2NPRMCMtlf0Hi/h5v/PKP+4JPVWtMvE3NV9VKyLBKZAfZprXSeb5xwdz4p6y0XPXG17gt2ObsrI0iJevKS+21mtkHKjePNx8tib8iFSNAvnh5IEHWDhmjGdz6AXDeb9VVwCS9hd6jcTLc4+23dUjd7yxCqdip7TagfPpIlI9lCM/DHzy/FtgjKvBFTDjtJ40v/s9TxAHKKEsreHEvT87M4MGtZwDtlZsFYkOjcgT2e7d0KgR3Uo3f09OocPNr/JrrbqOp7tz08Haxf7isGZnoP0iUr00Ik9E1sIll0CjshV5+l8+npZ/f9tvEAc8LxiDVY+oVFAktiiQJ5rXXnM1snrzTQD+1W0gze+ZxYqmpwT8WEqS8VSJBGsXq1JBkdii1Eqi+O47OOmksu127eCLL5jpd5EHcLcE91osguDVIyoVFIkt6rUSA4L1IQl4/OBBV1/wNWs8519y1yssq3GkZ+3N8qvPg2v0rEUZROKPeq3EKKf1Msv3Lwl4/P2X4P77PdfKHfN/XHHgJK9zZyzP46IzMlj4zU6NnkUSlAJ5lAWq287OzHA8fsrGtWSffk7ZjksugWnTuO2RhRQUFlS41sJvdqqDoEgCUyCPsmB12+WP1z+wj2VPDKRmcVHZibt2wVFHhXQtEUlMqlqJsmClfE3SUsFa/jXrX6z5zwBPEL/pun+7ygxLg3igazVITaHLuAW0GD6bLuMWaLEGkQSjQB5FOSvz2H+wqML+8qV8jyVvYOOj53PRVwsBmHjWpZxy/wf8+cZLKnzOcSWeJMNvB4u08o5IAotIasUYcxcwAUi31u6KxDUTWc7KPEa/95VjJ0FPKWDDQjCGjqX7N6UfR+9B/6FRowaM9fOy0qks0KljYbidCkUktoQdyI0xxwG9gc3hP07i861C8VW/BmTf0B+++KJs57p1NGvVivUhXN+3BtzfCvbKm4skjkikVh4D7gbUMikETlUobtd9PpNF9/cpC+KTJ7vy4K1aVfl+mk4vkvjCCuTGmAuAPGvt6hDOHWKMyTXG5O7cuTOc28Y1p5Fw223fsvGRvtz30WTXjvPPh+JiuPrqsO+n6fQiiS9oasUYMx84xuHQfcC9uNIqQVlrJwGTwDWzsxLPGNd8Z2U2SE3xLF5c9/f9fPrUVdQ/uN9z/tgXFzJrewk/3vtBRCbvaDq9SOKr8hR9Y0wb4EPAHYWaAj8CHa212wN99nCZou+UD09JNhQWlfDPOU/wt9VzPPsHXvIPFrfIJDUlWdPpRcRRxKfoW2u/BBqXu8FGIEtVK2Wc8uF//GYpz8/8h2f72Y79GdvjGgCSjQk4y1NExIlmdlaj8vnwY37dxdKnryo7Vr8xPQc/xYGUWgAVRuL+riMi4itiE4Kstc01GvfWJC2V5JJipr823CuIX3Hbc3zx0QqOSm/o1fM7QxUmIlIFGpFXQrB2s76e/nkJbceP9Gzf13soMzue78l5O33WN6euChMRCUaBPETB2s16WbUKMjNpW7r52Ymnc3n/UZQkJZNWw/8vQaowEZGqUCAPUbB2swDs2wd/+AP89JPnnP/OyeWOT3ZSUvrZ/IJC/18ABF+dR0TEl5pmhShoi9jbboN69cqC+OzZYC3/WPGL3y8AEZFI0Ig8RE3SUslzCOYX/vQlmL5lO26+GSZO9GyqR7iIVDcF8hAN69PSK0eevm8Py54cVHZCejr88APUrev1OX9fAKpEEZFIictAXtnqkUhwX3/CB+t4+Pl7+OP/VpQdXLECMjMdP+f7BQCqRBGRyIq7QF6p6pEIy/5iFtn33VC24/HHXbnxQJ9RJYqIVLO4C+QhVY9E2tq10KZN2Xa3brBgAdQI7T+fKlFEpDrFXSAP5eVhxFIv+/dD69awaVPZvk2b4PjjK38tEZFqEnflh8EWSnCnXsJeo/Luu6FOnbIg/vbbrkUeFMRFJMbEXSAPtlBCoNRLSD78EIyB8eNd29ddByUlkJ3tOSVnZZ5WpReRmBF3qZVgLw+rXLe9YwccfXTZdv36sHkzNGjgdVo0X7aKiDiJu0AOgV8eVrpuu6QE+veHd94p2/f559Cxo+PpUXnZKiISQNylVoKp1BqVL70EycllQfyRR1x5cD9BHDRTU0RiT1yOyAMJqW57/Xrvlek7dIAlSyAlJej1NVNTRGJNwgVyCJB6OXAA2rd3BXK3H36AFi1CvrZmaopIrEm41IpfDzwAqallQXz6dFcapRJBHFxfEu7VfMqv7qP8uIhES0KOyL0sWgR//GPZ9sCB8MorrhLDKtJMTRGJJYkbyHfvhkaNyraPOAK2b4eGDaP3TCIi1SDxUivWwoAB3kF88WL4/XcFcRFJSIkVyF97DZKSXPlvgNGjXYG9S5foPpeISDVKjNTK99+71sp0a9MGli2DmjWj90wiIodIfI/IDx50lROWD+Lffgtr1iiIi8hhI34D+cMPu4L16tWu7VdfdaVRTjopus8lInKIxV9q5bPPoHPnsu2LL4Y33iBn1Y+MH7cgYA/yaCwRJyJS3eIrkM+Y4Qrcbjt3QqNGIXUkVNdCEUlUYadWjDG3GGO+McZ8ZYx5NBIP5VfDhlCvHixc6EqjlJYYhtKDPOw+5SIiMSqsEbkxpgdwAdDOWvu7MaZxZB7Lj5494ddfK+wOpSOhuhaKSKIKN7VyIzDOWvs7gLV2R/iP5CxQfjutdgo/7y+s8JnyHQnVtVBEElW4qZWTgW7GmM+NMR8bYzr4O9EYM8QYk2uMyd25c2elbhJoHc6clXnsO1BU4TMpycarI2Gl+pSLiMSRoCNyY8x84BiHQ/eVfv5IoBPQAXjDGHOCtdb6nmytnQRMAsjKyqpwPJBg+e3CkoqXq3NEDa+XmCH1KRcRiUNBA7m19mx/x4wxNwIzSwP3F8aYEqARULkhdxBVyW//UlAx1aKuhSKSiMJNreQAPQCMMScDRwC7wn0oX/7y2E3SUgMeExE5HIQbyCcDJxhj1gLTgCud0irhCpTfVu5bRA53YVWtWGsPAgMj9Cx+hZLfVu5bRA5XphoG0EFlZWXZ3NzcQ35fEZF4ZoxZbq3N8t0fv02zREQEUCAXEYl7CuQiInFOgVxEJM4pkIuIxLmoVK0YY3YCmw75jaumEdUwySnGHW4/8+H284J+5njVzFqb7rszKoE8nhhjcp3KfRLZ4fYzH24/L+hnTjRKrYiIxDkFchGROKdAHtykaD9AFBxuP/Ph9vOCfuaEohy5iEic04hcRCTOKZCLiMQ5BfIQGGPGG2O+McasMca8bYxJi/YzVSdjzF+NMV8ZY0qMMQlZruVmjPmLMWa9MeY7Y8zwaD9PdTPGTDbG7ChdQyDhGWOOM8YsNMZ8Xfr/9G3RfqbqoEAemnnAadbatsC3wIgoP091Wwv0BxZF+0GqkzEmGXgSOAdoDVxmjGkd3aeqdi8Bf4n2QxxCRcBd1trWuNYWvikR/4wVyENgrZ1rrS0q3VwKNI3m81Q3a+06a+36aD/HIdAR+M5a+0PpIinTgAui/EzVylq7CNgT7ec4VKy126y1K0r/fS+wDki4VWcUyCvvGuCDaD+EREQGsKXc9lYS8C+5uBhjmgOZwOfRfZLIC2upt0RijJkPHONw6D5r7Tul59yH61e1qYfy2apDKD+vSKIwxtQFZgC3W2t/jfbzRJoCeSlr7dmBjhtjrgL6Ar2qY4HpQy3Yz3uYyAOOK7fdtHSfJBBjTAquID7VWjsz2s9THZRaCYEx5i/A3UA/a+3+aD+PRMwy4CRjTAtjzBHAAODdKD+TRJAxxgAvAOustf+O9vNUFwXy0DwB1APmGWNWGWOeifYDVSdjzIXGmK3AWcBsY8ycaD9TdSh9gX0zMAfXS7A3rLVfRfepqpcx5nXgM6ClMWarMebaaD9TNesCXAH0LP27u8oYc260HyrSNEVfRCTOaUQuIhLnFMhFROKcArmISJxTIBcRiXMK5CIicU6BXEQkzimQi4jEuf8HT7WLcq1/wiUAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"rng = jax.random.key(42)\n",
"\n",
"# Generate true data from y = w*x + b + noise\n",
"true_w, true_b = 2, -1\n",
"x_rng, noise_rng = jax.random.split(rng)\n",
"xs = jax.random.normal(x_rng, (128, 1))\n",
"noise = jax.random.normal(noise_rng, (128, 1)) * 0.5\n",
"ys = xs * true_w + true_b + noise\n",
"\n",
"# Fit regression\n",
"params = init(rng)\n",
"for _ in range(1000):\n",
" params = update(params, xs, ys)\n",
"\n",
"plt.scatter(xs, ys)\n",
"plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')\n",
"plt.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1wq3L6Xg1UHP"
},
"source": [
"## Taking it further\n",
"\n",
"The strategy described above is how any (jitted) JAX program must handle state. \n",
"\n",
"Handling parameters manually seems fine if you're dealing with two parameters, but what if it's a neural net with dozens of layers? You might already be getting worried about two things:\n",
"\n",
"1) Are we supposed to initialize them all manually, essentially repeating what we already write in the forward pass definition?\n",
"\n",
"2) Are we supposed to pipe all these things around manually?\n",
"\n",
"The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/google/jax#neural-network-libraries) for some examples."
]
}
],
"metadata": {
"colab": {
"name": "The Problem of State",
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}