rocm_jax/docs/jax-101/04-advanced-autodiff.ipynb
2021-04-01 16:27:06 -07:00

680 lines
20 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "kORMl5KmfByI"
},
"source": [
"# Advanced Automatic Differentiation in JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/jax-101/04-advanced-autodiff.ipynb)\n",
"\n",
"*Authors: Vlatimir Mikulik & Matteo Hessel*\n",
"\n",
"Computing gradients is a critical part of modern machine learning methods. This section considers a few advanced topics in the areas of automatic differentiation as it relates to modern machine learning.\n",
"\n",
"While understanding how automatic differentiation works under the hood isn't crucial for using JAX in most contexts, we encourage the reader to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on.\n",
"\n",
"[The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) is a more advanced and more detailed explanation of how these ideas are implemented in the JAX backend. It's not necessary to understand this to do most things in JAX. However, some features (like defining [custom derivatives](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)) depend on understanding this, so it's worth knowing this explanation exists if you ever need to use them."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qx50CO1IorCc"
},
"source": [
"## Higher-order derivatives\n",
"\n",
"JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations.\n",
"\n",
"We illustrate this in the single-variable case:\n",
"\n",
"The derivative of $f(x) = x^3 + 2x^2 - 3x + 1$ can be computed as:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "Kqsbj98UTVdi"
},
"outputs": [],
"source": [
"import jax\n",
"\n",
"f = lambda x: x**3 + 2*x**2 - 3*x + 1\n",
"\n",
"dfdx = jax.grad(f)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ItEt15OGiiAF"
},
"source": [
"The higher-order derivatives of $f$ are:\n",
"\n",
"$$\n",
"\\begin{array}{l}\n",
"f'(x) = 3x^2 + 4x -3\\\\\n",
"f''(x) = 6x + 4\\\\\n",
"f'''(x) = 6\\\\\n",
"f^{iv}(x) = 0\n",
"\\end{array}\n",
"$$\n",
"\n",
"Computing any of these in JAX is as easy as chaining the `grad` function:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "5X3yQqLgimqH"
},
"outputs": [],
"source": [
"d2fdx = jax.grad(dfdx)\n",
"d3fdx = jax.grad(d2fdx)\n",
"d4fdx = jax.grad(d3fdx)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fVL2P_pcj8T1"
},
"source": [
"Evaluating the above in $x=1$ would give us:\n",
"\n",
"$$\n",
"\\begin{array}{l}\n",
"f'(1) = 4\\\\\n",
"f''(1) = 10\\\\\n",
"f'''(1) = 6\\\\\n",
"f^{iv}(1) = 0\n",
"\\end{array}\n",
"$$\n",
"\n",
"Using JAX:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "tJkIp9wFjxL3",
"outputId": "581ecf87-2d20-4c83-9443-5befc1baf51d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.0\n",
"10.0\n",
"6.0\n",
"0.0\n"
]
}
],
"source": [
"print(dfdx(1.))\n",
"print(d2fdx(1.))\n",
"print(d3fdx(1.))\n",
"print(d4fdx(1.))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3-fTelU7LHRr"
},
"source": [
"In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to\n",
"\n",
"$$(\\mathbf{H}f)_{i,j} = \\frac{\\partial^2 f}{\\partial_i\\partial_j}.$$\n",
"\n",
"The Hessian of a real-valued function of several variables, $f: \\mathbb R^n\\to\\mathbb R$, can be identified with the Jacobian of its gradient. JAX provides two transformations for computing the Jacobian of a function, `jax.jacfwd` and `jax.jacrev`, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances see the [video about autodiff](https://www.youtube.com/watch?v=wG_nF1awSSY) linked above for an explanation."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "ILhkef1rOB6_"
},
"outputs": [],
"source": [
"def hessian(f):\n",
" return jax.jacfwd(jax.grad(f))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xaENwADXOGf_"
},
"source": [
"Let's double check this is correct on the dot-product $f: \\mathbf{x} \\mapsto \\mathbf{x} ^\\top \\mathbf{x}$.\n",
"\n",
"if $i=j$, $\\frac{\\partial^2 f}{\\partial_i\\partial_j}(\\mathbf{x}) = 2$. Otherwise, $\\frac{\\partial^2 f}{\\partial_i\\partial_j}(\\mathbf{x}) = 0$."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "Xm3A0QdWRdJl",
"outputId": "e1e8cba9-b567-439b-b8fc-34b21497e67f"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[2., 0., 0.],\n",
" [0., 2., 0.],\n",
" [0., 0., 2.]], dtype=float32)"
]
},
"execution_count": 6,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"import jax.numpy as jnp\n",
"\n",
"def f(x):\n",
" return jnp.dot(x, x)\n",
"\n",
"hessian(f)(jnp.array([1., 2., 3.]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7_gbi34WSUsD"
},
"source": [
"Often, however, we aren't interested in computing the full Hessian itself, and doing so can be very inefficient. [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) explains some tricks, like the Hessian-vector product, that allow to use it without materialising the whole matrix.\n",
"\n",
"If you plan to work with higher-order derivatives in JAX, we strongly recommend reading the Autodiff Cookbook."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMT2qAi-SvcK"
},
"source": [
"## Higher order optimization\n",
"\n",
"Some meta-learning techniques, such as Model-Agnostic Meta-Learning ([MAML](https://arxiv.org/abs/1703.03400)), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier:\n",
"\n",
"```python\n",
"def meta_loss_fn(params, data):\n",
" \"\"\"Computes the loss after one step of SGD.\"\"\"\n",
" grads = jax.grad(loss_fn)(params, data)\n",
" return loss_fn(params - lr * grads, data)\n",
"\n",
"meta_grads = jax.grad(meta_loss_fn)(params, data)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3h9Aj3YyuL6P"
},
"source": [
"## Stopping gradients\n",
"\n",
"Auto-diff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, we might want some additional control: for instance, we might want to avoid back-propagating gradients through some subset of the computational graph.\n",
"\n",
"Consider for instance the TD(0) ([temporal difference](https://en.wikipedia.org/wiki/Temporal_difference_learning)) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Let's assume the value estimate $v_{\\theta}(s_{t-1}$) in a state $s_{t-1}$ is parameterised by a linear function."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "fjLqbCb6SiOm"
},
"outputs": [],
"source": [
"# Value function and initial parameters\n",
"value_fn = lambda theta, state: jnp.dot(theta, state)\n",
"theta = jnp.array([0.1, -0.1, 0.])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "85S7HBo1tBzt"
},
"source": [
"Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which we observed the reward $r_t$"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "T6cRPau6tCSE"
},
"outputs": [],
"source": [
"# An example transition.\n",
"s_tm1 = jnp.array([1., 2., -1.])\n",
"r_t = jnp.array(1.)\n",
"s_t = jnp.array([2., 1., 0.])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QO5CHA9_Sk01"
},
"source": [
"The TD(0) update to the network parameters is:\n",
"\n",
"$$\n",
"\\Delta \\theta = (r_t + v_{\\theta}(s_t) - v_{\\theta}(s_{t-1})) \\nabla v_{\\theta}(s_{t-1})\n",
"$$\n",
"\n",
"This update is not the gradient of any loss function.\n",
"\n",
"However it can be **written** as the gradient of the pseudo loss function\n",
"\n",
"$$\n",
"L(\\theta) = [r_t + v_{\\theta}(s_t) - v_{\\theta}(s_{t-1})]^2\n",
"$$\n",
"\n",
"if the dependency of the target $r_t + v_{\\theta}(s_t)$ on the parameter $\\theta$ is ignored.\n",
"\n",
"How can we implement this in JAX? If we write the pseudo loss naively we get:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "uMcFny2xuOwz",
"outputId": "79c10af9-10b8-4e18-9753-a53918b9d72d"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([ 2.4, -2.4, 2.4], dtype=float32)"
]
},
"execution_count": 9,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def td_loss(theta, s_tm1, r_t, s_t):\n",
" v_tm1 = value_fn(theta, s_tm1)\n",
" target = r_t + value_fn(theta, s_t)\n",
" return (target - v_tm1) ** 2\n",
"\n",
"td_update = jax.grad(td_loss)\n",
"delta_theta = td_update(theta, s_tm1, r_t, s_t)\n",
"\n",
"delta_theta"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CPnjm59GG4Gq"
},
"source": [
"But `td_update` will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target` on $\\theta$.\n",
"\n",
"We can use `jax.lax.stop_gradient` to force JAX to ignore the dependency of the target on $\\theta$:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "WCeq7trKPS4V",
"outputId": "0f38d754-a871-4c47-8e3a-a961418a24cc"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([-2.4, -4.8, 2.4], dtype=float32)"
]
},
"execution_count": 10,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def td_loss(theta, s_tm1, r_t, s_t):\n",
" v_tm1 = value_fn(theta, s_tm1)\n",
" target = r_t + value_fn(theta, s_t)\n",
" return (jax.lax.stop_gradient(target) - v_tm1) ** 2\n",
"\n",
"td_update = jax.grad(td_loss)\n",
"delta_theta = td_update(theta, s_tm1, r_t, s_t)\n",
"\n",
"delta_theta"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNF0CkwOTKpD"
},
"source": [
"This will treat `target` as if it did **not** depend on the parameters $\\theta$ and compute the correct update to the parameters.\n",
"\n",
"The `jax.lax.stop_gradient` may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss).\n",
"\n",
"## Straight-through estimator using `stop_gradient`\n",
"\n",
"The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function $f : \\mathbb{R}^n \\to \\mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that $f$ is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "hdORJENmVHvX",
"outputId": "f0839541-46a4-45a9-fce7-ead08f20046b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f(x): 3.0\n",
"straight_through_f(x): 3.0\n",
"grad(f)(x): 0.0\n",
"grad(straight_through_f)(x): 1.0\n"
]
}
],
"source": [
"def f(x):\n",
" return jnp.round(x) # non-differentiable\n",
"\n",
"def straight_through_f(x):\n",
" return x + jax.lax.stop_gradient(f(x) - x)\n",
"\n",
"print(\"f(x): \", f(3.2))\n",
"print(\"straight_through_f(x):\", straight_through_f(3.2))\n",
"\n",
"print(\"grad(f)(x):\", jax.grad(f)(3.2))\n",
"print(\"grad(straight_through_f)(x):\", jax.grad(straight_through_f)(3.2))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wx3RNE0Sw5mn"
},
"source": [
"## Per-example gradients\n",
"\n",
"While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch.\n",
"\n",
"For instance, this is needed to prioritise data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis.\n",
"\n",
"In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient.\n",
"\n",
"In JAX we can define the code to compute the gradient per-sample in an easy but efficient way.\n",
"\n",
"Just combine the `jit`, `vmap` and `grad` transformations together:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "tFLyd9ifw4GG",
"outputId": "bf3ad4a3-102d-47a6-ece0-f4a8c9e5d434"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 12,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))\n",
"\n",
"# Test it:\n",
"batched_s_tm1 = jnp.stack([s_tm1, s_tm1])\n",
"batched_r_t = jnp.stack([r_t, r_t])\n",
"batched_s_t = jnp.stack([s_t, s_t])\n",
"\n",
"perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VxvYVEYQYiS_"
},
"source": [
"Let's walk through this one transformation at a time.\n",
"\n",
"First, we apply `jax.grad` to `td_loss` to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "rPO67QQrY5Bk",
"outputId": "fbb45b98-2dbf-4865-e6e5-87dc3eef5560"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([-2.4, -4.8, 2.4], dtype=float32)"
]
},
"execution_count": 13,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"dtdloss_dtheta = jax.grad(td_loss)\n",
"\n",
"dtdloss_dtheta(theta, s_tm1, r_t, s_t)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cU36nVAlcnJ0"
},
"source": [
"This function computes one row of the array above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c6DQF0b3ZA5u"
},
"source": [
"Then, we vectorise this function using `jax.vmap`. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, we produce a batch of outputs -- each output in the batch corresponds to the gradient for the corresponding member of the input batch."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "5agbNKavaNDM",
"outputId": "ab081012-88ab-4904-a367-68e9f81445f0"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 14,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"almost_perex_grads = jax.vmap(dtdloss_dtheta)\n",
"\n",
"batched_theta = jnp.stack([theta, theta])\n",
"almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K-v34yLuan7k"
},
"source": [
"This isn't quite what we want, because we have to manually feed this function a batch of `theta`s, whereas we actually want to use a single `theta`. We fix this by adding `in_axes` to the `jax.vmap`, specifying theta as `None`, and the other args as `0`. This makes the resulting function add an extra axis only to the other arguments, leaving `theta` unbatched, as we want:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "S6kd5MujbGrr",
"outputId": "d3d731ef-3f7d-4a0a-ce91-7df57627ddbd"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 15,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))\n",
"\n",
"inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "O0hbsm70be5T"
},
"source": [
"Almost there! This does what we want, but is slower than it has to be. Now, we wrap the whole thing in a `jax.jit` to get the compiled, efficient version of the same function:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "Fvr709FcbrSW",
"outputId": "627db899-5620-4bed-8d34-cd1364d3d187"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 16,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"perex_grads = jax.jit(inefficient_perex_grads)\n",
"\n",
"perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "FH42yzbHcNs2",
"outputId": "c8e52f93-615a-4ce7-d8ab-fb6215995a39"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100 loops, best of 5: 7.74 ms per loop\n",
"10000 loops, best of 5: 86.2 µs per loop\n"
]
}
],
"source": [
"%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()\n",
"%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Advanced Grads",
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}