mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
docs: prefer "summary" to "tl;dr"
This was common and we typically just mean "summary."
This commit is contained in:
parent
46957052c5
commit
672a013b3a
@ -23,7 +23,7 @@ Let's begin with {func}`jax.debug.print`.
|
||||
|
||||
## JAX `debug.print` for high-level
|
||||
|
||||
**TL;DR** Here is a rule of thumb:
|
||||
Here is a rule of thumb:
|
||||
|
||||
- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others.
|
||||
- Use Python {func}`print` for static values, such as dtypes and array shapes.
|
||||
@ -113,7 +113,7 @@ To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`a
|
||||
|
||||
## JAX `debug.breakpoint` for `pdb`-like debugging
|
||||
|
||||
**TL;DR** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values.
|
||||
**Summary:** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values.
|
||||
|
||||
To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack.
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
<!--* freshness: { reviewed: '2023-02-28' } *-->
|
||||
|
||||
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
|
||||
**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
|
||||
|
||||
```python
|
||||
from jax.experimental import checkify
|
||||
|
@ -6,7 +6,7 @@ JAX offers flags and context managers that enable catching errors more easily.
|
||||
|
||||
## `jax_debug_nans` configuration option and context manager
|
||||
|
||||
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code).
|
||||
**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code).
|
||||
|
||||
`jax_debug_nans` is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled -- when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN.
|
||||
|
||||
@ -41,7 +41,7 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
|
||||
|
||||
## `jax_disable_jit` configuration option and context manager
|
||||
|
||||
**TL;DR** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`
|
||||
**Summary:** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`
|
||||
|
||||
`jax_disable_jit` is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like `jax.lax.cond` and `jax.lax.scan`).
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
<!--* freshness: { reviewed: '2024-04-11' } *-->
|
||||
|
||||
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more.
|
||||
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has summaries and you can click the "Read more" links at the bottom to learn more.
|
||||
|
||||
Table of contents:
|
||||
|
||||
@ -12,7 +12,7 @@ Table of contents:
|
||||
|
||||
## [Interactive inspection with `jax.debug`](print_breakpoint)
|
||||
|
||||
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
|
||||
**Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
|
||||
and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
|
||||
|
||||
```python
|
||||
@ -38,7 +38,7 @@ Click [here](print_breakpoint) to learn more!
|
||||
|
||||
## [Functional error checks with `jax.experimental.checkify`](checkify_guide)
|
||||
|
||||
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
|
||||
**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
|
||||
|
||||
```python
|
||||
from jax.experimental import checkify
|
||||
@ -81,7 +81,7 @@ Click [here](checkify_guide) to learn more!
|
||||
|
||||
## [Throwing Python errors with JAX's debug flags](flags)
|
||||
|
||||
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
|
||||
**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
@ -7,7 +7,8 @@ inside of JIT-ted functions.
|
||||
|
||||
## Debugging with `jax.debug.print` and other debugging callbacks
|
||||
|
||||
**TL;DR** Use {func}`jax.debug.print` to print traced array values to stdout in `jit`- and `pmap`-decorated functions:
|
||||
**Summary:** Use {func}`jax.debug.print` to print traced array values to
|
||||
stdout in compiled (e.g. `jax.jit` or `jax.pmap`-decorated) functions:
|
||||
|
||||
```python
|
||||
import jax
|
||||
@ -236,7 +237,7 @@ Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronizat
|
||||
|
||||
## Interactive inspection with `jax.debug.breakpoint()`
|
||||
|
||||
**TL;DR** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values:
|
||||
**Summary:** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values:
|
||||
|
||||
```python
|
||||
@jax.jit
|
||||
|
@ -7,7 +7,7 @@ Using JAX requires installing two packages: `jax`, which is pure Python and
|
||||
cross-platform, and `jaxlib` which contains compiled binaries, and requires
|
||||
different builds for different operating systems and accelerators.
|
||||
|
||||
**TL;DR** For most users, a typical JAX installation may look something like this:
|
||||
**Summary:** For most users, a typical JAX installation may look something like this:
|
||||
|
||||
* **CPU-only (Linux/macOS/Windows)**
|
||||
```
|
||||
|
@ -30,7 +30,7 @@
|
||||
"id": "9Fg3NFNY-2RY"
|
||||
},
|
||||
"source": [
|
||||
"## TL;DR"
|
||||
"## Summary"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -32,7 +32,7 @@ For an introduction to JAX's automatic differentiation API, see [The Autodiff Co
|
||||
|
||||
+++ {"id": "9Fg3NFNY-2RY"}
|
||||
|
||||
## TL;DR
|
||||
## Summary
|
||||
|
||||
+++ {"id": "ZgMNRtXyWIW8"}
|
||||
|
||||
|
@ -27,7 +27,7 @@
|
||||
"id": "qaIsQSh1XoKF"
|
||||
},
|
||||
"source": [
|
||||
"### TL;DR\n",
|
||||
"### Summary\n",
|
||||
"\n",
|
||||
"Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.\n",
|
||||
"\n",
|
||||
|
@ -24,7 +24,7 @@ import jax.numpy as jnp
|
||||
|
||||
+++ {"id": "qaIsQSh1XoKF"}
|
||||
|
||||
### TL;DR
|
||||
### Summary
|
||||
|
||||
Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user