mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
fixes from reviewers
This commit is contained in:
parent
141996ec11
commit
c22da81d5d
@ -26,6 +26,7 @@ This section contains examples and tutorials on more advanced topics, such as Mu
|
||||
|
||||
notebooks/autodiff_cookbook
|
||||
notebooks/Custom_derivative_rules_for_Python_code
|
||||
notebooks/autodiff_remat
|
||||
|
||||
.. toctree::
|
||||
:caption: JAX Internals
|
||||
@ -39,4 +40,4 @@ This section contains examples and tutorials on more advanced topics, such as Mu
|
||||
:caption: Deep Dives
|
||||
:maxdepth: 1
|
||||
|
||||
notebooks/convolutions
|
||||
notebooks/convolutions
|
||||
|
@ -29,9 +29,9 @@
|
||||
"\n",
|
||||
"Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.\n",
|
||||
"\n",
|
||||
"**Don't miss the [Practical notes](#practical-notes) for a discussion about how `jax.checkpoint` interacts with `jax.jit`.**\n",
|
||||
"**Don't miss the [practical notes](#practical-notes) for a discussion about how `jax.checkpoint` interacts with `jax.jit`.**\n",
|
||||
"\n",
|
||||
"Without using `jax.checkpoint`, the forward pass of `jax.grad(f)(x)` saves the values of some intermediates, and computes and saves the values of Jacobian coefficients, for use on the backward pass. We call these saved values _residuals_:"
|
||||
"Without using `jax.checkpoint`, the forward pass of `jax.grad(f)(x)` saves, for use on the backward pass, the values of Jacobian coefficients and other intermediates. We call these saved values _residuals_:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -83,7 +83,7 @@
|
||||
"id": "97vvWfI-fSSF"
|
||||
},
|
||||
"source": [
|
||||
"By applying `jax.checkpoint` to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function's residuals. Instead, only the inputs of a `jax.checkpoint`-decorated function can be saved, and any residuals are computed on the backward pass from those as needed:"
|
||||
"By applying `jax.checkpoint` to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function's residuals. Instead, only the inputs of a `jax.checkpoint`-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -116,11 +116,22 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here the values of two `sin` applications are saved because they are arguments\n",
|
||||
"in subsequent applications of the `jax.checkpoint`-decorated `g` function, and\n",
|
||||
"inputs to a `jax.checkpoint`-decorated function may be saved. But no values of\n",
|
||||
"`cos` applications are saved."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "807b4764",
|
||||
"metadata": {
|
||||
"id": "CyRR3mTpjRtl"
|
||||
},
|
||||
"source": [
|
||||
"To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization _policy_, for example to save only results whose computation is likely FLOP-bound:"
|
||||
"To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization _policy_. Here is an example that saves only the results of `dot` operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -143,7 +154,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)\n",
|
||||
"f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)\n",
|
||||
"jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)"
|
||||
]
|
||||
},
|
||||
@ -323,7 +334,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# using jax.checkpoint with policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims:\n",
|
||||
"# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:\n",
|
||||
"print_fwd_bwd(f3, W1, W2, W3, x)"
|
||||
]
|
||||
},
|
||||
@ -451,7 +462,7 @@
|
||||
"source": [
|
||||
"In words, this alternative implementation doesn't compute `g_vjp`, or the residual values in its closure, on the forward pass. Instead it only computes them in the backward pass `f_bwd2`. That means `f_vjp_checkpoint` requires less memory: if `g` and `h` each required similar amounts of memory for their residuals, each much larger than `x`, then the function produced by `f_vjp_checkpoint(x)` requires half the memory as that of `f_vjp(x)`!\n",
|
||||
"\n",
|
||||
"The cost we pay is redundant work: in `f_bwd2` we must re-evaluate `g(x)` just to discard its value (in the underscore variable on the line `_, g_vjp = jax.vjp(g, x)`)."
|
||||
"The cost we pay is redundant work: in `f_bwd2` we must re-evaluate `g(x)` as part of `jax.vjp(g, x)` just to discard its value (in the underscore variable on the line `_, g_vjp = jax.vjp(g, x)`)."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -460,7 +471,7 @@
|
||||
"id": "LqTrjPoGqrK7"
|
||||
},
|
||||
"source": [
|
||||
"We can get this autodiff behavior without having to write VJP functions directly by instead using `jax.checkpoint` in an alternative definition of `f`:"
|
||||
"We can get this VJP behavior in autodiff — without having to write VJP functions directly — by instead using `jax.checkpoint` in an alternative definition of the original function `f`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -512,7 +523,7 @@
|
||||
"id": "tfssPyb2sQgw"
|
||||
},
|
||||
"source": [
|
||||
"In general, `jax.checkpoint(foo)` is a new function which has the same input-ouptut behavior as `foo`, but behaves differently under autodiff, particularly under `jax.linearize` and `jax.vjp` (and their wrappers, like `jax.grad`) but not `jax.jvp`. When differentiated, only the input to a `jax.checkpoint`-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from `foo` and its Jacobian coefficient values needed for the backward pass) are recomputed."
|
||||
"In general, `jax.checkpoint(foo)` is a new function which has the same input-output behavior as `foo`, but behaves differently under autodiff, particularly under `jax.linearize` and `jax.vjp` (and their wrappers, like `jax.grad`) but not `jax.jvp`. When differentiated, only the input to a `jax.checkpoint`-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from `foo` and its Jacobian coefficient values needed for the backward pass) are recomputed."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -521,7 +532,7 @@
|
||||
"id": "HVvS_S5zsVZ-"
|
||||
},
|
||||
"source": [
|
||||
"Notice that with `f = lambda x: h(g(x))` is the funciton we want to differentiate, i.e. if we want to apply `jax.grad(f)`, we don't get any memory savings by applying `jax.checkpoint` to `f` itself. That's because evaluating `jax.grad(jax.checkpoint(f))(x)` would lead to a computation like:\n",
|
||||
"Notice that if `f = lambda x: h(g(x))` is the function we want to differentiate, i.e. if we want to apply `jax.grad(f)`, we don't get any memory savings by applying `jax.checkpoint` to `f` itself. That's because evaluating `jax.grad(jax.checkpoint(f))(x)` would lead to a computation like:\n",
|
||||
"1. run the forward pass, discarding all residuals;\n",
|
||||
"2. immediately re-run the forward pass, saving residuals;\n",
|
||||
"3. run the backward pass, consuming residuals from step 2.\n",
|
||||
@ -606,7 +617,7 @@
|
||||
"\n",
|
||||
"To operate between these two extremes, saving some things and not others, we can carefully place `jax.checkpoint` decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.\n",
|
||||
"\n",
|
||||
"So an alternative is to use the `policy` argument to `jax.checkpoint`. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on `jax.checkpoint_policies`, like `jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims`, since the API for writing custom policy callables is considered internal.\n",
|
||||
"So an alternative is to use the `policy` argument to `jax.checkpoint`. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on `jax.checkpoint_policies`, like `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`, since the API for writing custom policy callables is considered internal.\n",
|
||||
"\n",
|
||||
"For example, consider this function to be differentiated:"
|
||||
]
|
||||
@ -674,7 +685,7 @@
|
||||
"id": "Mep7AReRNHby"
|
||||
},
|
||||
"source": [
|
||||
"Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy `jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims`:"
|
||||
"Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -698,7 +709,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)\n",
|
||||
"loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)\n",
|
||||
"print_saved_residuals(loss_checkpoint, params, x, y)"
|
||||
]
|
||||
},
|
||||
@ -804,8 +815,8 @@
|
||||
"Some of the policies are:\n",
|
||||
"* `everything_saveable` (the default strategy, as if `jax.checkpoint` were not being used at all)\n",
|
||||
"* `nothing_saveable` (i.e. rematerialize everything, as if a custom policy were not being used at all)\n",
|
||||
"* `checkpoint_dots`\n",
|
||||
"* `checkpoint_dots_with_no_batch_dims`\n",
|
||||
"* `dots_saveable` or its alias `checkpoint_dots`\n",
|
||||
"* `dots_with_no_batch_dims_saveable` or its alias `checkpoint_dots_with_no_batch_dims`\n",
|
||||
"* `save_anything_but_these_names` (save any values except for the output of\n",
|
||||
" `checkpoint_name` with any of the names given)\n",
|
||||
"* `save_any_names_but_these` (save only named values, i.e. any outputs of\n",
|
||||
@ -1118,9 +1129,9 @@
|
||||
"source": [
|
||||
"When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit` to a function which contains a `jax.grad` call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **`jax.checkpoint` often isn't needed for differentiated functions under a `jax.jit`**. XLA will optimize things for you.\n",
|
||||
"\n",
|
||||
"One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations between multiple control flow primitives, e.g. between a forward-pass `scan` and the corresponding backward-pass `scan`, aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function we pass to `jax.lax.scan`.\n",
|
||||
"One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.\n",
|
||||
"\n",
|
||||
"For example, one common pattern in large Transformer models is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:"
|
||||
"For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1148,7 +1159,7 @@
|
||||
"id": "38xDcRwW518P"
|
||||
},
|
||||
"source": [
|
||||
"We would instead use `jax.lax.scan`:"
|
||||
"We would instead iterate over the layer application with `jax.lax.scan`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1181,7 +1192,7 @@
|
||||
"id": "A-NbCn8A6TFK"
|
||||
},
|
||||
"source": [
|
||||
"This scan-over-layers version reduces compile times, but by foiling some XLA optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use `jax.checkpoint` on the scanned function:"
|
||||
"This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use `jax.checkpoint` on the scanned function:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1194,7 +1205,7 @@
|
||||
"from functools import partial\n",
|
||||
"\n",
|
||||
"@partial(jax.checkpoint,\n",
|
||||
" policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)\n",
|
||||
" policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)\n",
|
||||
"def layer(x, W_b_pair):\n",
|
||||
" W, b = W_b_pair\n",
|
||||
" out = jnp.maximum(jnp.dot(x, W) + b, 0.)\n",
|
||||
|
@ -5,7 +5,7 @@ jupytext:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.14.1
|
||||
jupytext_version: 1.14.4
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
name: python3
|
||||
@ -26,9 +26,9 @@ import jax.numpy as jnp
|
||||
|
||||
Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.
|
||||
|
||||
**Don't miss the [Practical notes](#practical-notes) for a discussion about how `jax.checkpoint` interacts with `jax.jit`.**
|
||||
**Don't miss the [practical notes](#practical-notes) for a discussion about how `jax.checkpoint` interacts with `jax.jit`.**
|
||||
|
||||
Without using `jax.checkpoint`, the forward pass of `jax.grad(f)(x)` saves the values of some intermediates, and computes and saves the values of Jacobian coefficients, for use on the backward pass. We call these saved values _residuals_:
|
||||
Without using `jax.checkpoint`, the forward pass of `jax.grad(f)(x)` saves, for use on the backward pass, the values of Jacobian coefficients and other intermediates. We call these saved values _residuals_:
|
||||
|
||||
```{code-cell}
|
||||
def g(W, x):
|
||||
@ -54,7 +54,7 @@ jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
|
||||
|
||||
+++ {"id": "97vvWfI-fSSF"}
|
||||
|
||||
By applying `jax.checkpoint` to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function's residuals. Instead, only the inputs of a `jax.checkpoint`-decorated function can be saved, and any residuals are computed on the backward pass from those as needed:
|
||||
By applying `jax.checkpoint` to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function's residuals. Instead, only the inputs of a `jax.checkpoint`-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:
|
||||
|
||||
```{code-cell}
|
||||
def f2(W1, W2, W3, x):
|
||||
@ -66,12 +66,17 @@ def f2(W1, W2, W3, x):
|
||||
jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
|
||||
```
|
||||
|
||||
Here the values of two `sin` applications are saved because they are arguments
|
||||
in subsequent applications of the `jax.checkpoint`-decorated `g` function, and
|
||||
inputs to a `jax.checkpoint`-decorated function may be saved. But no values of
|
||||
`cos` applications are saved.
|
||||
|
||||
+++ {"id": "CyRR3mTpjRtl"}
|
||||
|
||||
To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization _policy_, for example to save only results whose computation is likely FLOP-bound:
|
||||
To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization _policy_. Here is an example that saves only the results of `dot` operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):
|
||||
|
||||
```{code-cell}
|
||||
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
|
||||
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
|
||||
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
|
||||
```
|
||||
|
||||
@ -141,7 +146,7 @@ print_fwd_bwd(f, W1, W2, W3, x)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
# using jax.checkpoint with policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims:
|
||||
# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
|
||||
print_fwd_bwd(f3, W1, W2, W3, x)
|
||||
```
|
||||
|
||||
@ -220,11 +225,11 @@ def f_vjp_checkpoint(x):
|
||||
|
||||
In words, this alternative implementation doesn't compute `g_vjp`, or the residual values in its closure, on the forward pass. Instead it only computes them in the backward pass `f_bwd2`. That means `f_vjp_checkpoint` requires less memory: if `g` and `h` each required similar amounts of memory for their residuals, each much larger than `x`, then the function produced by `f_vjp_checkpoint(x)` requires half the memory as that of `f_vjp(x)`!
|
||||
|
||||
The cost we pay is redundant work: in `f_bwd2` we must re-evaluate `g(x)` just to discard its value (in the underscore variable on the line `_, g_vjp = jax.vjp(g, x)`).
|
||||
The cost we pay is redundant work: in `f_bwd2` we must re-evaluate `g(x)` as part of `jax.vjp(g, x)` just to discard its value (in the underscore variable on the line `_, g_vjp = jax.vjp(g, x)`).
|
||||
|
||||
+++ {"id": "LqTrjPoGqrK7"}
|
||||
|
||||
We can get this autodiff behavior without having to write VJP functions directly by instead using `jax.checkpoint` in an alternative definition of `f`:
|
||||
We can get this VJP behavior in autodiff — without having to write VJP functions directly — by instead using `jax.checkpoint` in an alternative definition of the original function `f`:
|
||||
|
||||
```{code-cell}
|
||||
def f_checkpoint(x):
|
||||
@ -256,11 +261,11 @@ def f_checkpoint_grad(x):
|
||||
|
||||
+++ {"id": "tfssPyb2sQgw"}
|
||||
|
||||
In general, `jax.checkpoint(foo)` is a new function which has the same input-ouptut behavior as `foo`, but behaves differently under autodiff, particularly under `jax.linearize` and `jax.vjp` (and their wrappers, like `jax.grad`) but not `jax.jvp`. When differentiated, only the input to a `jax.checkpoint`-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from `foo` and its Jacobian coefficient values needed for the backward pass) are recomputed.
|
||||
In general, `jax.checkpoint(foo)` is a new function which has the same input-output behavior as `foo`, but behaves differently under autodiff, particularly under `jax.linearize` and `jax.vjp` (and their wrappers, like `jax.grad`) but not `jax.jvp`. When differentiated, only the input to a `jax.checkpoint`-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from `foo` and its Jacobian coefficient values needed for the backward pass) are recomputed.
|
||||
|
||||
+++ {"id": "HVvS_S5zsVZ-"}
|
||||
|
||||
Notice that with `f = lambda x: h(g(x))` is the funciton we want to differentiate, i.e. if we want to apply `jax.grad(f)`, we don't get any memory savings by applying `jax.checkpoint` to `f` itself. That's because evaluating `jax.grad(jax.checkpoint(f))(x)` would lead to a computation like:
|
||||
Notice that if `f = lambda x: h(g(x))` is the function we want to differentiate, i.e. if we want to apply `jax.grad(f)`, we don't get any memory savings by applying `jax.checkpoint` to `f` itself. That's because evaluating `jax.grad(jax.checkpoint(f))(x)` would lead to a computation like:
|
||||
1. run the forward pass, discarding all residuals;
|
||||
2. immediately re-run the forward pass, saving residuals;
|
||||
3. run the backward pass, consuming residuals from step 2.
|
||||
@ -315,7 +320,7 @@ As shown so far, using `jax.checkpoint` switches from one extreme to another:
|
||||
|
||||
To operate between these two extremes, saving some things and not others, we can carefully place `jax.checkpoint` decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.
|
||||
|
||||
So an alternative is to use the `policy` argument to `jax.checkpoint`. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on `jax.checkpoint_policies`, like `jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims`, since the API for writing custom policy callables is considered internal.
|
||||
So an alternative is to use the `policy` argument to `jax.checkpoint`. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on `jax.checkpoint_policies`, like `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`, since the API for writing custom policy callables is considered internal.
|
||||
|
||||
For example, consider this function to be differentiated:
|
||||
|
||||
@ -347,10 +352,10 @@ print_saved_residuals(loss, params, x, y)
|
||||
|
||||
+++ {"id": "Mep7AReRNHby"}
|
||||
|
||||
Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy `jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims`:
|
||||
Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`:
|
||||
|
||||
```{code-cell}
|
||||
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
|
||||
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
|
||||
print_saved_residuals(loss_checkpoint, params, x, y)
|
||||
```
|
||||
|
||||
@ -394,8 +399,8 @@ Another policy which refers to names is `jax.checkpoint_policies.save_only_these
|
||||
Some of the policies are:
|
||||
* `everything_saveable` (the default strategy, as if `jax.checkpoint` were not being used at all)
|
||||
* `nothing_saveable` (i.e. rematerialize everything, as if a custom policy were not being used at all)
|
||||
* `checkpoint_dots`
|
||||
* `checkpoint_dots_with_no_batch_dims`
|
||||
* `dots_saveable` or its alias `checkpoint_dots`
|
||||
* `dots_with_no_batch_dims_saveable` or its alias `checkpoint_dots_with_no_batch_dims`
|
||||
* `save_anything_but_these_names` (save any values except for the output of
|
||||
`checkpoint_name` with any of the names given)
|
||||
* `save_any_names_but_these` (save only named values, i.e. any outputs of
|
||||
@ -485,9 +490,9 @@ print_fwd_bwd(f, 3.)
|
||||
|
||||
When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit` to a function which contains a `jax.grad` call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **`jax.checkpoint` often isn't needed for differentiated functions under a `jax.jit`**. XLA will optimize things for you.
|
||||
|
||||
One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations between multiple control flow primitives, e.g. between a forward-pass `scan` and the corresponding backward-pass `scan`, aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function we pass to `jax.lax.scan`.
|
||||
One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.
|
||||
|
||||
For example, one common pattern in large Transformer models is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:
|
||||
For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:
|
||||
|
||||
+++ {"id": "BUeqKFRS5yPU"}
|
||||
|
||||
@ -505,7 +510,7 @@ def net(params: ParamsList, x: jnp.ndarray):
|
||||
|
||||
+++ {"id": "38xDcRwW518P"}
|
||||
|
||||
We would instead use `jax.lax.scan`:
|
||||
We would instead iterate over the layer application with `jax.lax.scan`:
|
||||
|
||||
+++ {"id": "ZU2fwYoG6A4z"}
|
||||
|
||||
@ -528,7 +533,7 @@ def net(all_weights, all_biases, x):
|
||||
|
||||
+++ {"id": "A-NbCn8A6TFK"}
|
||||
|
||||
This scan-over-layers version reduces compile times, but by foiling some XLA optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use `jax.checkpoint` on the scanned function:
|
||||
This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use `jax.checkpoint` on the scanned function:
|
||||
|
||||
+++ {"id": "iHVVNVdO66Dv"}
|
||||
|
||||
@ -536,7 +541,7 @@ This scan-over-layers version reduces compile times, but by foiling some XLA opt
|
||||
from functools import partial
|
||||
|
||||
@partial(jax.checkpoint,
|
||||
policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
|
||||
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
|
||||
def layer(x, W_b_pair):
|
||||
W, b = W_b_pair
|
||||
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
|
||||
|
@ -57,12 +57,13 @@ def nothing_saveable(*_, **__) -> bool:
|
||||
# This is the effective policy when using jax.remat without explicit policy.
|
||||
return False
|
||||
|
||||
def checkpoint_dots(prim, *_, **__) -> bool:
|
||||
def dots_saveable(prim, *_, **__) -> bool:
|
||||
# Matrix multiplies are expensive, so let's save them (and nothing else).
|
||||
return prim in {lax_internal.dot_general_p,
|
||||
lax_convolution.conv_general_dilated_p}
|
||||
checkpoint_dots = dots_saveable
|
||||
|
||||
def dot_with_no_batch_dims(prim, *_, **params) -> bool:
|
||||
def dot_with_no_batch_dims_saveable(prim, *_, **params) -> bool:
|
||||
# This is a useful heuristic for transformers.
|
||||
if prim is lax_internal.dot_general_p:
|
||||
(_, _), (lhs_b, rhs_b) = params['dimension_numbers']
|
||||
@ -111,8 +112,10 @@ def save_from_both_policies(policy_1, policy_2):
|
||||
checkpoint_policies = types.SimpleNamespace(
|
||||
everything_saveable=everything_saveable,
|
||||
nothing_saveable=nothing_saveable,
|
||||
checkpoint_dots=checkpoint_dots,
|
||||
checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims,
|
||||
dots_saveable=dots_saveable,
|
||||
checkpoint_dots=dots_saveable,
|
||||
dots_with_no_batch_dims_saveable=dot_with_no_batch_dims_saveable,
|
||||
checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims_saveable,
|
||||
save_anything_except_these_names=save_anything_except_these_names,
|
||||
save_any_names_but_these=save_any_names_but_these,
|
||||
save_only_these_names=save_only_these_names,
|
||||
|
Loading…
x
Reference in New Issue
Block a user