Roy Frostig 3c223cd253 docs: tidy up titles and headings
This shortens some titles and makes them more consistent. It also
removes "JAX" from several titles ("in JAX", "for JAX", "JAX's",
etc.). Since these are JAX docs, that ought to be clear from context.
2024-08-13 11:53:57 -07:00

3.2 KiB
Raw Blame History

Debugging runtime values

Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more.

Table of contents:

Interactive inspection with jax.debug

TL;DR Use {func}jax.debug.print to print values to stdout in jax.jit-,jax.pmap-, and pjit-decorated functions, and {func}jax.debug.breakpoint to pause execution of your compiled function to inspect values in the call stack:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.breakpoint()
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y

f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯

Click here to learn more!

Functional error checks with jax.experimental.checkify

TL;DR Checkify lets you add jit-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the checkify.checkify transformation together with the assert-like checkify.check function to add runtime checks to JAX code:

from jax.experimental import checkify
import jax
import jax.numpy as jnp

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  y = x[i]
  z = jnp.sin(y)
  return z

jittable_f = checkify.checkify(f)

err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))

You can also use checkify to automatically add common checks:

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

Click here to learn more!

Throwing Python errors with JAX's debug flags

TL;DR Enable the jax_debug_nans flag to automatically detect when NaNs are produced in jax.jit-compiled code (but not in jax.pmap or jax.pjit-compiled code) and enable the jax_disable_jit flag to disable JIT-compilation, enabling use of traditional Python debugging tools like print and pdb.

import jax
jax.config.update("jax_debug_nans", True)

def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception!

Click here to learn more!

:caption: Read more
:maxdepth: 1

print_breakpoint
checkify_guide
flags