mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Update the advanced autodiff tutorial and replace some vmap with grad
This commit is contained in:
parent
853af56007
commit
f881f507d6
@ -247,7 +247,7 @@ perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
|
||||
|
||||
### Hessian-vector products with `jax.grad`-of-`jax.grad`
|
||||
|
||||
One thing you can do with higher-order {func}`jax.vmap` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)
|
||||
One thing you can do with higher-order {func}`jax.grad` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)
|
||||
|
||||
A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)).
|
||||
|
||||
@ -259,11 +259,11 @@ for any $v \in \mathbb{R}^n$.
|
||||
|
||||
The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.
|
||||
|
||||
Luckily, {func}`jax.vmap` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity:
|
||||
Luckily, {func}`jax.grad` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity:
|
||||
|
||||
$\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$,
|
||||
|
||||
where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.vmap` is efficient.
|
||||
where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.grad` is efficient.
|
||||
|
||||
In JAX code, you can just write this:
|
||||
|
||||
@ -357,7 +357,7 @@ To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfw
|
||||
|
||||
### Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff)
|
||||
|
||||
JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.vmap` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background.
|
||||
JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.grad` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background.
|
||||
|
||||
|
||||
#### JVPs in math
|
||||
@ -473,7 +473,7 @@ vjp :: (a -> b) -> a -> (b, CT b -> CT a)
|
||||
|
||||
where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`.
|
||||
|
||||
This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.vmap` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.
|
||||
This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.
|
||||
|
||||
There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!).
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user