rocm_jax/docs/debugging.md
Roy Frostig 7b29791200 docs: shorten/clarify some page titles
Shorten the titles for derivative rules and manual parallelism, and go
for canonical wording in the parallel programming intro (typically we
"shard" data, and "partition" computation, as part of parallel
programming).
2024-08-28 11:22:06 -07:00

5.8 KiB

jupytext kernelspec
formats text_representation
md:myst
extension format_name format_version jupytext_version
.md myst 0.13 1.16.4
display_name language name
Python 3 python python3

(debugging)=

Introduction to debugging

This section introduces you to a set of built-in JAX debugging methods — {func}jax.debug.print, {func}jax.debug.breakpoint, and {func}jax.debug.callback — that you can use with various JAX transformations.

Let's begin with {func}jax.debug.print.

jax.debug.print for simple inspection

Here is a rule of thumb:

  • Use {func}jax.debug.print for traced (dynamic) array values with {func}jax.jit, {func}jax.vmap and others.
  • Use Python {func}print for static values, such as dtypes and array shapes.

Recall from {ref}jit-compilation that when transforming a function with {func}jax.jit, the Python code is executed with abstract tracers in place of your arrays. Because of this, the Python {func}print function will only print this tracer value:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  print("print(x) ->", x)
  y = jnp.sin(x)
  print("print(y) ->", y)
  return y

result = f(2.)

Python's print executes at trace-time, before the runtime values exist. If you want to print the actual runtime values, you can use {func}jax.debug.print:

@jax.jit
def f(x):
  jax.debug.print("jax.debug.print(x) -> {x}", x=x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {y}", y=y)
  return y

result = f(2.)

Similarly, within {func}jax.vmap, using Python's print will only print the tracer; to print the values being mapped over, use {func}jax.debug.print:

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {}", y)
  return y

xs = jnp.arange(3.)

result = jax.vmap(f)(xs)

Here's the result with {func}jax.lax.map, which is a sequential map rather than a vectorization:

result = jax.lax.map(f, xs)

Notice the order is different, as {func}jax.vmap and {func}jax.lax.map compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.

Below is an example with {func}jax.grad, where {func}jax.debug.print only prints the forward pass. In this case, the behavior is similar to Python's {func}print, but it's consistent if you apply {func}jax.jit during the call.

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  return x ** 2

result = jax.grad(f)(1.)

Sometimes, when the arguments don't depend on one another, calls to {func}jax.debug.print may print them in a different order when staged out with a JAX transformation. If you need the original order, such as x: ... first and then y: ... second, add the ordered=True parameter.

For example:

@jax.jit
def f(x, y):
  jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
  jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
  return x + y

f(1, 2)

To learn more about {func}jax.debug.print and its Sharp Bits, refer to {ref}advanced-debugging.

jax.debug.breakpoint for pdb-like debugging

Summary: Use {func}jax.debug.breakpoint to pause the execution of your JAX program to inspect values.

To pause your compiled JAX program during certain points during debugging, you can use {func}jax.debug.breakpoint. The prompt is similar to Python pdb, and it allows you to inspect the values in the call stack. In fact, {func}jax.debug.breakpoint is an application of {func}jax.debug.callback that captures information about the call stack.

To print all available commands during a breakpoint debugging session, use the help command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in {ref}advanced-debugging.)

Here is an example of what a debugger session might look like:

:tags: [skip-execution]

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution

JAX debugger

For value-dependent breakpointing, you can use runtime conditionals like {func}jax.lax.cond:

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  jax.lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z

f(2., 1.) # ==> No breakpoint
:tags: [skip-execution]

f(2., 0.) # ==> Pauses during execution

jax.debug.callback for more control during debugging

Both {func}jax.debug.print and {func}jax.debug.breakpoint are implemented using the more flexible {func}jax.debug.callback, which gives greater control over the host-side logic executed via a Python callback. It is compatible with {func}jax.jit, {func}jax.vmap, {func}jax.grad and other transformations (refer to the {ref}external-callbacks-flavors-of-callback table in {ref}external-callbacks for more information).

For example:

import logging

def log_value(x):
  logging.warning(f'Logged value: {x}')

@jax.jit
def f(x):
  jax.debug.callback(log_value, x)
  return x

f(1.0);

This callback is compatible with other transformations, including {func}jax.vmap and {func}jax.grad:

x = jnp.arange(5.0)
jax.vmap(f)(x);
jax.grad(f)(1.0);

This can make {func}jax.debug.callback useful for general-purpose debugging.

You can learn more about {func}jax.debug.callback and other kinds of JAX callbacks in {ref}external-callbacks.

Next steps

Check out the {ref}advanced-debugging to learn more about debugging in JAX.