fix error in autodiff cookbook: 3x not 2x

This commit is contained in:
Matthew Johnson 2019-12-30 07:36:36 -08:00 committed by GitHub
parent 30bede1f6a
commit f5723848d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -774,7 +774,7 @@
"source": [
"The `jvp`-transformed function is evaluated much like the original function, but paired up with each primal value of type `a` it pushes along tangent values of type `T a`. For each primitive numerical operation that the original function would have applied, the `jvp`-transformed function executes a \"JVP rule\" for that primitive that both evaluates the primitive on the primals and applies the primitive's JVP at those primal values.\n",
"\n",
"That evaluation strategy has some immediate implications about computational complexity: since we evaluate JVPs as we go, we don't need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the `jvp`-transformed function is about 2x the cost of just evaluating the function. Put another way, for a fixed primal point $x$, we can evaluate $v \\mapsto \\partial f(x) \\cdot v$ for about the same cost as evaluating $f$.\n",
"That evaluation strategy has some immediate implications about computational complexity: since we evaluate JVPs as we go, we don't need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the `jvp`-transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example `sin(x)`; one unit for linearizing, like `cos(x)`; and one unit for applying the linearized function to a vector, like `cos_x * v`). Put another way, for a fixed primal point $x$, we can evaluate $v \\mapsto \\partial f(x) \\cdot v$ for about the same marginal cost as evaluating $f$.\n",
"\n",
"That memory complexity sounds pretty compelling! So why don't we see forward-mode very often in machine learning?\n",
"\n",
@ -860,7 +860,7 @@
"\n",
"where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`.\n",
"\n",
"This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \\mapsto (f(x), v^\\mathsf{T} \\partial f(x))$ is only about twice the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \\mathbb{R}^n \\to \\mathbb{R}$, we can do it in just one call. That's how `grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.\n",
"This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \\mapsto (f(x), v^\\mathsf{T} \\partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \\mathbb{R}^n \\to \\mathbb{R}$, we can do it in just one call. That's how `grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.\n",
"\n",
"There's a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!).\n",
"\n",