mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00

This shortens some titles and makes them more consistent. It also removes "JAX" from several titles ("in JAX", "for JAX", "JAX's", etc.). Since these are JAX docs, that ought to be clear from context.
106 lines
3.2 KiB
Markdown
106 lines
3.2 KiB
Markdown
# Debugging runtime values
|
||
|
||
<!--* freshness: { reviewed: '2024-04-11' } *-->
|
||
|
||
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more.
|
||
|
||
Table of contents:
|
||
|
||
* [Interactive inspection with `jax.debug`](print_breakpoint)
|
||
* [Functional error checks with jax.experimental.checkify](checkify_guide)
|
||
* [Throwing Python errors with JAX’s debug flags](flags)
|
||
|
||
## [Interactive inspection with `jax.debug`](print_breakpoint)
|
||
|
||
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
|
||
and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
|
||
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
|
||
@jax.jit
|
||
def f(x):
|
||
jax.debug.print("🤯 {x} 🤯", x=x)
|
||
y = jnp.sin(x)
|
||
jax.debug.breakpoint()
|
||
jax.debug.print("🤯 {y} 🤯", y=y)
|
||
return y
|
||
|
||
f(2.)
|
||
# Prints:
|
||
# 🤯 2.0 🤯
|
||
# Enters breakpoint to inspect values!
|
||
# 🤯 0.9092974662780762 🤯
|
||
```
|
||
|
||
Click [here](print_breakpoint) to learn more!
|
||
|
||
## [Functional error checks with `jax.experimental.checkify`](checkify_guide)
|
||
|
||
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
|
||
|
||
```python
|
||
from jax.experimental import checkify
|
||
import jax
|
||
import jax.numpy as jnp
|
||
|
||
def f(x, i):
|
||
checkify.check(i >= 0, "index needs to be non-negative!")
|
||
y = x[i]
|
||
z = jnp.sin(y)
|
||
return z
|
||
|
||
jittable_f = checkify.checkify(f)
|
||
|
||
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
|
||
print(err.get())
|
||
# >> index needs to be non-negative! (check failed at <...>:6 (f))
|
||
```
|
||
|
||
You can also use checkify to automatically add common checks:
|
||
|
||
```python
|
||
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
|
||
checked_f = checkify.checkify(f, errors=errors)
|
||
|
||
err, z = checked_f(jnp.ones((5,)), 100)
|
||
err.throw()
|
||
# ValueError: out-of-bounds indexing at <..>:7 (f)
|
||
|
||
err, z = checked_f(jnp.ones((5,)), -1)
|
||
err.throw()
|
||
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
|
||
|
||
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
|
||
err.throw()
|
||
# ValueError: nan generated by primitive sin at <...>:8 (f)
|
||
```
|
||
|
||
Click [here](checkify_guide) to learn more!
|
||
|
||
## [Throwing Python errors with JAX's debug flags](flags)
|
||
|
||
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
|
||
|
||
```python
|
||
import jax
|
||
jax.config.update("jax_debug_nans", True)
|
||
|
||
def f(x, y):
|
||
return x / y
|
||
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
|
||
```
|
||
|
||
Click [here](flags) to learn more!
|
||
|
||
```{toctree}
|
||
:caption: Read more
|
||
:maxdepth: 1
|
||
|
||
print_breakpoint
|
||
checkify_guide
|
||
flags
|
||
```
|
||
|