Roy Frostig 3c223cd253 docs: tidy up titles and headings
This shortens some titles and makes them more consistent. It also
removes "JAX" from several titles ("in JAX", "for JAX", "JAX's",
etc.). Since these are JAX docs, that ought to be clear from context.
2024-08-13 11:53:57 -07:00

106 lines
3.2 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Debugging runtime values
<!--* freshness: { reviewed: '2024-04-11' } *-->
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more.
Table of contents:
* [Interactive inspection with `jax.debug`](print_breakpoint)
* [Functional error checks with jax.experimental.checkify](checkify_guide)
* [Throwing Python errors with JAXs debug flags](flags)
## [Interactive inspection with `jax.debug`](print_breakpoint)
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
```python
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.breakpoint()
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯
```
Click [here](print_breakpoint) to learn more!
## [Functional error checks with `jax.experimental.checkify`](checkify_guide)
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
```python
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
```
You can also use checkify to automatically add common checks:
```python
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
```
Click [here](checkify_guide) to learn more!
## [Throwing Python errors with JAX's debug flags](flags)
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
```python
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
```
Click [here](flags) to learn more!
```{toctree}
:caption: Read more
:maxdepth: 1
print_breakpoint
checkify_guide
flags
```