2023-12-11 21:10:29 +00:00
---
jupytext:
formats: md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
2024-08-27 15:23:13 -07:00
jupytext_version: 1.16.4
2023-12-11 21:10:29 +00:00
kernelspec:
display_name: Python 3
language: python
name: python3
---
2023-10-30 13:58:29 -07:00
2023-12-11 21:10:29 +00:00
(debugging)=
2024-04-17 16:08:38 -07:00
# Introduction to debugging
2023-10-30 13:58:29 -07:00
2024-06-21 14:50:02 -07:00
<!-- * freshness: { reviewed: '2024 - 05 - 10' } * -->
2024-04-17 16:08:38 -07:00
This section introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print` , {func}`jax.debug.breakpoint` , and {func}`jax.debug.callback` — that you can use with various JAX transformations.
2023-12-11 21:10:29 +00:00
Let's begin with {func}`jax.debug.print` .
2024-08-28 11:19:53 -07:00
## `jax.debug.print` for simple inspection
2023-12-11 21:10:29 +00:00
2024-08-28 10:46:47 -07:00
Here is a rule of thumb:
2023-12-11 21:10:29 +00:00
- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit` , {func}`jax.vmap` and others.
2024-04-17 16:08:38 -07:00
- Use Python {func}`print` for static values, such as dtypes and array shapes.
2023-12-11 21:10:29 +00:00
2024-03-20 18:18:08 -07:00
Recall from {ref}`jit-compilation` that when transforming a function with {func}`jax.jit` ,
the Python code is executed with abstract tracers in place of your arrays. Because of this,
2024-04-17 16:08:38 -07:00
the Python {func}`print` function will only print this tracer value:
2023-12-11 21:10:29 +00:00
```{code-cell}
import jax
import jax.numpy as jnp
@jax .jit
def f(x):
2024-03-20 18:18:08 -07:00
print("print(x) ->", x)
y = jnp.sin(x)
print("print(y) ->", y)
return y
2023-12-11 21:10:29 +00:00
2024-03-20 18:18:08 -07:00
result = f(2.)
2023-12-11 21:10:29 +00:00
```
2024-03-20 18:18:08 -07:00
Python's `print` executes at trace-time, before the runtime values exist.
If you want to print the actual runtime values, you can use {func}`jax.debug.print` :
```{code-cell}
@jax .jit
def f(x):
jax.debug.print("jax.debug.print(x) -> {x}", x=x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {y}", y=y)
return y
result = f(2.)
```
2023-12-11 21:10:29 +00:00
2024-03-20 18:18:08 -07:00
Similarly, within {func}`jax.vmap` , using Python's `print` will only print the tracer;
to print the values being mapped over, use {func}`jax.debug.print` :
2023-12-11 21:10:29 +00:00
```{code-cell}
def f(x):
2024-03-20 18:18:08 -07:00
jax.debug.print("jax.debug.print(x) -> {}", x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {}", y)
return y
2023-12-11 21:10:29 +00:00
xs = jnp.arange(3.)
2024-03-20 18:18:08 -07:00
result = jax.vmap(f)(xs)
2023-12-11 21:10:29 +00:00
```
2024-03-20 18:18:08 -07:00
Here's the result with {func}`jax.lax.map` , which is a sequential map rather than a
vectorization:
2023-12-11 21:10:29 +00:00
```{code-cell}
2024-03-20 18:18:08 -07:00
result = jax.lax.map(f, xs)
2023-12-11 21:10:29 +00:00
```
Notice the order is different, as {func}`jax.vmap` and {func}`jax.lax.map` compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.
2024-04-17 16:08:38 -07:00
Below is an example with {func}`jax.grad` , where {func}`jax.debug.print` only prints the forward pass. In this case, the behavior is similar to Python's {func}`print` , but it's consistent if you apply {func}`jax.jit` during the call.
2023-12-11 21:10:29 +00:00
```{code-cell}
def f(x):
2024-03-20 18:18:08 -07:00
jax.debug.print("jax.debug.print(x) -> {}", x)
return x ** 2
2023-12-11 21:10:29 +00:00
2024-03-20 18:18:08 -07:00
result = jax.grad(f)(1.)
2023-10-30 13:58:29 -07:00
```
2023-12-11 21:10:29 +00:00
Sometimes, when the arguments don't depend on one another, calls to {func}`jax.debug.print` may print them in a different order when staged out with a JAX transformation. If you need the original order, such as `x: ...` first and then `y: ...` second, add the `ordered=True` parameter.
For example:
```{code-cell}
@jax .jit
def f(x, y):
2024-03-20 18:18:08 -07:00
jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
return x + y
f(1, 2)
2023-12-11 21:10:29 +00:00
```
To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging` .
2024-08-28 11:19:53 -07:00
## `jax.debug.breakpoint` for `pdb`-like debugging
2023-12-11 21:10:29 +00:00
2024-08-28 10:46:47 -07:00
**Summary:** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values.
2023-12-11 21:10:29 +00:00
To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint` . The prompt is similar to Python `pdb` , and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack.
To print all available commands during a `breakpoint` debugging session, use the `help` command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in {ref}`advanced-debugging` .)
2024-03-20 18:18:08 -07:00
Here is an example of what a debugger session might look like:
2023-12-11 21:10:29 +00:00
```{code-cell}
2024-03-20 18:18:08 -07:00
:tags: [skip-execution]
@jax .jit
def f(x):
2024-05-09 22:44:27 -04:00
y, z = jnp.sin(x), jnp.cos(x)
2024-03-20 18:18:08 -07:00
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution
```
2023-12-11 21:10:29 +00:00
2024-04-18 13:11:25 -07:00

2024-03-20 18:18:08 -07:00
For value-dependent breakpointing, you can use runtime conditionals like {func}`jax.lax.cond` :
```{code-cell}
2023-12-11 21:10:29 +00:00
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
jax.lax.cond(is_finite, true_fn, false_fn, x)
@jax .jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
2024-03-20 18:18:08 -07:00
f(2., 1.) # ==> No breakpoint
2023-12-11 21:10:29 +00:00
```
2024-03-20 18:18:08 -07:00
```{code-cell}
:tags: [skip-execution]
f(2., 0.) # ==> Pauses during execution
```
2023-12-11 21:10:29 +00:00
2024-08-28 11:19:53 -07:00
## `jax.debug.callback` for more control during debugging
2023-12-11 21:10:29 +00:00
2024-03-20 18:18:08 -07:00
Both {func}`jax.debug.print` and {func}`jax.debug.breakpoint` are implemented using
the more flexible {func}`jax.debug.callback` , which gives greater control over the
host-side logic executed via a Python callback.
It is compatible with {func}`jax.jit` , {func}`jax.vmap` , {func}`jax.grad` and other
transformations (refer to the {ref}`external-callbacks-flavors-of-callback` table in
{ref}`external-callbacks` for more information).
2023-12-11 21:10:29 +00:00
For example:
```{code-cell}
2024-03-20 18:18:08 -07:00
import logging
2023-12-11 21:10:29 +00:00
def log_value(x):
2024-03-20 18:18:08 -07:00
logging.warning(f'Logged value: {x}')
2023-12-11 21:10:29 +00:00
@jax .jit
def f(x):
jax.debug.callback(log_value, x)
return x
f(1.0);
```
2024-03-20 18:18:08 -07:00
This callback is compatible with other transformations, including {func}`jax.vmap` and {func}`jax.grad` :
2023-12-11 21:10:29 +00:00
```{code-cell}
x = jnp.arange(5.0)
jax.vmap(f)(x);
```
```{code-cell}
jax.grad(f)(1.0);
```
This can make {func}`jax.debug.callback` useful for general-purpose debugging.
2024-03-20 18:18:08 -07:00
You can learn more about {func}`jax.debug.callback` and other kinds of JAX callbacks in {ref}`external-callbacks` .
2023-12-11 21:10:29 +00:00
## Next steps
Check out the {ref}`advanced-debugging` to learn more about debugging in JAX.