rocm_jax/docs/jax-101/01-jax-basics.ipynb
2021-03-22 10:21:02 -07:00

853 lines
54 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "6_117sy0CGEU"
},
"source": [
"# JAX as accelerated NumPy\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/01-jax-basics.ipynb)\n",
"\n",
"*Authors: Rosalia Schneider & Vladimir Mikulik*\n",
"\n",
"In this first section you will learn the very fundamentals of JAX."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CXjHL4L6ku3-"
},
"source": [
"## Getting started with JAX numpy\n",
"\n",
"Fundamentally, JAX is a library that enables transformations of array-manipulating programs written with a NumPy-like API. \n",
"\n",
"Over the course of this series of guides, we will unpack exactly what that means. For now, you can think of JAX as *differentiable NumPy that runs on accelerators*.\n",
"\n",
"The code below shows how to import JAX and create a vector."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "ZqUzvqF1B1TO"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 1 2 3 4 5 6 7 8 9]\n"
]
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"x = jnp.arange(10)\n",
"print(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rPBmlAxXlBAy"
},
"source": [
"So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section.\n",
"\n",
"You can notice the first difference if you check the type of `x`. It is a variable of type `DeviceArray`, which is the way JAX represents arrays."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "3fLtgPUAn7mi"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)"
]
},
"execution_count": 2,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"x"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yx8VofzzoHFH"
},
"source": [
"One useful feature of JAX is that the same code can be run on different backends -- CPU, GPU and TPU.\n",
"\n",
"We will now perform a dot product to demonstrate that it can be done in different devices without changing the code. We use `%timeit` to check the performance. \n",
"\n",
"(Technical detail: when a JAX function is called, the corresponding operation is dispatched to an accelerator to be computed asynchronously when possible. The returned array is therefore not necessarily 'filled in' as soon as the function returns. Thus, if we don't require the result immediately, the computation won't block Python execution. Therefore, unless we `block_until_ready`, we will only time the dispatch, not the actual computation. See [Asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html#asynchronous-dispatch) in the JAX docs.)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "mRvjVxoqo-Bi"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The slowest run took 7.39 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"100 loops, best of 5: 7.85 ms per loop\n"
]
}
],
"source": [
"long_vector = jnp.arange(int(1e7))\n",
"\n",
"%timeit jnp.dot(long_vector, long_vector).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DKBB0zs-p-RC"
},
"source": [
"**Tip**: Try running the code above twice, once without an accelerator, and once with a GPU runtime (while in Colab, click *Runtime* → *Change Runtime Type* and choose `GPU`). Notice how much faster it runs on a GPU."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PkCpI-v0uQQO"
},
"source": [
"## JAX first transformation: `grad`\n",
"\n",
"A fundamental feature of JAX is that it allows you to transform functions.\n",
"\n",
"One of the most commonly used transformations is `jax.grad`, which takes a numerical function written in Python and returns you a new Python function that computes the gradient of the original function. \n",
"\n",
"To use it, let's first define a function that takes an array and returns the sum of squares."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "LuaGUVRUvbzQ"
},
"outputs": [],
"source": [
"def sum_of_squares(x):\n",
" return jnp.sum(x**2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QAqloI1Wvtp2"
},
"source": [
"Applying `jax.grad` to `sum_of_squares` will return a different function, namely the gradient of `sum_of_squares` with respect to its first parameter `x`. \n",
"\n",
"Then, you can use that function on an array to return the derivatives with respect to each element of the array."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "dKeorwJfvpeI"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"30.0\n",
"[2. 4. 6. 8.]\n"
]
}
],
"source": [
"sum_of_squares_dx = jax.grad(sum_of_squares)\n",
"\n",
"x = jnp.asarray([1.0, 2.0, 3.0, 4.0])\n",
"\n",
"print(sum_of_squares(x))\n",
"\n",
"print(sum_of_squares_dx(x))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VfBt5CYbyKUX"
},
"source": [
"You can think of `jax.grad` by analogy to the $\\nabla$ operator from vector calculus. Given a function $f(x)$, $\\nabla f$ represents the function that computes $f$'s gradient, i.e.\n",
"\n",
"$$\n",
"(\\nabla f)(x)_i = \\frac{\\partial f}{\\partial x_i}(x).\n",
"$$\n",
"\n",
"Analogously, `jax.grad(f)` is the function that computes the gradient, so `jax.grad(f)(x)` is the gradient of `f` at `x`.\n",
"\n",
"(Like $\\nabla$, `jax.grad` will only work on functions with a scalar output -- it will raise an error otherwise.)\n",
"\n",
"This makes the JAX API quite different to other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.\n",
"\n",
"This way of doing things makes it straightforward to control things like which variables to differentiate with respect to. By default, `jax.grad` will find the gradient with respect to the first argument. In the example below, the result of `sum_squared_error_dx` will be the gradient of `sum_squared_error` with respect to `x`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "f3NfaVu4yrQE"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.20000005 -0.19999981 -0.19999981 -0.19999981]\n"
]
}
],
"source": [
"def sum_squared_error(x, y):\n",
" return jnp.sum((x-y)**2)\n",
"\n",
"sum_squared_error_dx = jax.grad(sum_squared_error)\n",
"\n",
"y = jnp.asarray([1.1, 2.1, 3.1, 4.1])\n",
"\n",
"print(sum_squared_error_dx(x, y))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1tOztA5zpLWN"
},
"source": [
"To find the gradient with respect to a different argument (or several), you can set `argnums`:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "FQSczVQkqIPY"
},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
" DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))"
]
},
"execution_count": 7,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"jax.grad(sum_squared_error, argnums=(0, 1))(x, y) # Find gradient wrt both x & y"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yQAMTnZSqo-t"
},
"source": [
"Does this mean that when doing machine learning, we need to write functions with gigantic argument lists, with an argument for each model parameter array? No. JAX comes equipped with machinery for bundling arrays together in data structures called 'pytrees', on which more in a [later guide](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/05.1-pytrees.ipynb). So, most often, use of `jax.grad` looks like this:\n",
"\n",
"```\n",
"def loss_fn(params, data):\n",
" ...\n",
"\n",
"grads = jax.grad(loss_fn)(params, data_batch)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oBowiovisT97"
},
"source": [
"where `params` is, for example, a nested dict of arrays, and the returned `grads` is another nested dict of arrays with the same structure."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LNjf9jUEsZZ8"
},
"source": [
"## Value and Grad\n",
"\n",
"Often, you need to find both the value and the gradient of a function, e.g. if you want to log the training loss. JAX has a handy sister transformation for efficiently doing that:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "dWg4_-h3sYwl"
},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray(0.03999995, dtype=float32),\n",
" DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))"
]
},
"execution_count": 8,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"jax.value_and_grad(sum_squared_error)(x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QVT2EWHJsvvv"
},
"source": [
"which returns a tuple of, you guessed it, (value, grad). To be precise, for any `f`,\n",
"\n",
"```\n",
"jax.value_and_grad(f)(*xs) == (f(*xs), jax.grad(f)(*xs)) \n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QmHTVpAks3OX"
},
"source": [
"## Auxiliary data\n",
"\n",
"In addition to wanting to log the value, we often want to report some intermediate results obtained in computing the loss function. But if we try doing that with regular `jax.grad`, we run into trouble:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "ffGCEzT4st41",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "TypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFilteredStackTrace\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-7433a86e7375>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquared_error_with_aux\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output was (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).\n\nThe stack trace above excludes JAX-internal frames.\nThe following is the original exception that occurred, unmodified.\n\n--------------------",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/api.py\u001b[0m in \u001b[0;36m_check_scalar\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 825\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 826\u001b[0;31m \u001b[0maval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_aval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 827\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mget_aval\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 913\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 914\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mconcrete_aval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 915\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mconcrete_aval\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 906\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandler\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mhandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 907\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{type(x)} is not a valid JAX type\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 908\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: <class 'tuple'> is not a valid JAX type",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-7433a86e7375>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msum_squared_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquared_error_with_aux\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mreraise_with_filtered_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_under_reraiser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 743\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mapi_boundary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 744\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 745\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 746\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 747\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mreraise_with_filtered_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_under_reraiser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 810\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 811\u001b[0;31m \u001b[0m_check_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 812\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdtypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 813\u001b[0m \u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_check_output_dtype_grad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mholomorphic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/api.py\u001b[0m in \u001b[0;36m_check_scalar\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 826\u001b[0m \u001b[0maval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_aval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 827\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 828\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"was {x}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 829\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 830\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mShapedArray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: Gradient only defined for scalar-output functions. Output was (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))."
]
}
],
"source": [
"def squared_error_with_aux(x, y):\n",
" return sum_squared_error(x, y), x-y\n",
"\n",
"jax.grad(squared_error_with_aux)(x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IUubno3nth4i"
},
"source": [
"This is because `jax.grad` is only defined on scalar functions, and our new function returns a tuple. But we need to return a tuple to return our intermediate results! This is where `has_aux` comes in:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "uzUFihyatgiF"
},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
" DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))"
]
},
"execution_count": 10,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"jax.grad(squared_error_with_aux, has_aux=True)(x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "g5s3UiFauwDk"
},
"source": [
"`has_aux` signifies that the function returns a pair, `(out, aux)`. It makes `jax.grad` ignore `aux`, passing it through to the user, while differentiating the function as if only `out` was returned."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fk4FUXe7vsW4"
},
"source": [
"## Differences from NumPy\n",
"\n",
"The `jax.numpy` API closely follows that of NumPy. However, there are some important differences. We cover many of these in future guides, but it's worth pointing some out now.\n",
"\n",
"The most important difference, and in some sense the root of all the rest, is that JAX is designed to be _functional_, as in _functional programming_. The reason behind this is that the kinds of program transformations that JAX enables are much more feasible in functional-style programs.\n",
"\n",
"An introduction to functional programming (FP) is out of scope of this guide. If you already are familiar with FP, you will find your FP intuition helpful while learning JAX. If not, don't worry! The important feature of functional programming to grok when working with JAX is very simple: don't write code with side-effects.\n",
"\n",
"A side-effect is any effect of a function that doesn't appear in its output. One example is modifying an array in place:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "o_YBuLQC1wPJ"
},
"outputs": [
{
"data": {
"text/plain": [
"array([123, 2, 3])"
]
},
"execution_count": 11,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"\n",
"x = np.array([1, 2, 3])\n",
"\n",
"def in_place_modify(x):\n",
" x[0] = 123\n",
" return None\n",
"\n",
"in_place_modify(x)\n",
"x"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JTtUihVZ13F6"
},
"source": [
"The side-effectful function modifies its argument, but returns a completely unrelated value. The modification is a side-effect. \n",
"\n",
"The code below will run in NumPy. However, JAX arrays won't allow themselves to be modified in-place:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "u6grTYIVcZ3f",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "TypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-12-709e2d7ddd3f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Raises error when we cast input to jnp.ndarray\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-11-fce65eb843c7>\u001b[0m in \u001b[0;36min_place_modify\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m123\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py\u001b[0m in \u001b[0;36m_unimplemented_setitem\u001b[0;34m(self, i, x)\u001b[0m\n\u001b[1;32m 5116\u001b[0m \u001b[0;34m\"immutable; perhaps you want jax.ops.index_update or \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5117\u001b[0m \"jax.ops.index_add instead?\")\n\u001b[0;32m-> 5118\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5120\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_operator_round\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mndigits\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: '<class 'jax.interpreters.xla._DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?"
]
}
],
"source": [
"in_place_modify(jnp.array(x)) # Raises error when we cast input to jnp.ndarray"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RGqVfYSpc49s"
},
"source": [
"Helpfully, the error points us to JAX's side-effect-free way of doing the same thing via the [`jax.ops.index_*`](https://jax.readthedocs.io/en/latest/jax.ops.html#indexed-update-operators) ops. They are analogous to in-place modification by index, but create a new array with the corresponding modifications made:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "Rmklk6BB2xF0"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([123, 2, 3], dtype=int32)"
]
},
"execution_count": 13,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def jax_in_place_modify(x):\n",
" return jax.ops.index_update(x, 0, 123)\n",
"\n",
"y = jnp.array([1, 2, 3])\n",
"jax_in_place_modify(y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "91tn_25vdrNf"
},
"source": [
"Note that the old array was untouched, so there is no side-effect:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "KQGXig4Hde6T"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([1, 2, 3], dtype=int32)"
]
},
"execution_count": 14,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"y"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d5TibzPO25qa"
},
"source": [
"Side-effect-free code is sometimes called *functionally pure*, or just *pure*.\n",
"\n",
"Isn't the pure version less efficient? Strictly, yes; we are creating a new array. However, as we will explain in the next guide, JAX computations are often compiled before being run using another program transformation, `jax.jit`. If we don't use the old array after modifying it 'in place' using `jax.ops.index_update()`, the compiler can recognise that it can in fact compile to an in-place modify, resulting in efficient code in the end.\n",
"\n",
"Of course, it's possible to mix side-effectful Python code and functionally pure JAX code, and we will touch on this more later. As you get more familiar with JAX, you will learn how and when this can work. As a rule of thumb, however, any functions intended to be transformed by JAX should avoid side-effects, and the JAX primitives themselves will try to help you do that.\n",
"\n",
"We will explain other places where the JAX idiosyncracies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dFn_VBFFlGCz"
},
"source": [
"## Your first JAX training loop\n",
"\n",
"We still have much to learn about JAX, but you already know enough to understand how we can use JAX to build a simple training loop.\n",
"\n",
"To keep things simple, we'll start with a linear regression. \n",
"\n",
"Our data is sampled according to $y = w_{true} x + b_{true} + \\epsilon$."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "WGgyEWFqrPq1"
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x7fabbcb7c750>"
]
},
"execution_count": 15,
"metadata": {
"tags": []
},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAVT0lEQVR4nO3df6zddX3H8dfrHk6Xc/3BqeE60wt3ZZt2EQt0XhHXbE5kFp1CbZTphplzWTMzjUxWQoUI23Ql63SaaLY0kSyLREGtVzJ1FYLOaAbz1ttSoNQRI8KpxhK96OgVbm/f++PeA7en5/f3e358v+f5SJrc8+N+z+cEePXD+/v+fD6OCAEAsmts0AMAACRDkANAxhHkAJBxBDkAZBxBDgAZd8YgPvSss86K9evXD+KjASCz9u/f/3hETNQ+P5AgX79+vWZnZwfx0QCQWbYfqfc8pRUAyDiCHAAyjiAHgIwjyAEg4whyAMi4gXStAMComZmraPe+Izo6v6B15ZJ2bNmgrZsmU7k2QQ4APTYzV9HOvYe0sLgkSarML2jn3kOSlEqYU1oBgB7bve/IMyFetbC4pN37jqRyfYIcAHrs6PxCR893KpUgt122/XnbD9k+bPtVaVwXAPJgXbnU0fOdSmtG/nFJ/xkRvyXpAkmHU7ouAGTeji0bVCoWTnmuVCxox5YNqVw/8c1O22dK+j1J75SkiHha0tNJrwsAeVG9odmrrhUnPbPT9oWS9kh6UMuz8f2S3hcRT9a8b7uk7ZI0NTX18kceqbv3CwCgAdv7I2K69vk0SitnSPptSf8SEZskPSnputo3RcSeiJiOiOmJidN2YQQAdCmNIH9M0mMRce/K489rOdgBAH2QOMgj4seSHrVdrdq/VstlFgBAH6S1svO9km61vUbS9yX9WUrXBQC0kEqQR8QBSacV4AEAvcfKTgDIOIIcADKOIAeAjGMbWwBY0cs9w3uJIAcA9X7P8F6itAIA6v2e4b3EjBzASFhdNimPFxUhPbGw+EwJpdd7hvcSQQ4g92rLJj87vvjMa9USypmlouYXFk/73bT2DO8lSisAcq9e2WS1hcUl2erpnuG9RJADyL12yiPzxxe1a9tGTZZLsqTJckm7tm0c+hudEqUVACNgXbmkSoswX1cuaeumyUwEdy1m5AByr1V5JCsllEYIcgC5MTNX0eab79a5131Zm2++WzNzFUnLfeDlUrHu7xTszJRQGiHIAeRCtTOlMr+g0LPdKNUwv+ny8+rezPzIlRdkOsQlghxATrRa0LN102Rmb2a2ws1OALnQzoKerN7MbIUZOYBcKI/Xr4FnYUFPUqkFue2C7Tnb/5HWNQGgHTNzFf3fL0+c9nyx4Ex3o7QrzRn5+yQdTvF6ANCW3fuOaPFknPb8c9ackctSSq1UauS2z5b0h5I+LOn9aVwTwGjrZG/wRvXxJ+rsnZJHac3IPybpWkknU7oegBHWqpWwVqM6+CjUx6UUgtz2GyX9JCL2t3jfdtuztmePHTuW9GMB5Fine4Pv2LIhsxtepSGNGflmSZfb/oGkz0q6xPana98UEXsiYjoipicmJlL4WAB51ene4HnuEW9H4hp5ROyUtFOSbP++pL+JiKuSXhfA6Gq0yVWzUklee8TbQR85gKEz6qWSTqW6sjMiviHpG2leE8Doqc6ss3ii/SCwRB/AUBrlUkmnKK0AQMYR5ACQcQQ5AGQcQQ4AGUeQA0DG0bUCIBXVTa4q8wsq2FqK0CRtg31BkANIrLrJVXV/lKVY3lK2utmVJMK8hyitAEis3iZXVc02u0I6CHIAiTXazKrd15EMQQ4gsTNL9c/LrBqVfcEHhRo5gI6tPr3nzFJRv3jq9PMyV2Ozq94iyAG0ZXVXiiVVT8icb3GcWrlU5EZnjxHkAFqq7Uo5/Zjj+krFgm66/LzeDQySqJEDaEOzrpRGCvZIndIzSAQ5gJY67TopFQv6yJUXEOJ9QpADaGpmrqIxu+l7imPW2vHiSJ6XOQyokQM4TaMbm6tVn2cZ/uAlDnLb50j6d0m/quV/rnsi4uNJrwtgMNq5sVmwKZ0MkTRm5CckXRMR37X9PEn7bd8ZEQ+mcG0AfbC6L3xsZcOrZk5GEOJDJHGQR8SPJP1o5edf2D4saVISQQ5kQKMNr5phpeZwSbVGbnu9pE2S7q3z2nZJ2yVpamoqzY8F0IGZuYpuuuOBZxbyjFk62W5juJY7UlipOVxS61qx/VxJX5B0dUT8vPb1iNgTEdMRMT0xMZHWxwLowMxcRTs+d/CU1ZjthHi1Z4WOlOGUyozcdlHLIX5rROxN45oA0jUzV9E1tx9sq3SyGl0pwy+NrhVL+pSkwxHx0eRDApC2G2YO6dZ7ftj20vqqyXJJ377ukp6MCelJo7SyWdI7JF1i+8DKnzekcF0AKZiZq+jTXYQ4tfDsSKNr5Vt6toQGYEisXtTTjuesKag8vkZH5xe0jnJKprCyE8ihTkspxYL14TdzEzOrCHIgJ2rbCtvFzczsI8iBDOu0fLLaeHFMD/7963swKvQbQQ5kVO2KzE4Ux6x/2HZ+D0aFQSDIgYzq5rAHiVJKHhHkQMZ0W0656uIpfWjrxh6NCoNEkAMZUl1iv9jJ5iiSNv/GCwjxHCPIgSG1emvZal/3TXc80HGIMxPPP4IcGEK1NzIr8wu6+rYDHV2jVCywwdWIIMiBIdTtjcwqbmiOFoIcGDIzc5WOb2SuHS9q7oOv69GIMOxS248cQHLVkkonigXrxjed16MRIQuYkQND4oaZQ/r0PT9s+/2W2NwKkghyYKC67Qkvl4o6cCOlFCwjyIEB6XaJfXHMuulySil4FjVyYEA66Uwpl4qylrtRdr/1AkopOAUzcmBAjrZZTqGMglZSmZHbvsz2EdsP274ujWsCWTczV9Hmm+/Wudd9WZtvvlszc5VTXl9XLrW8xphEGQUtpXH4ckHSJyX9gaTHJH3H9h0R8WDSawNZVXtCT2V+QTv3HtLsIz/V1x86pqPzCzqzVFSxYC0u1V9yXyqOade28ymjoKU0SisXSXo4Ir4vSbY/K+kKSQQ5RtLMXKXuMWsLi0untBfOLyyqOGatHS9q/vgirYToWhpBPinp0VWPH5P0yto32d4uabskTU1NpfCxwHDave9I22dlLp4Mja85g1WZSKRvXSsRsScipiNiemJiol8fC/TNzFxFm/7uax33hLd70xNoJI0ZeUXSOasen73yHDAyZuYq2vH5gw3r3c20c9MTaCaNIP+OpBfbPlfLAf42SX+cwnWBoVW7V/iTT53oKsRLxYJ2bNnQgxFilCQO8og4Yfs9kvZJKki6JSIeSDwyYEjV2yu8G2vHi7rxTedxcxOJpbIgKCK+IukraVwLGHZJ9wqnrRBpY2Un0KFOb04y80avEeRAh9aVS22XUz72RxcS4Og5Ns0COrRjywaVioWW75sslwhx9AUzcqCJeifZV8O52WHIdKOgn5iRAw3MzFX0/tsOqDK/oNByd8r7bzugmbmKtm6a1GSD/u+Czen16CuCHGhg5977dLLmuZMrz0v1SyylYkEfuZL9wtFflFaAGjfMHNJn7n1US1F/gc/C4nK8V8O6UekF6BeCHFil0wOQt26aJLgxcJRWgFU+c++jrd8EDBlm5BhZM3MV3XTHA5pfWJS0vHCnUTllteesad16CPQTQY6RNDNX0Y7PHdTiyWeD+2fHF1v+XmHM+vCbN/ZyaEDHCHKMpN37jpwS4u2Y5GYmhhRBjpHUaol9wdZShAq23v7Kc/ShrczCMbwIcoycmbmKLDU8jm2yXNK3r7ukn0MCEqFrBSOn2ZmaxYJZWo/MYUaO3Gq0T0qzbWh3v4VVmcgeghy5VO8Un517D0lqvA0tuxUiqxKVVmzvtv2Q7ftsf9F2Oa2BAd2Ymato88136+rbDpx2is/C4pJ27zvScI8USirIqqQ18jslvSwizpf0PUk7kw8J6E51Ft6sI6Uyv6Ctmya1a9tGTZZLspZn4uxWiCxLVFqJiK+teniPpLckGw7QvXbO0izYktgjBfmSZtfKuyR9tdGLtrfbnrU9e+zYsRQ/FljWzlma7SzBB7KmZZDbvsv2/XX+XLHqPddLOiHp1kbXiYg9ETEdEdMTExPpjB5YZV2Dgx5Wa3QYBJBlLUsrEXFps9dtv1PSGyW9NoLpDnqj2ZFrVTu2bDilU6UWNzSRV4lq5LYvk3StpFdHxPF0hgScql4r4dW3HdAH9t6nhcWTpwV7NfDL40VFSE8sLHLoA3ItaR/5JyT9iqQ7vXwT6Z6I+MvEowJWaXQT8/jKST2re8S5iYlRlLRr5TfTGghQzw0zzdsJq6o94oQ4RhErOzF0ag98aFc7XStAHhHkGCr1DnxoVztdK0Aesfshhko3Bz5IdKRgtDEjx1BptzxSLhVlS/PH6UgBCHIMlUY7E6521cVTnNgDrEJpBUNlx5YNKo657msWIQ7Uw4wcfdVqhWb159VdK2vHi7rxTedROgEaIMjRN80Oe6gNc0IbaB9Bjp6pnX0/+dSJuoc9XHP7QUkivIEuEeToiXqz70aWIurOzAG0h5ud6Il2DnlYrbrEHkDnCHL0RDfL5VliD3SHIEdPNFouv3a8+Mxxa+3+DoDmqJGja7WbW1XbBCXpyadOnPb+UrHwzOu1B0CwxB7oHkGOrtTb3Opnxxd1zecOakw6bb+Uer3grU78AdAeghxdabS51dLJUL1bnONrzqBXHOgRauToSqc3JrmRCfQOQY6ulIqd/avDjUygd1IJctvX2A7bZ6VxPQy3mbnKM+dltoMbmUBvJa6R2z5H0usk/TD5cJAFrRburB0vanzNGdzIBPokjZud/yzpWklfSuFayIBm9W5L7FQI9Fmi0ortKyRVIuJgG+/dbnvW9uyxY8eSfCwGrFm9+08uniLEgT5rGeS277J9f50/V0j6gKQPtvNBEbEnIqYjYnpiYiLpuDFAO7ZsUKlYOOU5Dn0ABqdlaSUiLq33vO2Nks6VdNDLS67PlvRd2xdFxI9THSWGSnXGzYIeYDh0XSOPiEOSXlh9bPsHkqYj4vEUxoUBanWKj8SCHmCYsLITp2j3FB8AwyO1II+I9WldC/1RnXlX5hdUsLUUoTFLtSvvq3uFE+TAcGJGPqJqZ95LsZzedbZPkcQSe2CYEeQ516je3ekJPiyxB4YXQZ5jzerdnc6wWWIPDC82zcqxerPuar27kxl2uVSkPg4MMYI8xxrNuo/OL9Rd1FNPqVjQTZefl/bQAKSIIM+xRrPudeWStm6a1K5tGzW58p7qOZprx4sql4qypMlySbu2bWQ2Dgw5auQ5tmPLhqZnY7KoB8gHgjwnmq3GZCk9kG8EeQ60Wo1JcAP5Ro08B5p1pwDIP4I8B5p1pwDIP4I8B5p1pwDIP4I8B+r1hHPgMTA6uNmZA3SnAKONIM8JulOA0UVpBQAyjiAHgIxLHOS232v7IdsP2P7HNAYFAGhfohq57ddIukLSBRHxlO0XtvodNNfOwccAsFrSm53vlnRzRDwlSRHxk+RDGl0cfAygG0lLKy+R9Lu277X9X7Zf0eiNtrfbnrU9e+zYsYQfm08stQfQjZYzctt3SXpRnZeuX/n9F0i6WNIrJN1u+9cj4rQjfCNij6Q9kjQ9Pd3giN/RxlJ7AN1oGeQRcWmj12y/W9LeleD+H9snJZ0liSl3F9aVS6rUCW2W2gNoJmlpZUbSayTJ9kskrZH0eNJBjSqW2gPoRtKbnbdIusX2/ZKelvSn9coqo6yTLhSW2gPohgeRu9PT0zE7O9v3z+232i4UaXmGzTmYALphe39ETNc+z14rKVs9Ax+ztVTzF+XC4pKuuf2gJFoKAaSDIE9R7Qy8NsSrliLoDweQGvZaSVG9PvBG6A8HkBZm5F2qdxOz035v+sMBpIEg70KjpfRnloqaX1hs+zr0hwNIA6WVLjRaSm+rbh/4VRdP0R8OoGcI8i40KonMH1/Urm0bNVkuyZImyyXt2rZRH9q6se7z3OgEkAZKK11otpS+0ZFrHMUGoFeYkXeBpfQAhgkz8i6wlB7AMCHIu0SpBMCwoLQCABlHkANAxhHkAJBxBDkAZBxBDgAZR5ADQMYlCnLbF9q+x/YB27O2L0prYACA9iSdkf+jpL+NiAslfXDlMQCgj5IGeUh6/srPZ0o6mvB6AIAOJV3ZebWkfbb/Sct/KfxOozfa3i5puyRNTU0l/FgAQFXLILd9l6QX1XnpekmvlfTXEfEF21dK+pSkS+tdJyL2SNojSdPT0/UPs0yg3ok9LKEHMAocDQ4IbuuX7ScklSMibFvSExHx/Fa/Nz09HbOzs11/bq3aE3uk5d0I2fMbQJ7Y3h8R07XPJ62RH5X06pWfL5H0vwmv15VGJ/ZwuDGAUZC0Rv4Xkj5u+wxJv9RKDbzfGp3Yw+HGAEZBoiCPiG9JenlKY+lasxN7ACDvcrGykxN7AIyyzBws0awrhRN7AIyyTAR5bVdKZX5BO/cekqRTwpzgBjCKMlFaoSsFABrLRJDTlQIAjWUiyBt1n9CVAgAZCXK6UgCgsUzc7KQrBQAay0SQS3SlAEAjmSitAAAaI8gBIOMIcgDIOIIcADKOIAeAjEt0QlDXH2ofk/RI3z+4d86S9PigB9FHfN/8G7XvnJXv+2sRMVH75ECCPG9sz9Y7fimv+L75N2rfOevfl9IKAGQcQQ4AGUeQp2PPoAfQZ3zf/Bu175zp70uNHAAyjhk5AGQcQQ4AGUeQp8D2btsP2b7P9hdtlwc9pl6z/VbbD9g+aTuzbVut2L7M9hHbD9u+btDj6SXbt9j+ie37Bz2WfrB9ju2v235w5d/l9w16TN0iyNNxp6SXRcT5kr4naeeAx9MP90vaJumbgx5Ir9guSPqkpNdLeqmkt9t+6WBH1VP/JumyQQ+ij05IuiYiXirpYkl/ldV/vgR5CiLiaxFxYuXhPZLOHuR4+iEiDkdE3k+/vkjSwxHx/Yh4WtJnJV0x4DH1TER8U9JPBz2OfomIH0XEd1d+/oWkw5IyeegBQZ6+d0n66qAHgVRMSnp01ePHlNH/0NGc7fWSNkm6d7Aj6U5mTggaNNt3SXpRnZeuj4gvrbznei3/79qt/Rxbr7TznYGss/1cSV+QdHVE/HzQ4+kGQd6miLi02eu23ynpjZJeGzlpzm/1nUdARdI5qx6fvfIccsJ2UcshfmtE7B30eLpFaSUFti+TdK2kyyPi+KDHg9R8R9KLbZ9re42kt0m6Y8BjQkpsW9KnJB2OiI8OejxJEOTp+ISk50m60/YB2/866AH1mu03235M0qskfdn2vkGPKW0rN7DfI2mflm+E3R4RDwx2VL1j+zOS/lvSBtuP2f7zQY+pxzZLeoekS1b+uz1g+w2DHlQ3WKIPABnHjBwAMo4gB4CMI8gBIOMIcgDIOIIcADKOIAeAjCPIASDj/h/USuotBmiqlQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"xs = np.random.normal(size=(100,))\n",
"noise = np.random.normal(scale=0.1, size=(100,))\n",
"ys = xs * 3 - 1 + noise\n",
"\n",
"plt.scatter(xs, ys)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RTh22mo4rR1x"
},
"source": [
"Therefore, our model is $\\hat y(x; \\theta) = wx + b$.\n",
"\n",
"We will use a single array, `theta = [w, b]` to house both parameters:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "TnVrRTMamyzb"
},
"outputs": [],
"source": [
"def model(theta, x):\n",
" \"\"\"Computes wx + b on a batch of input x.\"\"\"\n",
" w, b = theta\n",
" return w * x + b"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qCrLmmKrn9_h"
},
"source": [
"The loss function is $J(x, y; \\theta) = (\\hat y - y)^2$."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "07eMcDLMn9Ww"
},
"outputs": [],
"source": [
"def loss_fn(theta, x, y):\n",
" prediction = model(theta, x)\n",
" return jnp.mean((prediction-y)**2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ejMt4dulnoYX"
},
"source": [
"How do we optimize a loss function? Using gradient descent. At each update step, we will find the gradient of the loss w.r.t. the parameters, and take a small step in the direction of steepest descent:\n",
"\n",
"$\\theta_{new} = \\theta - 0.1 (\\nabla_\\theta J) (x, y; \\theta)$"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "2I6T5Wphpaaa"
},
"outputs": [],
"source": [
"def update(theta, x, y, lr=0.1):\n",
" return theta - lr * jax.grad(loss_fn)(theta, x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MAUL1gT_opVn"
},
"source": [
"In JAX, it's common to define an `update()` function that is called every step, taking the current parameters as input and returning the new parameters. This is a natural consequence of JAX's functional nature, and is explained in more detail in [The Problem of State](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/07-state.ipynb).\n",
"\n",
"This function can then be JIT-compiled in its entirety for maximum efficiency. The next guide will explain exactly how `jax.jit` works, but if you want to, you can try adding `@jax.jit` before the `update()` definition, and see how the training loop below runs much faster."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "WLZxY7nIpuVW"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"w: 3.00, b: -1.00\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAZRklEQVR4nO3de5BcZZnH8d8znQ7pAcIECWAGxiBiUDYm0RaiKMpFAgmEGBVkxfWeKm9L1A0kAgoLbLJmxeBqlZtV13JhJUqyIxRiSLwulAEmzJBwC6IESAclaCZgMknm8uwfc2Gmp3v6drpPn+7vp8oi5/Tpc94h8sub97zv+5i7CwAQXQ1hNwAAUBqCHAAijiAHgIgjyAEg4ghyAIi4cWE89KijjvKpU6eG8WgAiKzNmze/6O6T08+HEuRTp05VW1tbGI8GgMgys2cynWdoBQAijiAHgIgjyAEg4ghyAIg4ghwAIo4gB4CIC2X6IQDUm9b2lFau36adnV2a0pTQkjnTtGBWcyD3JsgBoMxa21O6cu0WHejpkySlOru0bN1WSQokzBlaAYAycnctXfdKiA/q6u7VyvXbAnlGIEFuZk1mdruZPWFmj5vZ24K4LwBE2dYde3TCsp9pf3dfxs93dnYF8pyghlZulvRzd3+/mY2X1BjQfQEgcvr6XB/4j99p8zO7JUkNJvVlKMY2pSkRyPNKDnIzO0LSGZI+KknuflDSwVLvCwBRdN9TL+pD371/6PgHH3urOvd1a9m6rerq7h06n4jHtGTOtECeGUSP/ARJuyT9l5nNkLRZ0uXuvjeAewNAJHT39undK3+t1MBwySlTJuqOz71DsQYbuqZcs1as1OLLZpaUtEnS6e5+v5ndLOkld78m7bpFkhZJUktLy1ueeSbjJl4AEDl3bXlen/2fh4aO133m7Xpzy6TAn2Nmm909mX4+iB75Dkk73H3w7xK3S1qafpG7r5a0WpKSyWRpf3oAQBXYd7BHM667R929/ZF21slH63sfScrMcnwzWCUHubv/ycyeM7Np7r5N0tmSHiu9aQBQvf570zO6pvWRoeMNXzhDJx1zeChtCWrWyucl3TowY+WPkj4W0H0BoKrs3ntQs67fMHR86aktWr5weogtCijI3b1D0qhxGwCoJas2PqlVG38/dHzf0rPUHNAUwlKwRB8ActjZ2aW3r/jl0PE/nn2Svvie14fYopEIcgAYw9WtW3XLpmeHjh+65j068tDxIbZoNIIcADJ46oWXdc5Nvx06vm7+KfrI26eG16AxEOQAMKC1PaWv/fwJ7dyzf+hcg0lbr52jQw+p3ris3pYBQAW1tqe05CcPq3vYpijxmGnl+2dUdYhLbGMLAOrp7dPiNR0jQlySuns9sK1my6m6/5gBgIAMr9DT1BiXu7Snq1uHHjJOfzvQk/V7QW01W04EOYCa19qeGrH74O593UOfjRXiUnBbzZYTQysAat7K9dtGbCGbyaTGuBLx2IhzQW41W04EOYCal8/wSOe+bi1fOF3NTQmZpOamhJYvnB7YVrPlxNAKgJo3pSkxtE/4WNcsmNUcieBOR48cQE370QPP5gzxqAyhZEOPHEDNGD4zJZ9euCTFzCIzhJINPXIANWFwZkqqs0sujQrxVZfMzPgy8+sXz4h0iEv0yAHUiGwzUyY1xtX+lXNHXFeOuplhIsgB1IRswyidw+aMR/VlZi4MrQCItN4+19Sld2X9PAoLekoVWI/czGKS2iSl3P2CoO4LANmMFeBS/6ZXUZ6Nkq8ge+SXS3o8wPsBQEa79x7MGeKSdOj4cTU5lJIukB65mR0naZ6kGyV9MYh7Aqhv6VMJB19MZgpwk+Sjb6E9Xd0ZztaeoHrkqyRdIakvoPsBqGOZphJeuXbLqBB/4vrztH3FvKzj4PUwPi4FEORmdoGkF9x9c47rFplZm5m17dq1q9THAqhhmaYSHugZ2U/cvmKeJgzMC18yZ1pkN7wKQhBDK6dLmm9mcyVNkDTRzG5x98uGX+TuqyWtlqRkMpnpb0EAIGnsTa6eXj5XZjbi3OA4eC3OEc9HyUHu7sskLZMkM3u3pH9KD3EAKES25fXNTYlRIT6oVueI54N55ACqytK1WzKGeD0NlRQq0JWd7v5rSb8O8p4A6kf6y8wJ8QYd6O6ru6GSQrFEH0DoTr1xo154+cCIc9tXzAupNdFDkAMIVXov/JoL3qhPvOOEkFoTTQQ5gFBkWthDL7w4BDmAiuru7dNJV9094tyaRbN12mtfFVKLoo8gB1Ax9MLLgyAHUHYvvLxfp974ixHnNi07W8ceMSGkFtUWghxAIAY3uUp1dilmpl53NWdZ2EMvPFgEOYCSDW5yNbg/Sq/378KRHuJP3nC+xo9jHWLQCHIAJctWL3M4euHlwx+NAEo21iZXUv9+4SgfghxAyY5IxMf8vF72BQ8LQysACja8ek8+e1Kz2VV5EeQA8jJ8Vkq20mqZNCXibHZVZgQ5gJzSZ6XkG+KJeEzXzj+lfA2DJMbIAeQhn1kp6WJmWr5wOr3xCiDIAeSUa1ZKukQ8pq9fPIMQrxCCHMCYbnvg2ZxDKfEG06TGuEz95djoiVcWY+QARhn+YjObwReezVTvCV3JQW5mx0v6oaRj1P/7utrdby71vgDC0dqe0hW3b9HB3r6s18TMGDqpIkH0yHskfcndHzKzwyVtNrMN7v5YAPcGUAGFzgvvcyfEq0jJQe7uz0t6fuDXL5vZ45KaJRHkQBUaHtpTmhI68+TJWrs5VdCsFFZqVpdAx8jNbKqkWZLuz/DZIkmLJKmlpSXIxwLIU/p88FRnl27Z9GxB90jEY6zUrDKBzVoxs8MkrZW02N1fSv/c3Ve7e9Ldk5MnTw7qsQAKUMx8cOmVTa+YkVKdAumRm1lc/SF+q7uvC+KeAILT2p7SdXc+qt37ugv+LrNSql8Qs1ZM0vckPe7uN5XeJABBam1PacntD6u7N9+F9a9obkrovqVnlaFVCFIQQyunS/qwpLPMrGPgf3MDuC+AAKxcv62oEGcsPDqCmLVyr9g3Hqg6+SzqGe7Q8TE1NY4fms3CcEp0sLITqEFXt27VrZtyL60fFI+ZbnwvLzGjiiAHagAvM+sbQQ5EWGt7Skt+0qHu7Kvps5rUGFf7V84NvlGoOIIciKjW9pS+uKZDRWS4EvGYvnohBR9qBUEORNTK9duKCnGGUmoPQQ5EzNWtW/Wj+59Trxc2pfCy2S26YcH0MrUKYSLIgSqVvrnVkjnT1PbMXwveG0WSTj/xSEK8hhHkQBXKtLnV4jUdBd/HJH2InnjNI8iBKlTs5laDmhJxXTv/FMbB6wRBDlShQosdx8z0h+XsjFGvKL4MVJnW9pQarLBdLy497fgytQZRQI8cqCKDY+P5zkiJmenS045nDLzOEeRAiIbPTEnEG7SvgCWa21fMK2PLECUEORCS9JkphYR4MzUzMQxj5EBIip2Zwj7hSEeQAyEpZGZKUyIuEzUzkRlDK0BIpjQl8ir60JSIq+Or7FKI7IIqvnyepJslxSR9191XBHFfIMoyLbEf7El/6ccP5xXiDZKunc8uhRibeYEb74y6gVlM0pOS3iNph6QHJV3q7o9l+04ymfS2traSngtUs0Ir9JikRLxBXd19Q99JxBu0fOGbGEbBEDPb7O7J9PNB9MhPlfSUu/9x4EG3SbpIUtYgB2pZa3sq7xCfMK5BK95HWKM0QbzsbJb03LDjHQPnRjCzRWbWZmZtu3btCuCxQHVauX5b3j3x/T19Wrl+W1nbg9pXsVkr7r7a3ZPunpw8eXKlHgtUTGt7SrP++Z68q9YPKnRfFSBdEEMrKUnDN3o4buAcUDda21NacvvD6u4t/J3TFBb3oERBBPmDkk4ysxPUH+AflPT3AdwXqFrpM1J27z1QVIizuAdBKDnI3b3HzD4nab36px9+390fLbllQJXKVPShGNTORFACmUfu7j+T9LMg7gVUu2KW1pskF+GN8mBlJ1CgQl9OTmqM66sXUq0H5UOQAwXKd2m9SfrGJTMJcJQdm2YBBVowa0rOaxLxGCGOiqFHDowhfXZKvi822aEQlUSQA1m0tqe0eE3H0HF6iB87cYL+9NL+Ud9rbkoQ4qgoghxIM9gLH6v3vX3FvFHTECXmhSMcBDkwTCErNAd73dm2qgUqhSAHhrnuzkcLWqG5YFYzwY3QMWsFGGb3vu6wmwAUjB456lZre0rX3vGoOrsKC+/TTzyyTC0CikOQoy61tqe05CcPq7uvsI2uTj/xSN36qbeVqVVAcQhy1KVr73g07xCPN5hWfmAGY+GoWoyRo+60tqdyDqc0NyVkA/8kxFHt6JGjruzZ1z1ikU8mzU0J3bf0rAq1CCgdQY6aVczy+njMWNCDyCHIUZOKKf7QYNLK9zOMguhhjBw1pbU9pdNX/FKL13RkLf6w6pKZSsRjI84l4jHddDG7FSKaSuqRm9lKSRdKOijpD5I+5u6dQTQMKFSmvU8yYWk9ak2pQysbJC0bqNv5r5KWSbqy9GYBhcunBFvMTBJL61FbShpacfd73L1n4HCTpONKbxJQnHzGwXu98Er3QLULcoz845LuzvahmS0yszYza9u1a1eAjwWkqUvvyuu65qZEmVsCVF7OoRUz2yjp2AwfXeXuPx245ipJPZJuzXYfd18tabUkJZNJukUIRL4BLrFXOGpXziB393PG+tzMPirpAklnu/P3VpRHa3tK19356NDuhIl4g7q6+0Zdt+qSmUMvMZsa43KX9nR180ITNa3UWSvnSbpC0rvcfV8wTQJGylTsIT3EE/HYUJ1Mwhr1ptQx8m9JOlzSBjPrMLPvBNAmYMjVrVu1eE1HzmIPXd29Wrl+W4VaBVSXknrk7v66oBoCpLu6datu2fRs3tfvzLPCPVBrWKKPqlNswYcpzEhBnSLIUVWKLfjAjBTUM4IcVWXl+m15hXgi3qAJ8Zg69zEjBSDIUVXyWZ152ewW3bBgegVaA0QDQY6qsOiHbbrnsT/nvI4QB0YjyFFR6cUelsyZlrNij9Rf8IG9woHM2I8cFTO4zWyqs0uu/mGU9BDfvmKeVl0yU02J+NC5SY1xQhwYAz1ylE1673vvgZ4xt5ndvmKeJLaYBQpFjxxlkan3Pda88EQ8ptb2VOUaCNQQghxlkU+Rh+FYYg8UjyBHWRSzXJ4l9kBxCHKURbbl8g32Srm1fL8DYGwEOQLX2p7KuLBnsFL91y+ekbGKPUvsgeIwawVFS9/calJjXGeePFnrHto56tpJjXF99cJTRsxGoYo9EAyCHEXJtLnV7n3dGUNckhrHjxsR1EwxBILD0AqKku/mVoN4kQmUD0GOouSzudVwvMgEyieQIDezL5mZm9lRQdwP1a3QhTu8yATKq+QgN7PjJZ0rKf+aXIi0XJtcNcYb1NyUkElqbkoMFUUGUB5BvOz8hqQrJP00gHuhip1/8//p8edfynndvyx8E8ENVFBJQW5mF0lKufvDlmWRx7BrF0laJEktLS2lPBYhmLr0rryuu2x2CyEOVFjOIDezjZKOzfDRVZK+rP5hlZzcfbWk1ZKUTCYLK8iI0GQK8FWXzNSydVtH7KVikj5E0QcgFDmD3N3PyXTezKZLOkHSYG/8OEkPmdmp7v6nQFuJUKSH+BtePVF3X/7OoWMW9ADVoeihFXffKunowWMz2y4p6e4vBtAuhChTL3xwr/BBLOgBqgfzyDGkt88zhjh7hQPVLbAl+u4+Nah7oTIGK/jkWtwzuFc4PXCgOrHXSp0arOCTb/EHltgD1Ysgr3GZqtYvmNVccAUfltgD1Ysgr2Hpve5UZ5eWrduqp1/cW/BeKSyxB6oXLztrWKZed1d3r27+xe8Luk9TIs74OFDFCPIalmtce8K43L/9iXhM184/JagmASgDgryGjTWuvX3FPK1435vUPHDNYB3NSY1xNSXibHgFRAhj5DXs6MMPGTUWnojHtHxh/zJ6FvUAtYEgrxHps1MyvcxsZik9UJMI8hqQaXbKcOnL6wHUFsbIa0C2OeGHHTKOEAfqAEFeA7LNCd97oKfCLQEQBoI8wnp6+8Ys+MBqTKA+MEYeUbkq9lDwGKgf9Mgj5i9/OzAqxO+98kytumQmBY+BOkWPPELGKvhw3KRGghuoUwR5BDz8XKcu+vZ9I849cf15mhCPhdQiANWEIK9y+ZRdA1DfSg5yM/u8pM9K6pV0l7tfUXKroDUPPqsr124dce7p5XM1UOgaAIaUFORmdqakiyTNcPcDZnZ0ru9gbK3tKS1e0zHiXDxm+v2Nc0NqEYBqV2qP/NOSVrj7AUly9xdKb1L9+vgPHtQvnxj5r3D4JlcAkEmp0w9fL+mdZna/mf3GzN6a7UIzW2RmbWbWtmvXrhIfW3umLr1rVIhLrxQ+BoBscvbIzWyjpGMzfHTVwPePlDRb0lsl/djMXuvunn6xu6+WtFqSksnkqM/r1fxv3astO/aMeQ2FjwGMJWeQu/s52T4zs09LWjcQ3A+YWZ+koyTR5c5D+oyUSY1x7d7XPeo6ltoDGEupY+Stks6U9Csze72k8ZJeLLlVNS7blML07WglltoDyK3UIP++pO+b2SOSDkr6SKZhlXo2vODDq4+YoJ179o/4/M7PvUPTjztCkoZWZg4vEEEhCAC5WBi5m0wmva2treLPrbRMPezhWNgDoBBmttndk+nnWdkZsOE98AYz9Wb4g7JB0k2XzKx84wDUJII8QOk98EwhLkl9kpat61+1ybAJgFKxjW2AspVcy4T54QCCQo+8SOlV65fMmVbwfG/mhwMIAkFehExV669cu0WFvjZmfjiAIDC0UoRMQygHevoyXpuIx3TZ7BYl0vYOZ344gKAQ5EUYa0gkU8m1GxZM1/KF0ynFBqAsGFopwpSmhFIZwry5KaEFs5ozBnS28wBQKnrkBfrVthcyhjhDJQDCQo88T+6uE5b9bMS5YydO0J9f2s9SegChIsjzkF527R2vO0q3fPK0EFsEAK8gyMfQ2+c68csje+Fbrj1XEyfEQ2oRAIxGkGdx0z3b9M1fPjV0fNnsFt2wgJJrAKoPQZ5mf3evTr7m5yPOPXnD+Ro/jvfCAKoTQT7M5be166cdO4eOvzz3ZC0648QQWwQAuRHkkv6696DefP2GEeeeXj5XZhZSiwAgf3Uf5Bf++73amnql+PE3L52l+TOmhNgiAChMSUFuZjMlfUfSBEk9kj7j7g8E0bBye+Yve/Wulb8ecY6KPQCiqNQe+dckXefud5vZ3IHjd5fcqjJ7/dV36+CwTa5uWzRbs1/7qhBbBADFKzXIXdLEgV8fIWnnGNeGruO5Ti349n0jztELBxB1pQb5Yknrzezf1L9vy9uzXWhmiyQtkqSWlpYSH1u4qUvvGnG84Qtn6KRjDq94OwAgaDmD3Mw2Sjo2w0dXSTpb0hfcfa2ZXSzpe5LOyXQfd18tabUkJZPJQmsw5JSpYs+CWc3a+Nif9ckftg1d95pXNeo3S84M+vEAEBrzLAWC8/qy2R5JTe7u1j9Xb4+7T8z1vWQy6W1tbbkuy1t6xR5JmjCuQfvTij3c/+WzdczECYE9FwAqycw2u3sy/XypyxV3SnrXwK/PkvT7Eu9XlEwVe4aH+FknH63tK+YR4gBqUqlj5J+SdLOZjZO0XwNj4JU2VsWeR66bo8MOqfvp8gBqWEkJ5+73SnpLQG0p2lgVewhxALUu8jtBHezp094DPaPOU7EHQL2ITHc106yUWIPp8z9qH7pm8mGH6MW/HaBiD4C6EokgT5+Vkurs0uI1HUOfn/OGo/Wf/5BkkysAdSkSQZ5pVsqgjV88Q687moU9AOpXJMbIs81KMYkQB1D3IhHkU5oSBZ0HgHoSiSBfMmeaEvHYiHPMSgGAfpEYIx+cfZJpLxUAqHeRCHKpP8wJbgAYLRJDKwCA7AhyAIg4ghwAIo4gB4CII8gBIOIIcgCIuJJKvRX9ULNdkp6p+IPL5yhJL4bdiAri56199fYzR+XnfY27T04/GUqQ1xoza8tUR69W8fPWvnr7maP+8zK0AgARR5ADQMQR5MFYHXYDKoyft/bV288c6Z+XMXIAiDh65AAQcQQ5AEQcQR4AM1tpZk+Y2RYz+18zawq7TeVmZh8ws0fNrM/MIjttKxczO8/MtpnZU2a2NOz2lJOZfd/MXjCzR8JuSyWY2fFm9isze2zg/8uXh92mYhHkwdgg6e/c/U2SnpS0LOT2VMIjkhZK+m3YDSkXM4tJ+rak8yW9UdKlZvbGcFtVVj+QdF7YjaigHklfcvc3Spot6bNR/f0lyAPg7ve4e8/A4SZJx4XZnkpw98fdfVvY7SizUyU95e5/dPeDkm6TdFHIbSobd/+tpL+G3Y5Kcffn3f2hgV+/LOlxSZGsXkOQB+/jku4OuxEIRLOk54Yd71BE/0PH2MxsqqRZku4PtyXFiUypt7CZ2UZJx2b46Cp3/+nANVep/69rt1aybeWSz88MRJ2ZHSZpraTF7v5S2O0pBkGeJ3c/Z6zPzeyjki6QdLbXyOT8XD9zHUhJOn7Y8XED51AjzCyu/hC/1d3Xhd2eYjG0EgAzO0/SFZLmu/u+sNuDwDwo6SQzO8HMxkv6oKQ7Qm4TAmJmJul7kh5395vCbk8pCPJgfEvS4ZI2mFmHmX0n7AaVm5m918x2SHqbpLvMbH3YbQrawAvsz0lar/4XYT9290fDbVX5mNmPJP1O0jQz22Fmnwi7TWV2uqQPSzpr4L/bDjObG3ajisESfQCIOHrkABBxBDkARBxBDgARR5ADQMQR5AAQcQQ5AEQcQQ4AEff/wA5ga+Fcz+UAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"theta = jnp.array([1., 1.])\n",
"\n",
"for _ in range(1000):\n",
" theta = update(theta, xs, ys)\n",
"\n",
"plt.scatter(xs, ys)\n",
"plt.plot(xs, model(theta, xs))\n",
"\n",
"w, b = theta\n",
"print(f\"w: {w:<.2f}, b: {b:<.2f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5-q17kJ_rjLc"
},
"source": [
"As you will see going through these guides, this basic recipe underlies almost all training loops you'll see implemented in JAX. The main difference between this example and real training loops is the simplicity of our model: that allows us to use a single array to house all our parameters. We cover managing more parameters in the later [pytree guide](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/05.1-pytrees.ipynb). Feel free to skip forward to that guide now to see how to manually define and train a simple MLP in JAX."
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Jax Basics.ipynb",
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}