# Debugging runtime values Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more. Table of contents: * [Interactive inspection with `jax.debug`](print_breakpoint) * [Functional error checks with jax.experimental.checkify](checkify_guide) * [Throwing Python errors with JAX’s debug flags](flags) ## [Interactive inspection with `jax.debug`](print_breakpoint) **TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack: ```python import jax import jax.numpy as jnp @jax.jit def f(x): jax.debug.print("🤯 {x} 🤯", x=x) y = jnp.sin(x) jax.debug.breakpoint() jax.debug.print("🤯 {y} 🤯", y=y) return y f(2.) # Prints: # 🤯 2.0 🤯 # Enters breakpoint to inspect values! # 🤯 0.9092974662780762 🤯 ``` Click [here](print_breakpoint) to learn more! ## [Functional error checks with `jax.experimental.checkify`](checkify_guide) **TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python from jax.experimental import checkify import jax import jax.numpy as jnp def f(x, i): checkify.check(i >= 0, "index needs to be non-negative!") y = x[i] z = jnp.sin(y) return z jittable_f = checkify.checkify(f) err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1) print(err.get()) # >> index needs to be non-negative! (check failed at <...>:6 (f)) ``` You can also use checkify to automatically add common checks: ```python errors = checkify.user_checks | checkify.index_checks | checkify.float_checks checked_f = checkify.checkify(f, errors=errors) err, z = checked_f(jnp.ones((5,)), 100) err.throw() # ValueError: out-of-bounds indexing at <..>:7 (f) err, z = checked_f(jnp.ones((5,)), -1) err.throw() # ValueError: index needs to be non-negative! (check failed at <…>:6 (f)) err, z = checked_f(jnp.array([jnp.inf, 1]), 0) err.throw() # ValueError: nan generated by primitive sin at <...>:8 (f) ``` Click [here](checkify_guide) to learn more! ## [Throwing Python errors with JAX's debug flags](flags) **TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. ```python import jax jax.config.update("jax_debug_nans", True) def f(x, y): return x / y jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! ``` Click [here](flags) to learn more! ```{toctree} :caption: Read more :maxdepth: 1 print_breakpoint checkify_guide flags ```