fixes from reviewers

This commit is contained in:
Matthew Johnson 2023-02-22 09:26:43 -08:00
parent 141996ec11
commit c22da81d5d
4 changed files with 66 additions and 46 deletions

View File

@ -26,6 +26,7 @@ This section contains examples and tutorials on more advanced topics, such as Mu
notebooks/autodiff_cookbook
notebooks/Custom_derivative_rules_for_Python_code
notebooks/autodiff_remat
.. toctree::
:caption: JAX Internals
@ -39,4 +40,4 @@ This section contains examples and tutorials on more advanced topics, such as Mu
:caption: Deep Dives
:maxdepth: 1
notebooks/convolutions
notebooks/convolutions

View File

@ -29,9 +29,9 @@
"\n",
"Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.\n",
"\n",
"**Don't miss the [Practical notes](#practical-notes) for a discussion about how `jax.checkpoint` interacts with `jax.jit`.**\n",
"**Don't miss the [practical notes](#practical-notes) for a discussion about how `jax.checkpoint` interacts with `jax.jit`.**\n",
"\n",
"Without using `jax.checkpoint`, the forward pass of `jax.grad(f)(x)` saves the values of some intermediates, and computes and saves the values of Jacobian coefficients, for use on the backward pass. We call these saved values _residuals_:"
"Without using `jax.checkpoint`, the forward pass of `jax.grad(f)(x)` saves, for use on the backward pass, the values of Jacobian coefficients and other intermediates. We call these saved values _residuals_:"
]
},
{
@ -83,7 +83,7 @@
"id": "97vvWfI-fSSF"
},
"source": [
"By applying `jax.checkpoint` to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function's residuals. Instead, only the inputs of a `jax.checkpoint`-decorated function can be saved, and any residuals are computed on the backward pass from those as needed:"
"By applying `jax.checkpoint` to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function's residuals. Instead, only the inputs of a `jax.checkpoint`-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:"
]
},
{
@ -116,11 +116,22 @@
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here the values of two `sin` applications are saved because they are arguments\n",
"in subsequent applications of the `jax.checkpoint`-decorated `g` function, and\n",
"inputs to a `jax.checkpoint`-decorated function may be saved. But no values of\n",
"`cos` applications are saved."
]
},
{
"cell_type": "markdown",
"id": "807b4764",
"metadata": {
"id": "CyRR3mTpjRtl"
},
"source": [
"To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization _policy_, for example to save only results whose computation is likely FLOP-bound:"
"To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization _policy_. Here is an example that saves only the results of `dot` operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):"
]
},
{
@ -143,7 +154,7 @@
}
],
"source": [
"f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)\n",
"f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)\n",
"jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)"
]
},
@ -323,7 +334,7 @@
}
],
"source": [
"# using jax.checkpoint with policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims:\n",
"# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:\n",
"print_fwd_bwd(f3, W1, W2, W3, x)"
]
},
@ -451,7 +462,7 @@
"source": [
"In words, this alternative implementation doesn't compute `g_vjp`, or the residual values in its closure, on the forward pass. Instead it only computes them in the backward pass `f_bwd2`. That means `f_vjp_checkpoint` requires less memory: if `g` and `h` each required similar amounts of memory for their residuals, each much larger than `x`, then the function produced by `f_vjp_checkpoint(x)` requires half the memory as that of `f_vjp(x)`!\n",
"\n",
"The cost we pay is redundant work: in `f_bwd2` we must re-evaluate `g(x)` just to discard its value (in the underscore variable on the line `_, g_vjp = jax.vjp(g, x)`)."
"The cost we pay is redundant work: in `f_bwd2` we must re-evaluate `g(x)` as part of `jax.vjp(g, x)` just to discard its value (in the underscore variable on the line `_, g_vjp = jax.vjp(g, x)`)."
]
},
{
@ -460,7 +471,7 @@
"id": "LqTrjPoGqrK7"
},
"source": [
"We can get this autodiff behavior without having to write VJP functions directly by instead using `jax.checkpoint` in an alternative definition of `f`:"
"We can get this VJP behavior in autodiff — without having to write VJP functions directly — by instead using `jax.checkpoint` in an alternative definition of the original function `f`:"
]
},
{
@ -512,7 +523,7 @@
"id": "tfssPyb2sQgw"
},
"source": [
"In general, `jax.checkpoint(foo)` is a new function which has the same input-ouptut behavior as `foo`, but behaves differently under autodiff, particularly under `jax.linearize` and `jax.vjp` (and their wrappers, like `jax.grad`) but not `jax.jvp`. When differentiated, only the input to a `jax.checkpoint`-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from `foo` and its Jacobian coefficient values needed for the backward pass) are recomputed."
"In general, `jax.checkpoint(foo)` is a new function which has the same input-output behavior as `foo`, but behaves differently under autodiff, particularly under `jax.linearize` and `jax.vjp` (and their wrappers, like `jax.grad`) but not `jax.jvp`. When differentiated, only the input to a `jax.checkpoint`-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from `foo` and its Jacobian coefficient values needed for the backward pass) are recomputed."
]
},
{
@ -521,7 +532,7 @@
"id": "HVvS_S5zsVZ-"
},
"source": [
"Notice that with `f = lambda x: h(g(x))` is the funciton we want to differentiate, i.e. if we want to apply `jax.grad(f)`, we don't get any memory savings by applying `jax.checkpoint` to `f` itself. That's because evaluating `jax.grad(jax.checkpoint(f))(x)` would lead to a computation like:\n",
"Notice that if `f = lambda x: h(g(x))` is the function we want to differentiate, i.e. if we want to apply `jax.grad(f)`, we don't get any memory savings by applying `jax.checkpoint` to `f` itself. That's because evaluating `jax.grad(jax.checkpoint(f))(x)` would lead to a computation like:\n",
"1. run the forward pass, discarding all residuals;\n",
"2. immediately re-run the forward pass, saving residuals;\n",
"3. run the backward pass, consuming residuals from step 2.\n",
@ -606,7 +617,7 @@
"\n",
"To operate between these two extremes, saving some things and not others, we can carefully place `jax.checkpoint` decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.\n",
"\n",
"So an alternative is to use the `policy` argument to `jax.checkpoint`. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on `jax.checkpoint_policies`, like `jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims`, since the API for writing custom policy callables is considered internal.\n",
"So an alternative is to use the `policy` argument to `jax.checkpoint`. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on `jax.checkpoint_policies`, like `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`, since the API for writing custom policy callables is considered internal.\n",
"\n",
"For example, consider this function to be differentiated:"
]
@ -674,7 +685,7 @@
"id": "Mep7AReRNHby"
},
"source": [
"Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy `jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims`:"
"Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`:"
]
},
{
@ -698,7 +709,7 @@
}
],
"source": [
"loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)\n",
"loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)\n",
"print_saved_residuals(loss_checkpoint, params, x, y)"
]
},
@ -804,8 +815,8 @@
"Some of the policies are:\n",
"* `everything_saveable` (the default strategy, as if `jax.checkpoint` were not being used at all)\n",
"* `nothing_saveable` (i.e. rematerialize everything, as if a custom policy were not being used at all)\n",
"* `checkpoint_dots`\n",
"* `checkpoint_dots_with_no_batch_dims`\n",
"* `dots_saveable` or its alias `checkpoint_dots`\n",
"* `dots_with_no_batch_dims_saveable` or its alias `checkpoint_dots_with_no_batch_dims`\n",
"* `save_anything_but_these_names` (save any values except for the output of\n",
" `checkpoint_name` with any of the names given)\n",
"* `save_any_names_but_these` (save only named values, i.e. any outputs of\n",
@ -1118,9 +1129,9 @@
"source": [
"When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit` to a function which contains a `jax.grad` call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **`jax.checkpoint` often isn't needed for differentiated functions under a `jax.jit`**. XLA will optimize things for you.\n",
"\n",
"One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations between multiple control flow primitives, e.g. between a forward-pass `scan` and the corresponding backward-pass `scan`, aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function we pass to `jax.lax.scan`.\n",
"One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.\n",
"\n",
"For example, one common pattern in large Transformer models is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:"
"For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:"
]
},
{
@ -1148,7 +1159,7 @@
"id": "38xDcRwW518P"
},
"source": [
"We would instead use `jax.lax.scan`:"
"We would instead iterate over the layer application with `jax.lax.scan`:"
]
},
{
@ -1181,7 +1192,7 @@
"id": "A-NbCn8A6TFK"
},
"source": [
"This scan-over-layers version reduces compile times, but by foiling some XLA optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use `jax.checkpoint` on the scanned function:"
"This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use `jax.checkpoint` on the scanned function:"
]
},
{
@ -1194,7 +1205,7 @@
"from functools import partial\n",
"\n",
"@partial(jax.checkpoint,\n",
" policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)\n",
" policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)\n",
"def layer(x, W_b_pair):\n",
" W, b = W_b_pair\n",
" out = jnp.maximum(jnp.dot(x, W) + b, 0.)\n",

View File

@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.1
jupytext_version: 1.14.4
kernelspec:
display_name: Python 3
name: python3
@ -26,9 +26,9 @@ import jax.numpy as jnp
Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.
**Don't miss the [Practical notes](#practical-notes) for a discussion about how `jax.checkpoint` interacts with `jax.jit`.**
**Don't miss the [practical notes](#practical-notes) for a discussion about how `jax.checkpoint` interacts with `jax.jit`.**
Without using `jax.checkpoint`, the forward pass of `jax.grad(f)(x)` saves the values of some intermediates, and computes and saves the values of Jacobian coefficients, for use on the backward pass. We call these saved values _residuals_:
Without using `jax.checkpoint`, the forward pass of `jax.grad(f)(x)` saves, for use on the backward pass, the values of Jacobian coefficients and other intermediates. We call these saved values _residuals_:
```{code-cell}
def g(W, x):
@ -54,7 +54,7 @@ jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
+++ {"id": "97vvWfI-fSSF"}
By applying `jax.checkpoint` to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function's residuals. Instead, only the inputs of a `jax.checkpoint`-decorated function can be saved, and any residuals are computed on the backward pass from those as needed:
By applying `jax.checkpoint` to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function's residuals. Instead, only the inputs of a `jax.checkpoint`-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:
```{code-cell}
def f2(W1, W2, W3, x):
@ -66,12 +66,17 @@ def f2(W1, W2, W3, x):
jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
```
Here the values of two `sin` applications are saved because they are arguments
in subsequent applications of the `jax.checkpoint`-decorated `g` function, and
inputs to a `jax.checkpoint`-decorated function may be saved. But no values of
`cos` applications are saved.
+++ {"id": "CyRR3mTpjRtl"}
To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization _policy_, for example to save only results whose computation is likely FLOP-bound:
To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerialization _policy_. Here is an example that saves only the results of `dot` operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):
```{code-cell}
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
```
@ -141,7 +146,7 @@ print_fwd_bwd(f, W1, W2, W3, x)
```
```{code-cell}
# using jax.checkpoint with policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims:
# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
```
@ -220,11 +225,11 @@ def f_vjp_checkpoint(x):
In words, this alternative implementation doesn't compute `g_vjp`, or the residual values in its closure, on the forward pass. Instead it only computes them in the backward pass `f_bwd2`. That means `f_vjp_checkpoint` requires less memory: if `g` and `h` each required similar amounts of memory for their residuals, each much larger than `x`, then the function produced by `f_vjp_checkpoint(x)` requires half the memory as that of `f_vjp(x)`!
The cost we pay is redundant work: in `f_bwd2` we must re-evaluate `g(x)` just to discard its value (in the underscore variable on the line `_, g_vjp = jax.vjp(g, x)`).
The cost we pay is redundant work: in `f_bwd2` we must re-evaluate `g(x)` as part of `jax.vjp(g, x)` just to discard its value (in the underscore variable on the line `_, g_vjp = jax.vjp(g, x)`).
+++ {"id": "LqTrjPoGqrK7"}
We can get this autodiff behavior without having to write VJP functions directly by instead using `jax.checkpoint` in an alternative definition of `f`:
We can get this VJP behavior in autodiff — without having to write VJP functions directly — by instead using `jax.checkpoint` in an alternative definition of the original function `f`:
```{code-cell}
def f_checkpoint(x):
@ -256,11 +261,11 @@ def f_checkpoint_grad(x):
+++ {"id": "tfssPyb2sQgw"}
In general, `jax.checkpoint(foo)` is a new function which has the same input-ouptut behavior as `foo`, but behaves differently under autodiff, particularly under `jax.linearize` and `jax.vjp` (and their wrappers, like `jax.grad`) but not `jax.jvp`. When differentiated, only the input to a `jax.checkpoint`-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from `foo` and its Jacobian coefficient values needed for the backward pass) are recomputed.
In general, `jax.checkpoint(foo)` is a new function which has the same input-output behavior as `foo`, but behaves differently under autodiff, particularly under `jax.linearize` and `jax.vjp` (and their wrappers, like `jax.grad`) but not `jax.jvp`. When differentiated, only the input to a `jax.checkpoint`-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates from `foo` and its Jacobian coefficient values needed for the backward pass) are recomputed.
+++ {"id": "HVvS_S5zsVZ-"}
Notice that with `f = lambda x: h(g(x))` is the funciton we want to differentiate, i.e. if we want to apply `jax.grad(f)`, we don't get any memory savings by applying `jax.checkpoint` to `f` itself. That's because evaluating `jax.grad(jax.checkpoint(f))(x)` would lead to a computation like:
Notice that if `f = lambda x: h(g(x))` is the function we want to differentiate, i.e. if we want to apply `jax.grad(f)`, we don't get any memory savings by applying `jax.checkpoint` to `f` itself. That's because evaluating `jax.grad(jax.checkpoint(f))(x)` would lead to a computation like:
1. run the forward pass, discarding all residuals;
2. immediately re-run the forward pass, saving residuals;
3. run the backward pass, consuming residuals from step 2.
@ -315,7 +320,7 @@ As shown so far, using `jax.checkpoint` switches from one extreme to another:
To operate between these two extremes, saving some things and not others, we can carefully place `jax.checkpoint` decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.
So an alternative is to use the `policy` argument to `jax.checkpoint`. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on `jax.checkpoint_policies`, like `jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims`, since the API for writing custom policy callables is considered internal.
So an alternative is to use the `policy` argument to `jax.checkpoint`. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes on `jax.checkpoint_policies`, like `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`, since the API for writing custom policy callables is considered internal.
For example, consider this function to be differentiated:
@ -347,10 +352,10 @@ print_saved_residuals(loss, params, x, y)
+++ {"id": "Mep7AReRNHby"}
Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy `jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims`:
Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policy `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`:
```{code-cell}
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
```
@ -394,8 +399,8 @@ Another policy which refers to names is `jax.checkpoint_policies.save_only_these
Some of the policies are:
* `everything_saveable` (the default strategy, as if `jax.checkpoint` were not being used at all)
* `nothing_saveable` (i.e. rematerialize everything, as if a custom policy were not being used at all)
* `checkpoint_dots`
* `checkpoint_dots_with_no_batch_dims`
* `dots_saveable` or its alias `checkpoint_dots`
* `dots_with_no_batch_dims_saveable` or its alias `checkpoint_dots_with_no_batch_dims`
* `save_anything_but_these_names` (save any values except for the output of
`checkpoint_name` with any of the names given)
* `save_any_names_but_these` (save only named values, i.e. any outputs of
@ -485,9 +490,9 @@ print_fwd_bwd(f, 3.)
When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit` to a function which contains a `jax.grad` call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **`jax.checkpoint` often isn't needed for differentiated functions under a `jax.jit`**. XLA will optimize things for you.
One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations between multiple control flow primitives, e.g. between a forward-pass `scan` and the corresponding backward-pass `scan`, aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function we pass to `jax.lax.scan`.
One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.
For example, one common pattern in large Transformer models is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:
For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:
+++ {"id": "BUeqKFRS5yPU"}
@ -505,7 +510,7 @@ def net(params: ParamsList, x: jnp.ndarray):
+++ {"id": "38xDcRwW518P"}
We would instead use `jax.lax.scan`:
We would instead iterate over the layer application with `jax.lax.scan`:
+++ {"id": "ZU2fwYoG6A4z"}
@ -528,7 +533,7 @@ def net(all_weights, all_biases, x):
+++ {"id": "A-NbCn8A6TFK"}
This scan-over-layers version reduces compile times, but by foiling some XLA optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use `jax.checkpoint` on the scanned function:
This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would use `jax.checkpoint` on the scanned function:
+++ {"id": "iHVVNVdO66Dv"}
@ -536,7 +541,7 @@ This scan-over-layers version reduces compile times, but by foiling some XLA opt
from functools import partial
@partial(jax.checkpoint,
policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)

View File

@ -57,12 +57,13 @@ def nothing_saveable(*_, **__) -> bool:
# This is the effective policy when using jax.remat without explicit policy.
return False
def checkpoint_dots(prim, *_, **__) -> bool:
def dots_saveable(prim, *_, **__) -> bool:
# Matrix multiplies are expensive, so let's save them (and nothing else).
return prim in {lax_internal.dot_general_p,
lax_convolution.conv_general_dilated_p}
checkpoint_dots = dots_saveable
def dot_with_no_batch_dims(prim, *_, **params) -> bool:
def dot_with_no_batch_dims_saveable(prim, *_, **params) -> bool:
# This is a useful heuristic for transformers.
if prim is lax_internal.dot_general_p:
(_, _), (lhs_b, rhs_b) = params['dimension_numbers']
@ -111,8 +112,10 @@ def save_from_both_policies(policy_1, policy_2):
checkpoint_policies = types.SimpleNamespace(
everything_saveable=everything_saveable,
nothing_saveable=nothing_saveable,
checkpoint_dots=checkpoint_dots,
checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims,
dots_saveable=dots_saveable,
checkpoint_dots=dots_saveable,
dots_with_no_batch_dims_saveable=dot_with_no_batch_dims_saveable,
checkpoint_dots_with_no_batch_dims=dot_with_no_batch_dims_saveable,
save_anything_except_these_names=save_anything_except_these_names,
save_any_names_but_these=save_any_names_but_these,
save_only_these_names=save_only_these_names,