mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Fix typo in debugging docs
This commit is contained in:
parent
6276194e1c
commit
27b313b287
@ -5,7 +5,7 @@ inside of JIT-ted functions.
|
||||
|
||||
## Debugging with `jax.debug.print` and other debugging callbacks
|
||||
|
||||
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions:
|
||||
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jit`-,`pmap`-, and `pjit`-decorated functions:
|
||||
|
||||
```python
|
||||
import jax
|
||||
@ -268,8 +268,8 @@ def breakpoint_if_nonfinite(x):
|
||||
pass
|
||||
def false_fn(x):
|
||||
jax.debug.breakpoint()
|
||||
lax.cond(has_nan, true_fn, false_fn, x)
|
||||
|
||||
lax.cond(is_finite, true_fn, false_fn, x)
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
z = x / y
|
||||
@ -291,5 +291,5 @@ Because `jax.debug.breakpoint` is a just an application of `jax.debug.callback`,
|
||||
* Can inspect many values at the same time, up and down the call stack
|
||||
|
||||
#### Limitations
|
||||
* Need to potentially use many breakpoints pinpoint the source of an error
|
||||
* Need to potentially use many breakpoints to pinpoint the source of an error
|
||||
* Materializes many intermediates
|
||||
|
Loading…
x
Reference in New Issue
Block a user