mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add debugging tools under jax.debug
and documentation
Co-authored-by: Matthew Johnson <mattjj@google.com> Co-authored-by: Lena Martens <lenamartens@google.com>
This commit is contained in:
parent
27655af6b9
commit
4386a0f909
BIN
docs/_static/debugger.gif
vendored
Normal file
BIN
docs/_static/debugger.gif
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 424 KiB |
158
docs/debugging/checkify_guide.md
Normal file
158
docs/debugging/checkify_guide.md
Normal file
@ -0,0 +1,158 @@
|
||||
# The `checkify` transformation
|
||||
|
||||
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
|
||||
|
||||
```python
|
||||
from jax.experimental import checkify
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
def f(x, i):
|
||||
checkify.check(i >= 0, "index needs to be non-negative!")
|
||||
y = x[i]
|
||||
z = jnp.sin(y)
|
||||
return z
|
||||
|
||||
jittable_f = checkify.checkify(f)
|
||||
|
||||
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
|
||||
print(err.get())
|
||||
# >> index needs to be non-negative! (check failed at <...>:6 (f))
|
||||
```
|
||||
|
||||
You can also use checkify to automatically add common checks:
|
||||
|
||||
```python
|
||||
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
|
||||
checked_f = checkify.checkify(f, errors=errors)
|
||||
|
||||
err, z = checked_f(jnp.ones((5,)), 100)
|
||||
err.throw()
|
||||
# ValueError: out-of-bounds indexing at <..>:7 (f)
|
||||
|
||||
err, z = checked_f(jnp.ones((5,)), -1)
|
||||
err.throw()
|
||||
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
|
||||
|
||||
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
|
||||
err.throw()
|
||||
# ValueError: nan generated by primitive sin at <...>:8 (f)
|
||||
```
|
||||
|
||||
## Functionalizing checks
|
||||
|
||||
The assert-like check API by itself is not functionally pure: it can raise a Python Exception as a side-effect, just like assert. So it can't be staged out with `jit`, `pmap`, `pjit`, or `scan`:
|
||||
|
||||
<!-- TODO this error message might need updating -->
|
||||
```python
|
||||
jax.jit(f)(jnp.ones((5,)), -1) # checkify transformation not used
|
||||
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
|
||||
```
|
||||
|
||||
But the checkify transformation functionalizes (or discharges) these effects. A checkify-transformed function returns an error _value_ as a new output and remains functionally pure. That functionalization means checkify-transformed functions can be composed with staging/transforms however we like:
|
||||
|
||||
```python
|
||||
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
|
||||
err.throw()
|
||||
"""
|
||||
ValueError:
|
||||
.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
|
||||
.. at mapped index 2: out-of-bounds indexing at <..>:7 (f)
|
||||
"""
|
||||
```
|
||||
|
||||
## Why does JAX need checkify?
|
||||
|
||||
In some circumstances you can express runtime error checks with ordinary Python assertions, for example when only using `jax.grad` and `jax.numpy`:
|
||||
|
||||
```python
|
||||
def f(x):
|
||||
assert x > 0., "must be positive!"
|
||||
return jnp.log(x)
|
||||
|
||||
jax.grad(f)(0.)
|
||||
# ValueError: "must be positive!"
|
||||
```
|
||||
|
||||
But ordinary assertions don't work inside `jit`, `pmap`, `pjit`, or `scan`. In those cases, numeric computations are staged out (i.e. JAX will call the function with abstracted values) rather than evaluated eagerly during Python execution, and as a result numeric values aren't available:
|
||||
```python
|
||||
jax.jit(f)(0.)
|
||||
# ConcretizationTypeError: "Abstract tracer value encountered ..."
|
||||
```
|
||||
|
||||
So we can't use ordinary Python assertions everywhere. But beyond needing a new API, the situation is trickier still:
|
||||
XLA HLO doesn't have assertions or error semantics, so even if we had an API which worked with JAX's tracing to stage out assertions, what would we lower them to?
|
||||
JAX transformation semantics rely on functional purity, especially when composing multiple transformations, so how can we provide an error mechanism without disrupting all that?
|
||||
How can we add functional runtime checks, based on error values?
|
||||
We could manually add run-time checks to our function and plumb out values representing errors:
|
||||
|
||||
```python
|
||||
def f_checked(x):
|
||||
error = x <= 0.
|
||||
result = jnp.log(x)
|
||||
return error, result
|
||||
|
||||
err, y = jax.jit(f_checked)(0.)
|
||||
if err:
|
||||
raise ValueError("must be positive!")
|
||||
# ValueError: "must be positive!"
|
||||
```
|
||||
|
||||
Here `f_checked` is functionally pure, since there are no Exceptions involved and only ordinary values, so we know by construction that it'll already work with `jit`, pmap, pjit, scan, and all of JAX's transformations. The only problem is that this plumbing can be a pain!
|
||||
|
||||
checkify does this plumbing for you. When it sees a check application, checkify rewrites it to a functional form and plumbs the resulting error value through your function. That includes plumbing the error value through subsequent check calls (so that only the first is reported) and then ultimately returning it as an additional output of the checkify-transformed function:
|
||||
|
||||
```python
|
||||
def f(x):
|
||||
checkify.check(x > 0., "must be positive!") # convenient but effectful API
|
||||
return jnp.log(x)
|
||||
|
||||
f_checked = checkify(f)
|
||||
|
||||
err, x = jax.jit(f_checked)(0.)
|
||||
err.throw()
|
||||
# ValueError: must be positive! (check failed at <...>:2 (f))
|
||||
```
|
||||
|
||||
We call this functionalizing or discharging the effect introduced by calling check. (In the "manual" example above the error value is just a boolean. checkify's error values are conceptually similar but expose throw and get methods; see {mod}`jax.experimental.checkify`.
|
||||
|
||||
Checkify can also automatically add common checks
|
||||
In addition to functionalizing check calls, checkify can add error checks automatically, e.g. checks for NaNs and out-of-bounds indexing. Consider these error cases:
|
||||
|
||||
```python
|
||||
jnp.arange(3)[5] # out of bounds
|
||||
jnp.sin(jnp.inf) # nan generated
|
||||
jnp.ones((5,)) / jnp.arange(5) # division by zero
|
||||
```
|
||||
|
||||
By default checkify only discharges check applications, and won't do anything to catch errors like the above. But if you ask it to, checkify will also instrument your code with checks for these common issues automatically.
|
||||
|
||||
```python
|
||||
def f(x, i):
|
||||
y = x[i]
|
||||
z = jnp.sin(y)
|
||||
return z
|
||||
|
||||
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
|
||||
checked_f = checkify.checkify(f, errors=errors)
|
||||
|
||||
err, z = checked_f(jnp.ones((5,)), 100)
|
||||
err.throw()
|
||||
# ValueError: out-of-bounds indexing at <..>:7 (f)
|
||||
|
||||
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
|
||||
err.throw()
|
||||
# ValueError: nan generated by primitive sin at <...>:8 (f)
|
||||
```
|
||||
|
||||
These automatically-added checks include source line information in their error messages.
|
||||
The API for selecting which automatic checks to enable is based on Sets. See {mod}`jax.experimental.checkify` for more details.
|
||||
|
||||
## Strengths and limitations of `jax.experimental.checkify`
|
||||
<!-- TODO gotta finish this -->
|
||||
### Strengths
|
||||
* You can use it everywhere (errors are values and behave intuitively under transformations)
|
||||
* Automatic instrumentation: you don't need to make local modifications to your code. Instead, `checkify` can instrument all of it!
|
||||
### Limitations
|
||||
* Adding a lot of runtime checks can be expensive
|
||||
* Requires threading error values in and out of transformations
|
76
docs/debugging/flags.md
Normal file
76
docs/debugging/flags.md
Normal file
@ -0,0 +1,76 @@
|
||||
# JAX debugging flags
|
||||
|
||||
JAX offers flags and context managers.
|
||||
|
||||
## `jax_debug_nans` configuration option and context manager
|
||||
|
||||
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code).
|
||||
|
||||
`jax_debug_nans` is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled -- when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN.
|
||||
|
||||
### Usage
|
||||
|
||||
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
|
||||
* setting the `JAX_DEBUG_NANS=True` environment variable;
|
||||
* adding `from jax.config import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
|
||||
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
|
||||
|
||||
### Example(s)
|
||||
|
||||
```python
|
||||
from jax.config import config
|
||||
config.update("jax_debug_nans", True)
|
||||
|
||||
def f(x, y):
|
||||
return x / y
|
||||
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
|
||||
```
|
||||
|
||||
#### Strengths and limitations of `jax_debug_nans`
|
||||
##### Strengths
|
||||
* Easy to apply
|
||||
* Precisely detects where NaNs were produced
|
||||
* Throws a standard Python exception and is compatible with PDB postmortem
|
||||
|
||||
##### Limitations
|
||||
* Not compatible with `jax.pmap` or `jax.pjit`
|
||||
* Re-running functions eagerly can be slow
|
||||
* Errors on false positives (e.g. intentionally created NaNs)
|
||||
|
||||
## `jax_disable_jit` configuration option and context manager
|
||||
|
||||
**TL;DR** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`
|
||||
|
||||
`jax_disable_jit` is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like `jax.lax.cond` and `jax.lax.scan`).
|
||||
|
||||
### Usage
|
||||
|
||||
You can disable JIT-compilation by:
|
||||
* setting the `JAX_DISABLE_JIT=True` environment variable;
|
||||
* adding `from jax.config import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
|
||||
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
|
||||
|
||||
### Examples
|
||||
|
||||
```python
|
||||
from jax.config import config
|
||||
config.update("jax_disable_jit", True)
|
||||
|
||||
def f(x):
|
||||
y = jnp.log(x)
|
||||
if jnp.isnan(y):
|
||||
breakpoint()
|
||||
return y
|
||||
jax.jit(f)(-2.) # ==> Enters PDB breakpoint!
|
||||
```
|
||||
|
||||
#### Strengths and limitations of `jax_disable_jit`
|
||||
|
||||
##### Strengths
|
||||
* Easy to apply
|
||||
* Enables use of Python's built-in `breakpoint` and `print`
|
||||
* Throws standard Python exceptions and is compatible with PDB postmortem
|
||||
|
||||
##### Limitations
|
||||
* Not compatible with `jax.pmap` or `jax.pjit`
|
||||
* Running functions without JIT-compilation can be slow
|
91
docs/debugging/index.md
Normal file
91
docs/debugging/index.md
Normal file
@ -0,0 +1,91 @@
|
||||
# Debugging in JAX
|
||||
|
||||
Do you have exploding gradients? Are nans making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools!
|
||||
|
||||
## [Interactive inspection with `jax.debug`](print_breakpoint)
|
||||
|
||||
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`- or `jax.pmap`-decorated functions
|
||||
and use {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
|
||||
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
jax.debug.print("🤯 {x} 🤯", x=x)
|
||||
y = jnp.sin(x)
|
||||
jax.debug.breakpoint()
|
||||
jax.debug.print("🤯 {y} 🤯", y=y)
|
||||
return y
|
||||
|
||||
f(2.)
|
||||
# Prints:
|
||||
# 🤯 2.0 🤯
|
||||
# Enters breakpoint to inspect values!
|
||||
# 🤯 0.9092974662780762 🤯
|
||||
```
|
||||
|
||||
## [Functional error checks with `jax.experimental.checkify`](checkify_guide)
|
||||
|
||||
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
|
||||
|
||||
```python
|
||||
from jax.experimental import checkify
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
def f(x, i):
|
||||
checkify.check(i >= 0, "index needs to be non-negative!")
|
||||
y = x[i]
|
||||
z = jnp.sin(y)
|
||||
return z
|
||||
|
||||
jittable_f = checkify.checkify(f)
|
||||
|
||||
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
|
||||
print(err.get())
|
||||
# >> index needs to be non-negative! (check failed at <...>:6 (f))
|
||||
```
|
||||
|
||||
You can also use checkify to automatically add common checks:
|
||||
|
||||
```python
|
||||
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
|
||||
checked_f = checkify.checkify(f, errors=errors)
|
||||
|
||||
err, z = checked_f(jnp.ones((5,)), 100)
|
||||
err.throw()
|
||||
# ValueError: out-of-bounds indexing at <..>:7 (f)
|
||||
|
||||
err, z = checked_f(jnp.ones((5,)), -1)
|
||||
err.throw()
|
||||
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
|
||||
|
||||
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
|
||||
err.throw()
|
||||
# ValueError: nan generated by primitive sin at <...>:8 (f)
|
||||
```
|
||||
|
||||
## [Throwing Python errors with JAX's debug flags](flags)
|
||||
|
||||
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
|
||||
|
||||
```python
|
||||
from jax.config import config
|
||||
config.update("jax_debug_nans", True)
|
||||
|
||||
def f(x, y):
|
||||
return x / y
|
||||
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:caption: Index
|
||||
:maxdepth: 1
|
||||
|
||||
print_breakpoint
|
||||
checkify_guide
|
||||
flags
|
||||
```
|
||||
|
293
docs/debugging/print_breakpoint.md
Normal file
293
docs/debugging/print_breakpoint.md
Normal file
@ -0,0 +1,293 @@
|
||||
# `jax.debug.print` and `jax.debug.breakpoint`
|
||||
|
||||
The {mod}`jax.debug` package offers some useful tools for inspecting values
|
||||
inside of JIT-ted functions.
|
||||
|
||||
## Debugging with `jax.debug.print` and other debugging callbacks
|
||||
|
||||
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`- or `jax.pmap`-decorated functions:
|
||||
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
jax.debug.print("🤯 {x} 🤯", x=x)
|
||||
y = jnp.sin(x)
|
||||
jax.debug.print("🤯 {y} 🤯", y=y)
|
||||
return y
|
||||
|
||||
f(2.)
|
||||
# Prints:
|
||||
# 🤯 2.0 🤯
|
||||
# 🤯 0.9092974662780762 🤯
|
||||
```
|
||||
|
||||
<!-- mattjj added this line -->
|
||||
With some transformations, like `jax.grad` and `jax.vmap`, you can use Python's builtin `print` function to print out numerical values. But `print` won't work with `jax.jit` or `jax.pmap` because those transformations delay numerical evaluation. So use `jax.debug.print` instead!
|
||||
|
||||
Semantically, `jax.debug.print` is roughly equivalent to the following Python function
|
||||
|
||||
```python
|
||||
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
|
||||
print(fmt.format(*args, **kwargs))
|
||||
```
|
||||
except that it can be staged out and transformed by JAX. See the {func}`API reference <jax.debug.print>` for more details.
|
||||
|
||||
### Why "_debug_" print?
|
||||
In the name of debugging, `jax.debug.print` can reveal information about _how_ computations are evaluated:
|
||||
|
||||
```python
|
||||
xs = jnp.arange(3.)
|
||||
|
||||
def f(x):
|
||||
jax.debug.print("x: {}", x)
|
||||
y = np.sin(x)
|
||||
jax.debug.print("y: {}", y)
|
||||
return y
|
||||
jax.vmap(f)(xs)
|
||||
# Prints: x: 0.0
|
||||
# x: 1.0
|
||||
# x: 2.0
|
||||
# y: 0.0
|
||||
# y: 0.841471
|
||||
# y: 0.9092974
|
||||
jax.lax.map(f, xs)
|
||||
# Prints: x: 0.0
|
||||
# y: 0.0
|
||||
# x: 1.0
|
||||
# y: 0.841471
|
||||
# x: 2.0
|
||||
# y: 0.9092974
|
||||
```
|
||||
Notice that the printed results are in different orders!
|
||||
|
||||
By revealing these inner-workings, the output of `jax.debug.print` doesn't respect JAX's usual semantics guarantees, like that `jax.vmap(f)(xs)` and `jax.lax.map(f, xs)` compute the same thing (in different ways). Yet these evaluation order details are exactly what we might want to see when debugging!
|
||||
|
||||
<!-- mattjj tweaked this line -->
|
||||
So use `jax.debug.print` for debugging, and not when semantics guarantees are important.
|
||||
|
||||
### More examples of `jax.debug.print`
|
||||
|
||||
In addition to the above examples using `jit` and `vmap`, here are a few more to have in mind.
|
||||
|
||||
#### Printing under `jax.pmap`
|
||||
|
||||
When `jax.pmap`-ed, `jax.debug.print`s might be reordered!
|
||||
```python
|
||||
xs = jnp.arange(2.)
|
||||
|
||||
def f(x):
|
||||
jax.debug.print("x: {}", x)
|
||||
return x
|
||||
jax.pmap(f)(xs)
|
||||
# Prints: x: 1.0
|
||||
# x: 0.0
|
||||
# OR
|
||||
# Prints: x: 1.0
|
||||
# x: 0.0
|
||||
```
|
||||
|
||||
#### Printing under `jax.grad`
|
||||
|
||||
Under a `jax.grad`, `jax.debug.print`s will only print on the forward pass:
|
||||
```python
|
||||
def f(x):
|
||||
jax.debug.print("x: {}", x)
|
||||
return x * 2.
|
||||
|
||||
jax.grad(f)(1.)
|
||||
# Prints: x: 1.0
|
||||
```
|
||||
|
||||
<!-- mattjj added this line -->
|
||||
This behavior is similar to how Python's builtin `print` works under a `jax.grad`. But by using `jax.debug.print` here, the behavior is the same even if the caller applies a `jax.jit`.
|
||||
|
||||
To print on the backward pass, just use a `jax.custom_vjp`:
|
||||
|
||||
```python
|
||||
@jax.custom_vjp
|
||||
def print_grad(x):
|
||||
return x
|
||||
|
||||
def print_grad_fwd(x):
|
||||
return x, None
|
||||
|
||||
def print_grad_bwd(_, x_grad):
|
||||
jax.debug.print("x_grad: {}", x_grad)
|
||||
return (x_grad,)
|
||||
|
||||
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
|
||||
|
||||
|
||||
def f(x):
|
||||
x = print_grad(x)
|
||||
return x * 2.
|
||||
jax.grad(f)(1.)
|
||||
# Prints: x_grad: 2.0
|
||||
```
|
||||
|
||||
### More control with `jax.debug.callback`
|
||||
|
||||
In fact, `jax.debug.print` is a thin convenience wrapper around `jax.debug.callback`, which can be used directly for greater control over string formatting, or even the kind of output.
|
||||
|
||||
Semantically, `jax.debug.callback` is roughly equivalent to the following Python function
|
||||
|
||||
```python
|
||||
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
|
||||
fun(*args, **kwargs)
|
||||
return None
|
||||
```
|
||||
|
||||
As with `jax.debug.print`, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it's not safe to use `jax.debug.callback` for timing operations, since callbacks might reordered and asynchronous (see below).
|
||||
|
||||
### Sharp bits
|
||||
Like most JAX APIs, `jax.debug.print` can cut you if you're not careful.
|
||||
|
||||
#### Ordering of printed results
|
||||
When distinct calls to `jax.debug.print` involve arguments which don't depend on one another, they might be reordered when staged out, e.g. by `jax.jit`:
|
||||
|
||||
```python
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
jax.debug.print("x: {}", x)
|
||||
jax.debug.print("y: {}", y)
|
||||
return x + y
|
||||
|
||||
f(2., 3.)
|
||||
# Prints: x: 2.0
|
||||
# y: 3.0
|
||||
# OR
|
||||
# Prints: y: 3.0
|
||||
# x: 2.0
|
||||
```
|
||||
|
||||
Why? Under the hood, the compiler gets a functional representation of the staged-out computation, where the imperative order of the Python function is lost and only data dependence remains. This change is invisible to users with functionally pure code, but in the presence of side-effects like printing, it's noticeable.
|
||||
|
||||
To preserve the original order of `jax.debug.print`s as written in your Python function, you can use `jax.debug.print(..., ordered=True)`, which will ensure the relative order of prints is preserved. But using `ordered=True` will raise an error under `jax.pmap` and other JAX transformations involving parallelism, since ordering can't be guaranteed under parallel execution.
|
||||
|
||||
#### Asynchronous callbacks
|
||||
|
||||
Depending on the backend, `jax.debug.print`s may happen asynchronously, i.e. not in your main program thread. This means that values could be printed to your screen even after your JAX function has returned a value.
|
||||
```python
|
||||
@jax.jit
|
||||
def f(x):
|
||||
jax.debug.print("x: {}")
|
||||
return x
|
||||
f(2.).block_until_ready()
|
||||
# <do something else>
|
||||
# Prints: x: 2.
|
||||
```
|
||||
|
||||
To block on the `jax.debug.print`s in a function, you can call `jax.effects_barrier()`, which will wait until any remaining side-effects in the function have completed as well:
|
||||
|
||||
```python
|
||||
@jax.jit
|
||||
def f(x):
|
||||
jax.debug.print("x: {}")
|
||||
return x
|
||||
f(2.).block_until_ready()
|
||||
jax.effects_barrier()
|
||||
# Prints: x: 2.
|
||||
# <do something else>
|
||||
```
|
||||
#### Performance impacts
|
||||
|
||||
##### Unnecessary materialization
|
||||
|
||||
While `jax.debug.print` was designed to have a minimal performance footprint, it can interfere with compiler optimizations and potentially affect the memory profile of your JAX programs.
|
||||
```python
|
||||
def f(w, b, x):
|
||||
logits = w.dot(x) + b
|
||||
jax.debug.print("logits: {}", logits)
|
||||
return jax.nn.relu(logits)
|
||||
```
|
||||
In this example, we are printing intermediate values in between a linear layer and the activation function. Compilers like XLA can perform fusion optimizations, which might avoid materializing `logits` in memory. But when we use `jax.debug.print` on `logits`, we are forcing those intermediates to be materialized, potentially slowing down the program and increasing memory usage.
|
||||
|
||||
Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronization occurs that will materialize values on a single device.
|
||||
|
||||
##### Callback overhead
|
||||
|
||||
`jax.debug.print` inherently incurs communication between an accelerator and its host. The underlying mechanism differs from backend to backend (e.g. GPU vs TPU) but in all cases, we'll need to copy the printed values from device to host. In the CPU case, this overhead is smaller.
|
||||
|
||||
Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronization occurs that adds some overhead.
|
||||
|
||||
### Strengths and limitations of `jax.debug.print`
|
||||
#### Strengths
|
||||
* Print debugging is simple and intuitive
|
||||
* `jax.debug.callback` can be used for other innocuous side-effects
|
||||
|
||||
#### Limitations
|
||||
* Adding print statements is a manual process
|
||||
* Can have performance impacts
|
||||
|
||||
## Interactive inspection with `jax.debug.breakpoint()`
|
||||
|
||||
**TL;DR** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values:
|
||||
|
||||
```python
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y, z = jnp.sin(x), jnp.cos(x)
|
||||
jax.debug.breakpoint()
|
||||
return y * z
|
||||
f(2.) # ==> Pauses during execution!
|
||||
```
|
||||

|
||||
|
||||
`jax.debug.breakpoint()` is actually just an application of `jax.debug.callback(...)` that captures information about the call stack. It has the same transformation behaviors as `jax.debug.print` as a result (e.g. `vmap`-ing `jax.debug.breakpoint()` unrolls it across the mapped axis).
|
||||
|
||||
### Usage
|
||||
|
||||
Calling `jax.debug.breakpoint()` in a compiled JAX function will pause your program when it hits the breakpoint. You'll be presented with a `pdb`-like prompt that allows you to inspect the values in the call stack. Unlike `pdb`, you will not be able to step through the execution, but you are allowed to resume it.
|
||||
|
||||
Debugger commands:
|
||||
* `help` - prints out available commands
|
||||
* `p` - evaluates an expression and prints its result
|
||||
* `pp` - evaluates an expression and pretty-prints its result
|
||||
* `u(p)` - go up a stack frame
|
||||
* `d(own)` - go down a stack frame
|
||||
* `w(here)/bt` - print out a backtrace
|
||||
* `l(ist)` - print out code context
|
||||
* `c(ont(inue))` - resumes the execution of the program
|
||||
* `q(uit)/exit` - exits the program (does not work on TPU)
|
||||
|
||||
### Examples
|
||||
|
||||
#### Usage with `jax.lax.cond`
|
||||
|
||||
When combined with `jax.lax.cond`, the debugger can become a useful tool for detecting `nan`s or `inf`s.
|
||||
|
||||
```python
|
||||
def breakpoint_if_nonfinite(x):
|
||||
is_finite = jnp.isfinite(x).all()
|
||||
def true_fn(x):
|
||||
pass
|
||||
def false_fn(x):
|
||||
jax.debug.breakpoint()
|
||||
lax.cond(has_nan, true_fn, false_fn, x)
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
z = x / y
|
||||
breakpoint_if_nonfinite(z)
|
||||
return z
|
||||
f(2., 0.) # ==> Pauses during execution!
|
||||
```
|
||||
|
||||
### Sharp bits
|
||||
Because `jax.debug.breakpoint` is a just an application of `jax.debug.callback`, it has the same [sharp bits as `jax.debug.print`](#sharp-bits), with a few more caveats:
|
||||
* `jax.debug.breakpoint` materializes *even more* intermediates than `jax.debug.print` because it forces materialization of all values in the call stack
|
||||
* `jax.debug.breakpoint` has more runtime overhead than a `jax.debug.print` because it has to potentially copy all the intermediate values in a JAX program from device to host.
|
||||
|
||||
|
||||
### Strengths and limitations of `jax.debug.breakpoint()`
|
||||
|
||||
#### Strengths
|
||||
* Simple, intuitive and (somewhat) standard
|
||||
* Can inspect many values at the same time, up and down the call stack
|
||||
|
||||
#### Limitations
|
||||
* Need to potentially use many breakpoints pinpoint the source of an error
|
||||
* Materializes many intermediates
|
@ -19,6 +19,11 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
|
||||
|
||||
jax-101/index
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
debugging/index
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Reference Documentation
|
||||
|
19
docs/jax.debug.rst
Normal file
19
docs/jax.debug.rst
Normal file
@ -0,0 +1,19 @@
|
||||
|
||||
jax.debug package
|
||||
=================
|
||||
|
||||
.. currentmodule:: jax.debug
|
||||
|
||||
.. automodule:: jax.debug
|
||||
|
||||
Debugging utilities
|
||||
--------------------------
|
||||
|
||||
:doc:`debugging/print_breakpoint` describes how to make use of JAX's debugging features.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
callback
|
||||
print
|
||||
breakpoint
|
24
docs/jax.experimental.checkify.rst
Normal file
24
docs/jax.experimental.checkify.rst
Normal file
@ -0,0 +1,24 @@
|
||||
jax.experimental.checkify module
|
||||
=====================================
|
||||
|
||||
|
||||
.. automodule:: jax.experimental.checkify
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
checkify
|
||||
check
|
||||
check_error
|
||||
Error
|
||||
ErrorCategory
|
||||
user_checks
|
||||
nan_checks
|
||||
index_checks
|
||||
div_checks
|
||||
float_checks
|
||||
automatic_checks
|
||||
all_checks
|
@ -14,6 +14,7 @@ Experimental Modules
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
jax.experimental.checkify
|
||||
jax.experimental.global_device_array
|
||||
jax.experimental.host_callback
|
||||
jax.experimental.maps
|
||||
|
@ -12,6 +12,7 @@ Subpackages
|
||||
jax.numpy
|
||||
jax.scipy
|
||||
jax.config
|
||||
jax.debug
|
||||
jax.dlpack
|
||||
jax.distributed
|
||||
jax.example_libraries
|
||||
|
@ -130,6 +130,7 @@ from jax._src.tree_util import (
|
||||
from jax import abstract_arrays as abstract_arrays
|
||||
from jax import api_util as api_util
|
||||
from jax import distributed as distributed
|
||||
from jax import debug as debug
|
||||
from jax import dtypes as dtypes
|
||||
from jax import errors as errors
|
||||
from jax import image as image
|
||||
|
@ -137,8 +137,4 @@ def breakpoint(*, ordered: bool = False, **kwargs): # pylint: disable=redefined
|
||||
with debug_lock:
|
||||
debugger(frames, thread_id, **kwargs)
|
||||
|
||||
if ordered:
|
||||
effect = debugging.DebugEffect.ORDERED_PRINT
|
||||
else:
|
||||
effect = debugging.DebugEffect.PRINT
|
||||
debugging.debug_callback(_breakpoint_callback, effect, *flat_args)
|
||||
debugging.debug_callback(_breakpoint_callback, *flat_args, ordered=ordered)
|
||||
|
@ -110,8 +110,8 @@ if jaxlib.version >= (0, 3, 15):
|
||||
mlir.register_lowering(
|
||||
debug_callback_p, debug_callback_lowering, platform="tpu")
|
||||
|
||||
def debug_callback(callback: Callable[..., Any], effect: DebugEffect, *args,
|
||||
**kwargs):
|
||||
def debug_callback(callback: Callable[..., Any], *args: Any,
|
||||
ordered: bool = False, **kwargs: Any):
|
||||
"""Calls a stageable Python callback.
|
||||
|
||||
`debug_callback` enables you to pass in a Python function that can be called
|
||||
@ -128,33 +128,33 @@ def debug_callback(callback: Callable[..., Any], effect: DebugEffect, *args,
|
||||
|
||||
Args:
|
||||
callback: A Python callable.
|
||||
effect: A `DebugEffect`.
|
||||
*args: The positional arguments to the callback.
|
||||
ordered: A keyword only argument used to indicate whether or not the
|
||||
staged out computation will enforce ordering of this callback w.r.t.
|
||||
other ordered callbacks.
|
||||
**kwargs: The positional arguments to the callback.
|
||||
Returns:
|
||||
The value of `callback(*args, **kwargs)`.
|
||||
"""
|
||||
if not isinstance(effect, DebugEffect):
|
||||
raise ValueError("Can only use `DebugEffect` effects in `debug_callback`")
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
effect = DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT
|
||||
return debug_callback_p.bind(*flat_args, callback=callback, effect=effect,
|
||||
in_tree=in_tree)
|
||||
|
||||
def _format_print_callback(fmt: str, *args, **kwargs):
|
||||
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
|
||||
|
||||
def debug_print(fmt: str, *args, ordered=False, **kwargs) -> None:
|
||||
def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
|
||||
"""Prints values and works in staged out JAX functions.
|
||||
|
||||
Args:
|
||||
fmt: A format string, e.g. `"hello {x}"`, that will be used to format
|
||||
fmt: A format string, e.g. ``"hello {x}"``, that will be used to format
|
||||
input arguments.
|
||||
*args: A list of positional arguments to be formatted.
|
||||
ordered: A keyword only argument used to indicate whether or not the
|
||||
staged out computation will enforce ordering of this `debug_print` w.r.t.
|
||||
other ordered `debug_print`s.
|
||||
staged out computation will enforce ordering of this ``debug_print``
|
||||
w.r.t. other ordered ``debug_print`` calls.
|
||||
**kwargs: Additional keyword arguments to be formatted.
|
||||
"""
|
||||
effect = DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT
|
||||
debug_callback(functools.partial(_format_print_callback, fmt), effect, *args,
|
||||
**kwargs)
|
||||
debug_callback(functools.partial(_format_print_callback, fmt), *args,
|
||||
**kwargs, ordered=ordered)
|
||||
|
17
jax/debug.py
Normal file
17
jax/debug.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from jax._src.debugging import debug_callback as callback
|
||||
from jax._src.debugging import debug_print as print
|
||||
from jax._src.debugging import DebugEffect
|
||||
from jax._src.debugger import breakpoint
|
@ -153,6 +153,9 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
# TODO: Remove once tensorflow is 2.10.0 everywhere.
|
||||
if p.name == "optimization_barrier":
|
||||
continue
|
||||
if p.name == "debug_callback":
|
||||
# TODO(sharadmv,necula): enable debug callbacks in TF
|
||||
continue
|
||||
if p.name in tf_not_yet_impl:
|
||||
self.assertNotIn(
|
||||
p, tf_impl) # Should not be in both tf_impl and tf_not_yet_impl
|
||||
|
Loading…
x
Reference in New Issue
Block a user