rocm_jax/docs/notebooks/autodiff_remat.ipynb
2023-02-23 15:06:55 -08:00

1249 lines
91 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "29WqUVkCXjDD"
},
"source": [
"## Control autodiff's saved values with `jax.checkpoint` (aka `jax.remat`)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qaIsQSh1XoKF"
},
"source": [
"### TL;DR\n",
"\n",
"Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.\n",
"\n",
"**Don't miss the [practical notes](#practical-notes) for a discussion about how `jax.checkpoint` interacts with `jax.jit`.**\n",
"\n",
"Without using `jax.checkpoint`, the forward pass of `jax.grad(f)(x)` saves, for use on the backward pass, the values of Jacobian coefficients and other intermediates. We call these saved values _residuals_:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[5,4] from the argument 'W1'\n",
"f32[6,5] from the argument 'W2'\n",
"f32[7,6] from the argument 'W3'\n",
"f32[4] from the argument 'x'\n",
"f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)\n",
"f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g)\n",
"f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)\n",
"f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g)\n",
"f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g)\n"
]
}
],
"source": [
"def g(W, x):\n",
" y = jnp.dot(W, x)\n",
" return jnp.sin(y)\n",
"\n",
"def f(W1, W2, W3, x):\n",
" x = g(W1, x)\n",
" x = g(W2, x)\n",
" x = g(W3, x)\n",
" return x\n",
"\n",
"W1 = jnp.ones((5, 4))\n",
"W2 = jnp.ones((6, 5))\n",
"W3 = jnp.ones((7, 6))\n",
"x = jnp.ones(4)\n",
"\n",
"# Inspect the 'residual' values to be saved on the forward pass\n",
"# if we were to evaluate `jax.grad(f)(W1, W2, W3, x)`\n",
"from jax.ad_checkpoint import print_saved_residuals\n",
"jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "97vvWfI-fSSF"
},
"source": [
"By applying `jax.checkpoint` to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function's residuals. Instead, only the inputs of a `jax.checkpoint`-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[5,4] from the argument 'W1'\n",
"f32[6,5] from the argument 'W2'\n",
"f32[7,6] from the argument 'W3'\n",
"f32[4] from the argument 'x'\n",
"f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)\n",
"f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)\n"
]
}
],
"source": [
"def f2(W1, W2, W3, x):\n",
" x = jax.checkpoint(g)(W1, x)\n",
" x = jax.checkpoint(g)(W2, x)\n",
" x = jax.checkpoint(g)(W3, x)\n",
" return x\n",
"\n",
"jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here the values of two `sin` applications are saved because they are arguments\n",
"in subsequent applications of the `jax.checkpoint`-decorated `g` function, and\n",
"inputs to a `jax.checkpoint`-decorated function may be saved. But no values of\n",
"`cos` applications are saved."
]
},
{
"cell_type": "markdown",
"id": "807b4764",
"metadata": {
"id": "CyRR3mTpjRtl"
},
"source": [
"To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization _policy_. Here is an example that saves only the results of `dot` operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[5,4] from the argument 'W1'\n",
"f32[6,5] from the argument 'W2'\n",
"f32[7,6] from the argument 'W3'\n",
"f32[4] from the argument 'x'\n",
"f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)\n",
"f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)\n",
"f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)\n"
]
}
],
"source": [
"f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)\n",
"jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9fe6W0YxlfKa"
},
"source": [
"You can also use policies to refer to intermediate values you name using `jax.ad_checkpoint.checkpoint_name`:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[5,4] from the argument 'W1'\n",
"f32[6,5] from the argument 'W2'\n",
"f32[7,6] from the argument 'W3'\n",
"f32[4] from the argument 'x'\n",
"f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4)\n"
]
}
],
"source": [
"from jax.ad_checkpoint import checkpoint_name\n",
"\n",
"def f4(W1, W2, W3, x):\n",
" x = checkpoint_name(g(W1, x), name='a')\n",
" x = checkpoint_name(g(W2, x), name='b')\n",
" x = checkpoint_name(g(W3, x), name='c')\n",
" return x\n",
"\n",
"f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))\n",
"jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "40oy-FbmVkDc"
},
"source": [
"When playing around with these toy examples, we can get a closer look at what's going on using the `print_fwd_bwd` utility definied in this notebook:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from jax.tree_util import tree_flatten, tree_unflatten\n",
"\n",
"from rich.console import Console\n",
"from rich.table import Table\n",
"import rich.text\n",
"\n",
"def print_fwd_bwd(f, *args, **kwargs) -> None:\n",
" args, in_tree = tree_flatten((args, kwargs))\n",
"\n",
" def f_(*args):\n",
" args, kwargs = tree_unflatten(in_tree, args)\n",
" return f(*args, **kwargs)\n",
"\n",
" fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr\n",
"\n",
" y, f_vjp = jax.vjp(f_, *args)\n",
" res, in_tree = tree_flatten(f_vjp)\n",
"\n",
" def g_(*args):\n",
" *res, y = args\n",
" f_vjp = tree_unflatten(in_tree, res)\n",
" return f_vjp(y)\n",
"\n",
" bwd = jax.make_jaxpr(g_)(*res, y).jaxpr\n",
"\n",
" table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)\n",
" table.add_row(\"[bold green]forward computation:\",\n",
" \"[bold green]backward computation:\")\n",
" table.add_row(rich.text.Text.from_ansi(str(fwd)),\n",
" rich.text.Text.from_ansi(str(bwd)))\n",
" console = Console(width=240, force_jupyter=True)\n",
" console.print(table)\n",
"\n",
"def _renderable_repr(self):\n",
" return self.html\n",
"rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"> \n",
" <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">forward computation:</span> <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">backward computation:</span> \n",
" \n",
" { <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; a</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5,4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> b</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6,5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> c</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7,6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> d</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[4]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> { <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; a</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\"> b</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> c</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7,6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> d</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> e</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> f</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6,5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> g</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> h</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[4]</span> \n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">e</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d</span> <span style=\"color: #000000; text-decoration-color: #000000\"> i</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5,4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> j</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> f</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin e</span> <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">k</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul j a</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> g</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos e</span> <span style=\"color: #000000; text-decoration-color: #000000\"> l</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> h</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f</span> <span style=\"color: #000000; text-decoration-color: #000000\"> m</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7,6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([], []), ([], []))] k b</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> i</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin h</span> <span style=\"color: #000000; text-decoration-color: #000000\"> n</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul l d</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> j</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos h</span> <span style=\"color: #000000; text-decoration-color: #000000\"> o</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> k</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i</span> <span style=\"color: #000000; text-decoration-color: #000000\"> p</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6,5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([], []), ([], []))] n e</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> l</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin k</span> <span style=\"color: #000000; text-decoration-color: #000000\"> q</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul o g</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> m</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos k</span> <span style=\"color: #000000; text-decoration-color: #000000\"> r</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([0], [0]), ([], []))] q i</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(l, m, i, c, j, f, b, g, d, a) }</span> <span style=\"color: #000000; text-decoration-color: #000000\"> s</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5,4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([], []), ([], []))] q h</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(s, p, m, r) }</span> \n",
"</pre>\n"
],
"text/plain": [
"<rich.jupyter.JupyterRenderable at 0x7f5e4659a820>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# no use of jax.checkpoint:\n",
"print_fwd_bwd(f, W1, W2, W3, x)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"> \n",
" <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">forward computation:</span> <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">backward computation:</span> \n",
" \n",
" { <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; a</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5,4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> b</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6,5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> c</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7,6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> d</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[4]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> { <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; a</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> b</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> c</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\"> d</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5,4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> e</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6,5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> f</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7,6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> g</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> h</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> \n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">e</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d</span> <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">i</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5,4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> j</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6,5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> k</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7,6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> l</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = remat2[</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> f</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin e</span> <span style=\"color: #000000; text-decoration-color: #000000\"> differentiated=True</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> g</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f</span> <span style=\"color: #000000; text-decoration-color: #000000\"> jaxpr={ </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; m</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> n</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> o</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\"> p</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5,4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> q</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6,5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> r</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7,6]</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> h</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin g</span> <span style=\"color: #000000; text-decoration-color: #000000\"> s</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> t</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> i</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([1], [0]), ([], []))] c h</span> <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">u</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin m</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> j</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin i</span> <span style=\"color: #000000; text-decoration-color: #000000\"> v</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos m</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(j, e, g, i, a, b, c, d) }</span> <span style=\"color: #000000; text-decoration-color: #000000\"> w</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin n</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> x</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos n</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> y</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos o</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> z</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul t y</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> ba</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> bb</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul ba x</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> bc</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> bd</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul bc v</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> be</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> bf</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[5,4]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([], []), ([], []))] bd s</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> bg</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[6,5]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([], []), ([], []))] bb u</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> bh</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[7,6]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = dot_general[dimension_numbers=(([], []), ([], []))] z w</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(bf, bg, bh, be) }</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> policy=&lt;function dot_with_no_batch_dims at 0x7f5e469b1700&gt;</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> prevent_cse=True</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> ] a b c d e f g h</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(i, j, k, l) }</span> \n",
"</pre>\n"
],
"text/plain": [
"<rich.jupyter.JupyterRenderable at 0x7f5e42c8f6d0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:\n",
"print_fwd_bwd(f3, W1, W2, W3, x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UsvnQJYomcub"
},
"source": [
"### Let's think step by step\n",
"\n",
"You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VMfwm_yinvoZ"
},
"source": [
"#### Fundamentals of `jax.checkpoint`\n",
"\n",
"\n",
"\n",
"In both `jax.linearize` and `jax.vjp` there is flexibilty in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices with `jax.checkpoint`.\n",
"\n",
"One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example of `sin_vjp`:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def sin_vjp(x):\n",
" y = jnp.sin(x)\n",
" cos_x = jnp.cos(x)\n",
" return y, lambda y_bar: cos_x * y_bar"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7swp_LJcorL6"
},
"source": [
"Another valid implementation would compute the value of `jnp.cos(x)` on the backward pass rather than on the forward pass:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def sin_vjp2(x):\n",
" y = jnp.sin(x)\n",
" return y, lambda y_bar: jnp.cos(x) * y_bar"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uDaHIXzHo18i"
},
"source": [
"For this particular function, the amount of memory used by the two versions is the same, though we've reduced the FLOPs for the primal computation (i.e. the forward pass) and increased the FLOPs for the cotangent computation (i.e. the backward pass).\n",
"\n",
"There's another choice when it comes to function composition. Recall our VJP rule for a composition of two functions:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def f(x):\n",
" y = g(x)\n",
" z = h(y)\n",
" return z\n",
"\n",
"def f_vjp(x):\n",
" y, g_vjp = jax.vjp(g, x)\n",
" z, h_vjp = jax.vjp(h, y)\n",
" def f_bwd(z_bar):\n",
" y_bar, = h_vjp(z_bar)\n",
" x_bar, = g_vjp(y_bar)\n",
" return x_bar\n",
" return z, f_bwd"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6pC6Ng-6pigH"
},
"source": [
"An alternative is:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def f_vjp_checkpoint(x):\n",
" y = g(x)\n",
" z, h_vjp = jax.vjp(h, y)\n",
" def f_bwd2(z_bar):\n",
" y_bar, = h_vjp(z_bar)\n",
" _, g_vjp = jax.vjp(g, x)\n",
" x_bar, = g_vjp(y_bar)\n",
" return x_bar\n",
" return z, f_bwd2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JYMw6oxtp6SH"
},
"source": [
"In words, this alternative implementation doesn't compute `g_vjp`, or the residual values in its closure, on the forward pass. Instead it only computes them in the backward pass `f_bwd2`. That means `f_vjp_checkpoint` requires less memory: if `g` and `h` each required similar amounts of memory for their residuals, each much larger than `x`, then the function produced by `f_vjp_checkpoint(x)` requires half the memory as that of `f_vjp(x)`!\n",
"\n",
"The cost we pay is redundant work: in `f_bwd2` we must re-evaluate `g(x)` as part of `jax.vjp(g, x)` just to discard its value (in the underscore variable on the line `_, g_vjp = jax.vjp(g, x)`)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LqTrjPoGqrK7"
},
"source": [
"We can get this VJP behavior in autodiff &#151; without having to write VJP functions directly &#151; by instead using `jax.checkpoint` in an alternative definition of the original function `f`:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def f_checkpoint(x):\n",
" y = jax.checkpoint(g)(x)\n",
" z = h(y)\n",
" return z"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bjNcEAnLrUwy"
},
"source": [
"In other words, we apply `jax.checkpoint` to `g`, the first stage of `f`, rather than to `f` itself. This way, when we evaluate `jax.grad(f_checkpoint)(x)`, we'd get a computation like:\n",
"1. run the forward pass of `g`, discarding residual values;\n",
"2. run the forward pass of `h`, saving residuals;\n",
"3. run the backward pass of `h`, consuming residuals from step 2;\n",
"4. re-run the forward pass of `g`, saving residuals;\n",
"5. run the backward pass of `g`, consuming residuals from step 4.\n",
"\n",
"That is, by evaluating `jax.grad(f_checkpoint)(x)` we'd get the same computation as:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def f_checkpoint_grad(x):\n",
" y = g(x) # step 1\n",
" _, h_vjp = jax.vjp(h)(y) # step 2\n",
" y_bar, = h_vjp(1.0) # step 3\n",
" _, g_vjp = jax.vjp(g, x) # step 4\n",
" x_bar, = g_vjp(y_bar) # step 5\n",
" return x_bar"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tfssPyb2sQgw"
},
"source": [
"In general, `jax.checkpoint(foo)` is a new function which has the same input-output behavior as `foo`, but behaves differently under autodiff, particularly under `jax.linearize` and `jax.vjp` (and their wrappers, like `jax.grad`) but not `jax.jvp`. When differentiated, only the input to a `jax.checkpoint`-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from `foo` and its Jacobian coefficient values needed for the backward pass) are recomputed."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HVvS_S5zsVZ-"
},
"source": [
"Notice that if `f = lambda x: h(g(x))` is the function we want to differentiate, i.e. if we want to apply `jax.grad(f)`, we don't get any memory savings by applying `jax.checkpoint` to `f` itself. That's because evaluating `jax.grad(jax.checkpoint(f))(x)` would lead to a computation like:\n",
"1. run the forward pass, discarding all residuals;\n",
"2. immediately re-run the forward pass, saving residuals;\n",
"3. run the backward pass, consuming residuals from step 2.\n",
"\n",
"That is, in code we'd have something like:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def f_grad_bad(x):\n",
" _ = f(x) # step 1\n",
" _, f_vjp = jax.vjp(f, x) # step 2\n",
" x_bar, = f_vjp(1.0) # step 3\n",
" return x_bar"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yayWMylctuhM"
},
"source": [
"We also wouldn't get any memory savings by applying `jax.checkpoint` to `h`, the second stage of `f`. That's because evaluating `jax.grad(lambda x: jax.checkpoint(h)(g(x)))` would lead to a computation like:\n",
"1. run the forward pass of `g`, saving residuals;\n",
"2. run the forward pass of `h`, discarding residuals;\n",
"3. immediately re-run the forward pass of `h`, saving residuals;\n",
"4. run the backward pass of `h`, consuming residuals from step 3;\n",
"5. run the backward pass of `g`, consuming residuals from step 1.\n",
"\n",
"That is, in code we'd have something like:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def f_grad_bad2(x):\n",
" y, g_vjp = jax.vjp(g, x) # step 1\n",
" z = h(y) # step 2\n",
" _, h_vjp = jax.vjp(h, y) # step 3\n",
" y_bar, = h_vjp(1.0) # step 3\n",
" x_bar, = g_vjp(y_bar) # step 5\n",
" return x_bar"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "suLiGdSFxUOc"
},
"source": [
"Slightly more generally, if we had a chain composition of functions, like `f = lambda x: f3(f2(f1(x)))`, and we were interested in evaluating `jax.grad(f)`, we could say that:\n",
"* we shouldn't apply `jax.checkpoint` to the whole function `f`, since that wouldn't save any memory (and will perform wasteful recomputation);\n",
"* we shouldn't apply `jax.checkpoint` to the last sub-function `f3`, since that wouldn't save any memory (and will perform wasteful recomputation);\n",
"* we could apply `jax.checkpoint` to `f1`, `f2`, or their composition `lambda x: f2(f1(x))`, since any of those might save memory and would express different memory/recompute tradeoffs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s9KXvtlkyBfq"
},
"source": [
"#### Custom policies for what's saveable"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i9cPf56JyO_h"
},
"source": [
"As shown so far, using `jax.checkpoint` switches from one extreme to another:\n",
"* without `jax.checkpoint`, JAX's autodiff tends to compute everything possible on the forward pass and store it for the backward pass;\n",
"* with a `jax.checkpoint` decorator, we instead compute as little as possible on the forward pass and recompute values as needed on the backward pass.\n",
"\n",
"To operate between these two extremes, saving some things and not others, we can carefully place `jax.checkpoint` decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.\n",
"\n",
"So an alternative is to use the `policy` argument to `jax.checkpoint`. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on `jax.checkpoint_policies`, like `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`, since the API for writing custom policy callables is considered internal.\n",
"\n",
"For example, consider this function to be differentiated:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def loss(params, x, y):\n",
" return jnp.sum((predict(params, x) - y)**2)\n",
"\n",
"def predict(params, x):\n",
" *Ws, Wlast = params\n",
" for W in Ws:\n",
" x = layer(W, x)\n",
" x = jnp.dot(Wlast, x)\n",
" return x\n",
"\n",
"def layer(W, x):\n",
" return jnp.sin(jnp.dot(W, x))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"W1 = W2 = W3 = jnp.ones((4, 4))\n",
"params = [W1, W2, W3]\n",
"x = jnp.ones(4)\n",
"y = jnp.ones(4)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[4,4] from the argument 'params'\n",
"f32[4,4] from the argument 'params'\n",
"f32[4,4] from the argument 'params'\n",
"f32[4] from the argument 'x'\n",
"f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)\n",
"f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)\n",
"f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)\n",
"f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)\n",
"f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)\n"
]
}
],
"source": [
"print_saved_residuals(loss, params, x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mep7AReRNHby"
},
"source": [
"Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[4,4] from the argument 'params'\n",
"f32[4,4] from the argument 'params'\n",
"f32[4,4] from the argument 'params'\n",
"f32[4] from the argument 'x'\n",
"f32[4] from the argument 'y'\n",
"f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)\n",
"f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)\n",
"f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict)\n"
]
}
],
"source": [
"loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)\n",
"print_saved_residuals(loss_checkpoint, params, x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bSS5AgbhOtEO"
},
"source": [
"Notice also that by providing a policy, we didn't need to edit the code defining `loss`, `predict`, or `layer`. That is particularly convenient if we want to experiment with policies in calling code (e.g. a training script) without changing library code (e.g. the neural network library)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wa8NudsITxNx"
},
"source": [
"Some policies can refer to values named with `jax.ad_checkpoint.checkpoint_name`:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"from jax.ad_checkpoint import checkpoint_name\n",
"\n",
"def predict(params, x):\n",
" *Ws, Wlast = params\n",
" for i, W in enumerate(Ws):\n",
" x = layer(W, x)\n",
" x = checkpoint_name(x, name=f'layer{i}_output')\n",
" x = jnp.dot(Wlast, x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eu88XIW6UXW_"
},
"source": [
"By itself, `checkpoint_name` is just an identity function. But because some policy functions know to look for them, we can use the names to control whether certain values output by `checkpoint_name` are considered saveable:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[4,4] from the argument 'params'\n",
"f32[4,4] from the argument 'params'\n",
"f32[4,4] from the argument 'params'\n",
"f32[4] from the argument 'x'\n",
"f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)\n",
"f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict)\n",
"f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)\n",
"f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict)\n",
"f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)\n"
]
}
],
"source": [
"print_saved_residuals(loss, params, x, y)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[4,4] from the argument 'params'\n",
"f32[4,4] from the argument 'params'\n",
"f32[4,4] from the argument 'params'\n",
"f32[4] from the argument 'x'\n",
"f32[4] from the argument 'y'\n"
]
}
],
"source": [
"loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))\n",
"print_saved_residuals(loss_checkpoint2, params, x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YYBzKTnT6JkL"
},
"source": [
"Another policy which refers to names is `jax.checkpoint_policies.save_only_these_names`.\n",
"\n",
"Some of the policies are:\n",
"* `everything_saveable` (the default strategy, as if `jax.checkpoint` were not being used at all)\n",
"* `nothing_saveable` (i.e. rematerialize everything, as if a custom policy were not being used at all)\n",
"* `dots_saveable` or its alias `checkpoint_dots`\n",
"* `dots_with_no_batch_dims_saveable` or its alias `checkpoint_dots_with_no_batch_dims`\n",
"* `save_anything_but_these_names` (save any values except for the output of\n",
" `checkpoint_name` with any of the names given)\n",
"* `save_any_names_but_these` (save only named values, i.e. any outputs of\n",
" `checkpoint_name`, except for those with the names given)\n",
"* `save_only_these_names` (save only named values, and only among the names\n",
" given)\n",
"\n",
"Policies only indicate what is saveable; a value is only saved if it's actually needed by the backward pass."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lixGsLNwxQo7"
},
"source": [
"#### Advanced: recursive `jax.checkpoint`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QHz3fQHZyT16"
},
"source": [
"By applying `jax.checkpoint` in the right way, there are many tradeoffs between memory usage and (re)computation that can be expressed. One surprising example is _recursive_ checkpointing, where we apply `jax.checkpoint` to a function which itself calls `jax.checkpoint`-decorated functions in a way so that memory usage from the chain composition of $D$ functions scales like $\\mathcal{O}(\\log_2 D)$ rather than $\\mathcal{O}(D)$.\n",
"\n",
"As a toy example, consider the chain composition of multiple `jnp.sin` functions:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n"
]
}
],
"source": [
"def chain_compose(funs):\n",
" def f(x):\n",
" for fun in funs:\n",
" x = fun(x)\n",
" return x\n",
" return f\n",
"\n",
"f = chain_compose([jnp.sin] * 8)\n",
"print_saved_residuals(f, 3.)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3SFXo4n6YJQG"
},
"source": [
"In general, the number of stored residuals scales linearly with the length of the chain:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n",
"f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)\n"
]
}
],
"source": [
"f = chain_compose([jnp.sin] * 16)\n",
"print_saved_residuals(f, 3.)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RcTuFRwZYXm7"
},
"source": [
"But we can apply `jax.checkpoint` recursively to improve the scaling:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"def recursive_checkpoint(funs):\n",
" if len(funs) == 1:\n",
" return funs[0]\n",
" elif len(funs) == 2:\n",
" f1, f2 = funs\n",
" return lambda x: f1(f2(x))\n",
" else:\n",
" f1 = recursive_checkpoint(funs[:len(funs)//2])\n",
" f2 = recursive_checkpoint(funs[len(funs)//2:])\n",
" return lambda x: f1(jax.checkpoint(f2)(x))"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[] from the argument 'x'\n",
"f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)\n",
"f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)\n",
"f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)\n"
]
}
],
"source": [
"f = recursive_checkpoint([jnp.sin] * 8)\n",
"print_saved_residuals(f, 3.)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f32[] from the argument 'x'\n",
"f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)\n",
"f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)\n",
"f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)\n",
"f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)\n"
]
}
],
"source": [
"f = recursive_checkpoint([jnp.sin] * 16)\n",
"print_saved_residuals(f, 3.)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D0yhcndHX0yN"
},
"source": [
"The cost here, as usual, is recomputation: in particular, we end up performing $\\mathcal{O}(\\log_2 D)$ times as many FLOPs:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"> \n",
" <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">forward computation:</span> <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">backward computation:</span> \n",
" \n",
" { <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; a</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> { <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; a</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> b</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> c</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> d</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> e</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> f</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> g</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> h</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> i</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> \n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">b</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin a</span> <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">j</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul i a</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> c</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos a</span> <span style=\"color: #000000; text-decoration-color: #000000\"> k</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul j b</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> d</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin b</span> <span style=\"color: #000000; text-decoration-color: #000000\"> l</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul k c</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> e</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos b</span> <span style=\"color: #000000; text-decoration-color: #000000\"> m</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul l d</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> f</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin d</span> <span style=\"color: #000000; text-decoration-color: #000000\"> n</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul m e</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> g</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos d</span> <span style=\"color: #000000; text-decoration-color: #000000\"> o</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul n f</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> h</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin f</span> <span style=\"color: #000000; text-decoration-color: #000000\"> p</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul o g</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> i</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos f</span> <span style=\"color: #000000; text-decoration-color: #000000\"> q</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul p h</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> j</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin h</span> <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(q,) }</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> k</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos h</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> l</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin j</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> m</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos j</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> n</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin l</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> o</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos l</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> p</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin n</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> q</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos n</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(p, q, o, m, k, i, g, e, c) }</span> \n",
"</pre>\n"
],
"text/plain": [
"<rich.jupyter.JupyterRenderable at 0x7f5e39c60490>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"f = chain_compose([jnp.sin] * 8)\n",
"print_fwd_bwd(f, 3.)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"> \n",
" <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">forward computation:</span> <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">backward computation:</span> \n",
" \n",
" { <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; a</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> { <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; a</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> b</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> c</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> d</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> \n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">b</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = remat2[</span> <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">e</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul d a</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> differentiated=False</span> <span style=\"color: #000000; text-decoration-color: #000000\"> f</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul e b</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> jaxpr={ </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; c</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span><span style=\"color: #000000; text-decoration-color: #000000\"> d</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin c; e</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin d </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(e,) }</span> <span style=\"color: #000000; text-decoration-color: #000000\"> g</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = remat2[</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> policy=None</span> <span style=\"color: #000000; text-decoration-color: #000000\"> differentiated=True</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> prevent_cse=True</span> <span style=\"color: #000000; text-decoration-color: #000000\"> jaxpr={ </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; h</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> i</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> ] a</span> <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">j</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin h</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> f</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin b</span> <span style=\"color: #000000; text-decoration-color: #000000\"> k</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos h</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> g</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin f</span> <span style=\"color: #000000; text-decoration-color: #000000\"> l</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos j</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> h</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin g</span> <span style=\"color: #000000; text-decoration-color: #000000\"> m</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul i l</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> i</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin h</span> <span style=\"color: #000000; text-decoration-color: #000000\"> n</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul m k</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> j</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin i</span> <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(n,) }</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> k</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos i</span> <span style=\"color: #000000; text-decoration-color: #000000\"> policy=None</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> l</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin j</span> <span style=\"color: #000000; text-decoration-color: #000000\"> prevent_cse=True</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> m</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos j</span> <span style=\"color: #000000; text-decoration-color: #000000\"> ] c f</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(l, m, k, g, a) }</span> <span style=\"color: #000000; text-decoration-color: #000000\"> o</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = remat2[</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> differentiated=True</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> jaxpr={ </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; p</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> q</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> \n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">r</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin p</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> s</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin r</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> t</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin s</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> u</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos s</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> v</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos t</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> w</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul q v</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> x</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul w u</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> y</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = remat2[</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> differentiated=True</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> jaxpr={ </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">lambda </span><span style=\"color: #000000; text-decoration-color: #000000\">; z</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> ba</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\">. </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">let</span> \n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\"> </span><span style=\"color: #000000; text-decoration-color: #000000\">bb</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = sin z</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> bc</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos z</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> bd</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = cos bb</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> be</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul ba bd</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> bf</span><span style=\"color: #800080; text-decoration-color: #800080\">:f32[]</span><span style=\"color: #000000; text-decoration-color: #000000\"> = mul be bc</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(bf,) }</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> policy=None</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> prevent_cse=True</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> ] p x</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(y,) }</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> policy=None</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> prevent_cse=True</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> ] 3.0 g</span> \n",
" <span style=\"color: #000000; text-decoration-color: #000000\"> </span><span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">in </span><span style=\"color: #000000; text-decoration-color: #000000\">(o,) }</span> \n",
"</pre>\n"
],
"text/plain": [
"<rich.jupyter.JupyterRenderable at 0x7f5e39c0ed00>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"f = recursive_checkpoint([jnp.sin] * 8)\n",
"print_fwd_bwd(f, 3.)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wvSm1yu0yUtd"
},
"source": [
"### Practical notes"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LSADkBOCyX9X"
},
"source": [
"When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit` to a function which contains a `jax.grad` call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **`jax.checkpoint` often isn't needed for differentiated functions under a `jax.jit`**. XLA will optimize things for you.\n",
"\n",
"One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.\n",
"\n",
"For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BUeqKFRS5yPU"
},
"source": [
"```python\n",
"from typing import Tuple, List\n",
"\n",
"LayerParam = Tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer\n",
"ParamsList = List[LayerParam]\n",
"\n",
"def net(params: ParamsList, x: jnp.ndarray):\n",
" for W, b in params:\n",
" x = jnp.maximum(jnp.dot(x, W) + b, 0.)\n",
" return x\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "38xDcRwW518P"
},
"source": [
"We would instead iterate over the layer application with `jax.lax.scan`:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZU2fwYoG6A4z"
},
"source": [
"```python\n",
"StackedWeights = jnp.ndarray # all weight matrices stacked together\n",
"StackedBiases = jnp.ndarray # all bias vectors stacked together\n",
"\n",
"all_weights = jnp.stack([W for W, _ in params])\n",
"all_biases = jnp.stack([b for _, b in params])\n",
"\n",
"def layer(x, W_b_pair):\n",
" W, b = W_b_pair\n",
" out = jnp.maximum(jnp.dot(x, W) + b, 0.)\n",
" return out, None\n",
"\n",
"def net(all_weights, all_biases, x):\n",
" x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))\n",
" return x\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A-NbCn8A6TFK"
},
"source": [
"This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use `jax.checkpoint` on the scanned function:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iHVVNVdO66Dv"
},
"source": [
"```python\n",
"from functools import partial\n",
"\n",
"@partial(jax.checkpoint,\n",
" policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)\n",
"def layer(x, W_b_pair):\n",
" W, b = W_b_pair\n",
" out = jnp.maximum(jnp.dot(x, W) + b, 0.)\n",
" return out, None\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QYZlp_-s7D4M"
},
"source": [
"By using `jax.checkpoint` this way, we're manually controlling which values JAX's autodiff saves between the forward and backward passes, and hence not relying on XLA optimizations to choose for us."
]
}
],
"metadata": {
"colab": {
"last_runtime": {
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
"kind": "private"
},
"provenance": [],
"toc_visible": true
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}