docs: prefer "summary" to "tl;dr"

This was common and we typically just mean "summary."
This commit is contained in:
Roy Frostig 2024-08-28 10:46:47 -07:00
parent 46957052c5
commit 672a013b3a
10 changed files with 17 additions and 16 deletions

View File

@ -23,7 +23,7 @@ Let's begin with {func}`jax.debug.print`.
## JAX `debug.print` for high-level
**TL;DR** Here is a rule of thumb:
Here is a rule of thumb:
- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others.
- Use Python {func}`print` for static values, such as dtypes and array shapes.
@ -113,7 +113,7 @@ To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`a
## JAX `debug.breakpoint` for `pdb`-like debugging
**TL;DR** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values.
**Summary:** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values.
To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack.

View File

@ -2,7 +2,7 @@
<!--* freshness: { reviewed: '2023-02-28' } *-->
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
```python
from jax.experimental import checkify

View File

@ -6,7 +6,7 @@ JAX offers flags and context managers that enable catching errors more easily.
## `jax_debug_nans` configuration option and context manager
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code).
**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code).
`jax_debug_nans` is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled -- when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN.
@ -41,7 +41,7 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
## `jax_disable_jit` configuration option and context manager
**TL;DR** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`
**Summary:** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`
`jax_disable_jit` is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like `jax.lax.cond` and `jax.lax.scan`).

View File

@ -2,7 +2,7 @@
<!--* freshness: { reviewed: '2024-04-11' } *-->
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more.
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has summaries and you can click the "Read more" links at the bottom to learn more.
Table of contents:
@ -12,7 +12,7 @@ Table of contents:
## [Interactive inspection with `jax.debug`](print_breakpoint)
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
**Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
```python
@ -38,7 +38,7 @@ Click [here](print_breakpoint) to learn more!
## [Functional error checks with `jax.experimental.checkify`](checkify_guide)
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
```python
from jax.experimental import checkify
@ -81,7 +81,7 @@ Click [here](checkify_guide) to learn more!
## [Throwing Python errors with JAX's debug flags](flags)
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
```python
import jax

View File

@ -7,7 +7,8 @@ inside of JIT-ted functions.
## Debugging with `jax.debug.print` and other debugging callbacks
**TL;DR** Use {func}`jax.debug.print` to print traced array values to stdout in `jit`- and `pmap`-decorated functions:
**Summary:** Use {func}`jax.debug.print` to print traced array values to
stdout in compiled (e.g. `jax.jit` or `jax.pmap`-decorated) functions:
```python
import jax
@ -236,7 +237,7 @@ Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronizat
## Interactive inspection with `jax.debug.breakpoint()`
**TL;DR** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values:
**Summary:** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values:
```python
@jax.jit

View File

@ -7,7 +7,7 @@ Using JAX requires installing two packages: `jax`, which is pure Python and
cross-platform, and `jaxlib` which contains compiled binaries, and requires
different builds for different operating systems and accelerators.
**TL;DR** For most users, a typical JAX installation may look something like this:
**Summary:** For most users, a typical JAX installation may look something like this:
* **CPU-only (Linux/macOS/Windows)**
```

View File

@ -30,7 +30,7 @@
"id": "9Fg3NFNY-2RY"
},
"source": [
"## TL;DR"
"## Summary"
]
},
{

View File

@ -32,7 +32,7 @@ For an introduction to JAX's automatic differentiation API, see [The Autodiff Co
+++ {"id": "9Fg3NFNY-2RY"}
## TL;DR
## Summary
+++ {"id": "ZgMNRtXyWIW8"}

View File

@ -27,7 +27,7 @@
"id": "qaIsQSh1XoKF"
},
"source": [
"### TL;DR\n",
"### Summary\n",
"\n",
"Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.\n",
"\n",

View File

@ -24,7 +24,7 @@ import jax.numpy as jnp
+++ {"id": "qaIsQSh1XoKF"}
### TL;DR
### Summary
Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.