From 672a013b3a90782a7ebc7c5646fb9b44d8183e3a Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 28 Aug 2024 10:46:47 -0700 Subject: [PATCH] docs: prefer "summary" to "tl;dr" This was common and we typically just mean "summary." --- docs/debugging.md | 4 ++-- docs/debugging/checkify_guide.md | 2 +- docs/debugging/flags.md | 4 ++-- docs/debugging/index.md | 8 ++++---- docs/debugging/print_breakpoint.md | 5 +++-- docs/installation.md | 2 +- .../Custom_derivative_rules_for_Python_code.ipynb | 2 +- docs/notebooks/Custom_derivative_rules_for_Python_code.md | 2 +- docs/notebooks/autodiff_remat.ipynb | 2 +- docs/notebooks/autodiff_remat.md | 2 +- 10 files changed, 17 insertions(+), 16 deletions(-) diff --git a/docs/debugging.md b/docs/debugging.md index 94384035c..1e8501f99 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -23,7 +23,7 @@ Let's begin with {func}`jax.debug.print`. ## JAX `debug.print` for high-level -**TL;DR** Here is a rule of thumb: +Here is a rule of thumb: - Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others. - Use Python {func}`print` for static values, such as dtypes and array shapes. @@ -113,7 +113,7 @@ To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`a ## JAX `debug.breakpoint` for `pdb`-like debugging -**TL;DR** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. +**Summary:** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack. diff --git a/docs/debugging/checkify_guide.md b/docs/debugging/checkify_guide.md index 2dad9b863..8b012e97e 100644 --- a/docs/debugging/checkify_guide.md +++ b/docs/debugging/checkify_guide.md @@ -2,7 +2,7 @@ -**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: +**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python from jax.experimental import checkify diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md index 1cf1829e5..13e34a6c3 100644 --- a/docs/debugging/flags.md +++ b/docs/debugging/flags.md @@ -6,7 +6,7 @@ JAX offers flags and context managers that enable catching errors more easily. ## `jax_debug_nans` configuration option and context manager -**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code). +**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code). `jax_debug_nans` is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled -- when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN. @@ -41,7 +41,7 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! ## `jax_disable_jit` configuration option and context manager -**TL;DR** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb` +**Summary:** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb` `jax_disable_jit` is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like `jax.lax.cond` and `jax.lax.scan`). diff --git a/docs/debugging/index.md b/docs/debugging/index.md index 724827f83..46523d681 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -2,7 +2,7 @@ -Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more. +Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has summaries and you can click the "Read more" links at the bottom to learn more. Table of contents: @@ -12,7 +12,7 @@ Table of contents: ## [Interactive inspection with `jax.debug`](print_breakpoint) - **TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, + **Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack: ```python @@ -38,7 +38,7 @@ Click [here](print_breakpoint) to learn more! ## [Functional error checks with `jax.experimental.checkify`](checkify_guide) - **TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: + **Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python from jax.experimental import checkify @@ -81,7 +81,7 @@ Click [here](checkify_guide) to learn more! ## [Throwing Python errors with JAX's debug flags](flags) -**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. +**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. ```python import jax diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index d7cb68bd1..d33498697 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -7,7 +7,8 @@ inside of JIT-ted functions. ## Debugging with `jax.debug.print` and other debugging callbacks -**TL;DR** Use {func}`jax.debug.print` to print traced array values to stdout in `jit`- and `pmap`-decorated functions: +**Summary:** Use {func}`jax.debug.print` to print traced array values to +stdout in compiled (e.g. `jax.jit` or `jax.pmap`-decorated) functions: ```python import jax @@ -236,7 +237,7 @@ Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronizat ## Interactive inspection with `jax.debug.breakpoint()` -**TL;DR** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values: +**Summary:** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values: ```python @jax.jit diff --git a/docs/installation.md b/docs/installation.md index bd0473d89..4a831750e 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -7,7 +7,7 @@ Using JAX requires installing two packages: `jax`, which is pure Python and cross-platform, and `jaxlib` which contains compiled binaries, and requires different builds for different operating systems and accelerators. -**TL;DR** For most users, a typical JAX installation may look something like this: +**Summary:** For most users, a typical JAX installation may look something like this: * **CPU-only (Linux/macOS/Windows)** ``` diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index ec85f6e63..6767b33a2 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -30,7 +30,7 @@ "id": "9Fg3NFNY-2RY" }, "source": [ - "## TL;DR" + "## Summary" ] }, { diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 6c948650f..000d48c49 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -32,7 +32,7 @@ For an introduction to JAX's automatic differentiation API, see [The Autodiff Co +++ {"id": "9Fg3NFNY-2RY"} -## TL;DR +## Summary +++ {"id": "ZgMNRtXyWIW8"} diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index 041cf6531..82381838a 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -27,7 +27,7 @@ "id": "qaIsQSh1XoKF" }, "source": [ - "### TL;DR\n", + "### Summary\n", "\n", "Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.\n", "\n", diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index a4fb27c58..0a6c84b2d 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -24,7 +24,7 @@ import jax.numpy as jnp +++ {"id": "qaIsQSh1XoKF"} -### TL;DR +### Summary Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.