Fix typo in debugging docs

This commit is contained in:
Sharad Vikram 2022-08-25 15:28:01 -07:00
parent 6276194e1c
commit 27b313b287

View File

@ -5,7 +5,7 @@ inside of JIT-ted functions.
## Debugging with `jax.debug.print` and other debugging callbacks
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions:
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jit`-,`pmap`-, and `pjit`-decorated functions:
```python
import jax
@ -268,8 +268,8 @@ def breakpoint_if_nonfinite(x):
pass
def false_fn(x):
jax.debug.breakpoint()
lax.cond(has_nan, true_fn, false_fn, x)
lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
@ -291,5 +291,5 @@ Because `jax.debug.breakpoint` is a just an application of `jax.debug.callback`,
* Can inspect many values at the same time, up and down the call stack
#### Limitations
* Need to potentially use many breakpoints pinpoint the source of an error
* Need to potentially use many breakpoints to pinpoint the source of an error
* Materializes many intermediates