docs: runtime debugging tweaks

Mainly make titles/headings easier to read, by swapping code for words
and not using headings as links.
This commit is contained in:
Roy Frostig 2024-08-28 11:01:43 -07:00
parent 763d600508
commit 42de34263f
2 changed files with 14 additions and 9 deletions

View File

@ -10,7 +10,9 @@ Table of contents:
* [Functional error checks with jax.experimental.checkify](checkify_guide)
* [Throwing Python errors with JAXs debug flags](flags)
## [Interactive inspection with `jax.debug`](print_breakpoint)
## Interactive inspection with `jax.debug`
Complete guide [here](print_breakpoint)
**Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
@ -34,9 +36,11 @@ Table of contents:
# 🤯 0.9092974662780762 🤯
```
Click [here](print_breakpoint) to learn more!
[Read more](print_breakpoint).
## [Functional error checks with `jax.experimental.checkify`](checkify_guide)
## Functional error checks with `jax.experimental.checkify`
Complete guide [here](checkify_guide)
**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
@ -77,9 +81,11 @@ Click [here](print_breakpoint) to learn more!
# ValueError: nan generated by primitive sin at <...>:8 (f)
```
Click [here](checkify_guide) to learn more!
[Read more](checkify_guide).
## [Throwing Python errors with JAX's debug flags](flags)
## Throwing Python errors with JAX's debug flags
Complete guide [here](flags)
**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
@ -92,7 +98,7 @@ def f(x, y):
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
```
Click [here](flags) to learn more!
[Read more](flags).
```{toctree}
:caption: Read more

View File

@ -1,9 +1,9 @@
# `jax.debug.print` and `jax.debug.breakpoint`
# Compiled prints and breakpoints
<!--* freshness: { reviewed: '2024-03-13' } *-->
The {mod}`jax.debug` package offers some useful tools for inspecting values
inside of JIT-ted functions.
inside of compiled functions.
## Debugging with `jax.debug.print` and other debugging callbacks
@ -27,7 +27,6 @@ f(2.)
# 🤯 0.9092974662780762 🤯
```
<!-- mattjj added this line -->
With some transformations, like `jax.grad` and `jax.vmap`, you can use Python's builtin `print` function to print out numerical values. But `print` won't work with `jax.jit` or `jax.pmap` because those transformations delay numerical evaluation. So use `jax.debug.print` instead!
Semantically, `jax.debug.print` is roughly equivalent to the following Python function