Add debugging tools under jax.debug and documentation

Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Lena Martens <lenamartens@google.com>
This commit is contained in:
Sharad Vikram 2022-07-26 14:47:36 -07:00
parent 27655af6b9
commit 4386a0f909
15 changed files with 702 additions and 17 deletions

BIN
docs/_static/debugger.gif vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 424 KiB

View File

@ -0,0 +1,158 @@
# The `checkify` transformation
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
```python
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
```
You can also use checkify to automatically add common checks:
```python
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
```
## Functionalizing checks
The assert-like check API by itself is not functionally pure: it can raise a Python Exception as a side-effect, just like assert. So it can't be staged out with `jit`, `pmap`, `pjit`, or `scan`:
<!-- TODO this error message might need updating -->
```python
jax.jit(f)(jnp.ones((5,)), -1) # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
```
But the checkify transformation functionalizes (or discharges) these effects. A checkify-transformed function returns an error _value_ as a new output and remains functionally pure. That functionalization means checkify-transformed functions can be composed with staging/transforms however we like:
```python
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
.. at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""
```
## Why does JAX need checkify?
In some circumstances you can express runtime error checks with ordinary Python assertions, for example when only using `jax.grad` and `jax.numpy`:
```python
def f(x):
assert x > 0., "must be positive!"
return jnp.log(x)
jax.grad(f)(0.)
# ValueError: "must be positive!"
```
But ordinary assertions don't work inside `jit`, `pmap`, `pjit`, or `scan`. In those cases, numeric computations are staged out (i.e. JAX will call the function with abstracted values) rather than evaluated eagerly during Python execution, and as a result numeric values aren't available:
```python
jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."
```
So we can't use ordinary Python assertions everywhere. But beyond needing a new API, the situation is trickier still:
XLA HLO doesn't have assertions or error semantics, so even if we had an API which worked with JAX's tracing to stage out assertions, what would we lower them to?
JAX transformation semantics rely on functional purity, especially when composing multiple transformations, so how can we provide an error mechanism without disrupting all that?
How can we add functional runtime checks, based on error values?
We could manually add run-time checks to our function and plumb out values representing errors:
```python
def f_checked(x):
error = x <= 0.
result = jnp.log(x)
return error, result
err, y = jax.jit(f_checked)(0.)
if err:
raise ValueError("must be positive!")
# ValueError: "must be positive!"
```
Here `f_checked` is functionally pure, since there are no Exceptions involved and only ordinary values, so we know by construction that it'll already work with `jit`, pmap, pjit, scan, and all of JAX's transformations. The only problem is that this plumbing can be a pain!
checkify does this plumbing for you. When it sees a check application, checkify rewrites it to a functional form and plumbs the resulting error value through your function. That includes plumbing the error value through subsequent check calls (so that only the first is reported) and then ultimately returning it as an additional output of the checkify-transformed function:
```python
def f(x):
checkify.check(x > 0., "must be positive!") # convenient but effectful API
return jnp.log(x)
f_checked = checkify(f)
err, x = jax.jit(f_checked)(0.)
err.throw()
# ValueError: must be positive! (check failed at <...>:2 (f))
```
We call this functionalizing or discharging the effect introduced by calling check. (In the "manual" example above the error value is just a boolean. checkify's error values are conceptually similar but expose throw and get methods; see {mod}`jax.experimental.checkify`.
Checkify can also automatically add common checks
In addition to functionalizing check calls, checkify can add error checks automatically, e.g. checks for NaNs and out-of-bounds indexing. Consider these error cases:
```python
jnp.arange(3)[5] # out of bounds
jnp.sin(jnp.inf) # nan generated
jnp.ones((5,)) / jnp.arange(5) # division by zero
```
By default checkify only discharges check applications, and won't do anything to catch errors like the above. But if you ask it to, checkify will also instrument your code with checks for these common issues automatically.
```python
def f(x, i):
y = x[i]
z = jnp.sin(y)
return z
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
```
These automatically-added checks include source line information in their error messages.
The API for selecting which automatic checks to enable is based on Sets. See {mod}`jax.experimental.checkify` for more details.
## Strengths and limitations of `jax.experimental.checkify`
<!-- TODO gotta finish this -->
### Strengths
* You can use it everywhere (errors are values and behave intuitively under transformations)
* Automatic instrumentation: you don't need to make local modifications to your code. Instead, `checkify` can instrument all of it!
### Limitations
* Adding a lot of runtime checks can be expensive
* Requires threading error values in and out of transformations

76
docs/debugging/flags.md Normal file
View File

@ -0,0 +1,76 @@
# JAX debugging flags
JAX offers flags and context managers.
## `jax_debug_nans` configuration option and context manager
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code).
`jax_debug_nans` is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled -- when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN.
### Usage
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
* setting the `JAX_DEBUG_NANS=True` environment variable;
* adding `from jax.config import config` and `config.update("jax_debug_nans", True)` near the top of your main file;
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;
### Example(s)
```python
from jax.config import config
config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
```
#### Strengths and limitations of `jax_debug_nans`
##### Strengths
* Easy to apply
* Precisely detects where NaNs were produced
* Throws a standard Python exception and is compatible with PDB postmortem
##### Limitations
* Not compatible with `jax.pmap` or `jax.pjit`
* Re-running functions eagerly can be slow
* Errors on false positives (e.g. intentionally created NaNs)
## `jax_disable_jit` configuration option and context manager
**TL;DR** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`
`jax_disable_jit` is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like `jax.lax.cond` and `jax.lax.scan`).
### Usage
You can disable JIT-compilation by:
* setting the `JAX_DISABLE_JIT=True` environment variable;
* adding `from jax.config import config` and `config.update("jax_disable_jit", True)` near the top of your main file;
* adding from `jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_disable_jit=True`;
### Examples
```python
from jax.config import config
config.update("jax_disable_jit", True)
def f(x):
y = jnp.log(x)
if jnp.isnan(y):
breakpoint()
return y
jax.jit(f)(-2.) # ==> Enters PDB breakpoint!
```
#### Strengths and limitations of `jax_disable_jit`
##### Strengths
* Easy to apply
* Enables use of Python's built-in `breakpoint` and `print`
* Throws standard Python exceptions and is compatible with PDB postmortem
##### Limitations
* Not compatible with `jax.pmap` or `jax.pjit`
* Running functions without JIT-compilation can be slow

91
docs/debugging/index.md Normal file
View File

@ -0,0 +1,91 @@
# Debugging in JAX
Do you have exploding gradients? Are nans making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools!
## [Interactive inspection with `jax.debug`](print_breakpoint)
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`- or `jax.pmap`-decorated functions
and use {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:
```python
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.breakpoint()
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯
```
## [Functional error checks with `jax.experimental.checkify`](checkify_guide)
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
```python
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
```
You can also use checkify to automatically add common checks:
```python
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
```
## [Throwing Python errors with JAX's debug flags](flags)
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
```python
from jax.config import config
config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
```
```{toctree}
:caption: Index
:maxdepth: 1
print_breakpoint
checkify_guide
flags
```

View File

@ -0,0 +1,293 @@
# `jax.debug.print` and `jax.debug.breakpoint`
The {mod}`jax.debug` package offers some useful tools for inspecting values
inside of JIT-ted functions.
## Debugging with `jax.debug.print` and other debugging callbacks
**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`- or `jax.pmap`-decorated functions:
```python
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯
```
<!-- mattjj added this line -->
With some transformations, like `jax.grad` and `jax.vmap`, you can use Python's builtin `print` function to print out numerical values. But `print` won't work with `jax.jit` or `jax.pmap` because those transformations delay numerical evaluation. So use `jax.debug.print` instead!
Semantically, `jax.debug.print` is roughly equivalent to the following Python function
```python
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
print(fmt.format(*args, **kwargs))
```
except that it can be staged out and transformed by JAX. See the {func}`API reference <jax.debug.print>` for more details.
### Why "_debug_" print?
In the name of debugging, `jax.debug.print` can reveal information about _how_ computations are evaluated:
```python
xs = jnp.arange(3.)
def f(x):
jax.debug.print("x: {}", x)
y = np.sin(x)
jax.debug.print("y: {}", y)
return y
jax.vmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# x: 2.0
# y: 0.0
# y: 0.841471
# y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
# y: 0.0
# x: 1.0
# y: 0.841471
# x: 2.0
# y: 0.9092974
```
Notice that the printed results are in different orders!
By revealing these inner-workings, the output of `jax.debug.print` doesn't respect JAX's usual semantics guarantees, like that `jax.vmap(f)(xs)` and `jax.lax.map(f, xs)` compute the same thing (in different ways). Yet these evaluation order details are exactly what we might want to see when debugging!
<!-- mattjj tweaked this line -->
So use `jax.debug.print` for debugging, and not when semantics guarantees are important.
### More examples of `jax.debug.print`
In addition to the above examples using `jit` and `vmap`, here are a few more to have in mind.
#### Printing under `jax.pmap`
When `jax.pmap`-ed, `jax.debug.print`s might be reordered!
```python
xs = jnp.arange(2.)
def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 1.0
# x: 0.0
# OR
# Prints: x: 1.0
# x: 0.0
```
#### Printing under `jax.grad`
Under a `jax.grad`, `jax.debug.print`s will only print on the forward pass:
```python
def f(x):
jax.debug.print("x: {}", x)
return x * 2.
jax.grad(f)(1.)
# Prints: x: 1.0
```
<!-- mattjj added this line -->
This behavior is similar to how Python's builtin `print` works under a `jax.grad`. But by using `jax.debug.print` here, the behavior is the same even if the caller applies a `jax.jit`.
To print on the backward pass, just use a `jax.custom_vjp`:
```python
@jax.custom_vjp
def print_grad(x):
return x
def print_grad_fwd(x):
return x, None
def print_grad_bwd(_, x_grad):
jax.debug.print("x_grad: {}", x_grad)
return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
x = print_grad(x)
return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0
```
### More control with `jax.debug.callback`
In fact, `jax.debug.print` is a thin convenience wrapper around `jax.debug.callback`, which can be used directly for greater control over string formatting, or even the kind of output.
Semantically, `jax.debug.callback` is roughly equivalent to the following Python function
```python
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
fun(*args, **kwargs)
return None
```
As with `jax.debug.print`, these callbacks should only be used for debugging output, like printing or plotting. Printing and plotting are pretty harmless, but if you use it for anything else its behavior might surprise you under transformations. For example, it's not safe to use `jax.debug.callback` for timing operations, since callbacks might reordered and asynchronous (see below).
### Sharp bits
Like most JAX APIs, `jax.debug.print` can cut you if you're not careful.
#### Ordering of printed results
When distinct calls to `jax.debug.print` involve arguments which don't depend on one another, they might be reordered when staged out, e.g. by `jax.jit`:
```python
@jax.jit
def f(x, y):
jax.debug.print("x: {}", x)
jax.debug.print("y: {}", y)
return x + y
f(2., 3.)
# Prints: x: 2.0
# y: 3.0
# OR
# Prints: y: 3.0
# x: 2.0
```
Why? Under the hood, the compiler gets a functional representation of the staged-out computation, where the imperative order of the Python function is lost and only data dependence remains. This change is invisible to users with functionally pure code, but in the presence of side-effects like printing, it's noticeable.
To preserve the original order of `jax.debug.print`s as written in your Python function, you can use `jax.debug.print(..., ordered=True)`, which will ensure the relative order of prints is preserved. But using `ordered=True` will raise an error under `jax.pmap` and other JAX transformations involving parallelism, since ordering can't be guaranteed under parallel execution.
#### Asynchronous callbacks
Depending on the backend, `jax.debug.print`s may happen asynchronously, i.e. not in your main program thread. This means that values could be printed to your screen even after your JAX function has returned a value.
```python
@jax.jit
def f(x):
jax.debug.print("x: {}")
return x
f(2.).block_until_ready()
# <do something else>
# Prints: x: 2.
```
To block on the `jax.debug.print`s in a function, you can call `jax.effects_barrier()`, which will wait until any remaining side-effects in the function have completed as well:
```python
@jax.jit
def f(x):
jax.debug.print("x: {}")
return x
f(2.).block_until_ready()
jax.effects_barrier()
# Prints: x: 2.
# <do something else>
```
#### Performance impacts
##### Unnecessary materialization
While `jax.debug.print` was designed to have a minimal performance footprint, it can interfere with compiler optimizations and potentially affect the memory profile of your JAX programs.
```python
def f(w, b, x):
logits = w.dot(x) + b
jax.debug.print("logits: {}", logits)
return jax.nn.relu(logits)
```
In this example, we are printing intermediate values in between a linear layer and the activation function. Compilers like XLA can perform fusion optimizations, which might avoid materializing `logits` in memory. But when we use `jax.debug.print` on `logits`, we are forcing those intermediates to be materialized, potentially slowing down the program and increasing memory usage.
Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronization occurs that will materialize values on a single device.
##### Callback overhead
`jax.debug.print` inherently incurs communication between an accelerator and its host. The underlying mechanism differs from backend to backend (e.g. GPU vs TPU) but in all cases, we'll need to copy the printed values from device to host. In the CPU case, this overhead is smaller.
Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronization occurs that adds some overhead.
### Strengths and limitations of `jax.debug.print`
#### Strengths
* Print debugging is simple and intuitive
* `jax.debug.callback` can be used for other innocuous side-effects
#### Limitations
* Adding print statements is a manual process
* Can have performance impacts
## Interactive inspection with `jax.debug.breakpoint()`
**TL;DR** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values:
```python
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution!
```
![JAX debugger](../_static/debugger.gif)
`jax.debug.breakpoint()` is actually just an application of `jax.debug.callback(...)` that captures information about the call stack. It has the same transformation behaviors as `jax.debug.print` as a result (e.g. `vmap`-ing `jax.debug.breakpoint()` unrolls it across the mapped axis).
### Usage
Calling `jax.debug.breakpoint()` in a compiled JAX function will pause your program when it hits the breakpoint. You'll be presented with a `pdb`-like prompt that allows you to inspect the values in the call stack. Unlike `pdb`, you will not be able to step through the execution, but you are allowed to resume it.
Debugger commands:
* `help` - prints out available commands
* `p` - evaluates an expression and prints its result
* `pp` - evaluates an expression and pretty-prints its result
* `u(p)` - go up a stack frame
* `d(own)` - go down a stack frame
* `w(here)/bt` - print out a backtrace
* `l(ist)` - print out code context
* `c(ont(inue))` - resumes the execution of the program
* `q(uit)/exit` - exits the program (does not work on TPU)
### Examples
#### Usage with `jax.lax.cond`
When combined with `jax.lax.cond`, the debugger can become a useful tool for detecting `nan`s or `inf`s.
```python
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
lax.cond(has_nan, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 0.) # ==> Pauses during execution!
```
### Sharp bits
Because `jax.debug.breakpoint` is a just an application of `jax.debug.callback`, it has the same [sharp bits as `jax.debug.print`](#sharp-bits), with a few more caveats:
* `jax.debug.breakpoint` materializes *even more* intermediates than `jax.debug.print` because it forces materialization of all values in the call stack
* `jax.debug.breakpoint` has more runtime overhead than a `jax.debug.print` because it has to potentially copy all the intermediate values in a JAX program from device to host.
### Strengths and limitations of `jax.debug.breakpoint()`
#### Strengths
* Simple, intuitive and (somewhat) standard
* Can inspect many values at the same time, up and down the call stack
#### Limitations
* Need to potentially use many breakpoints pinpoint the source of an error
* Materializes many intermediates

View File

@ -19,6 +19,11 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
jax-101/index
.. toctree::
:maxdepth: 2
debugging/index
.. toctree::
:maxdepth: 1
:caption: Reference Documentation

19
docs/jax.debug.rst Normal file
View File

@ -0,0 +1,19 @@
jax.debug package
=================
.. currentmodule:: jax.debug
.. automodule:: jax.debug
Debugging utilities
--------------------------
:doc:`debugging/print_breakpoint` describes how to make use of JAX's debugging features.
.. autosummary::
:toctree: _autosummary
callback
print
breakpoint

View File

@ -0,0 +1,24 @@
jax.experimental.checkify module
=====================================
.. automodule:: jax.experimental.checkify
API
---
.. autosummary::
:toctree: _autosummary
checkify
check
check_error
Error
ErrorCategory
user_checks
nan_checks
index_checks
div_checks
float_checks
automatic_checks
all_checks

View File

@ -14,6 +14,7 @@ Experimental Modules
.. toctree::
:maxdepth: 1
jax.experimental.checkify
jax.experimental.global_device_array
jax.experimental.host_callback
jax.experimental.maps

View File

@ -12,6 +12,7 @@ Subpackages
jax.numpy
jax.scipy
jax.config
jax.debug
jax.dlpack
jax.distributed
jax.example_libraries

View File

@ -130,6 +130,7 @@ from jax._src.tree_util import (
from jax import abstract_arrays as abstract_arrays
from jax import api_util as api_util
from jax import distributed as distributed
from jax import debug as debug
from jax import dtypes as dtypes
from jax import errors as errors
from jax import image as image

View File

@ -137,8 +137,4 @@ def breakpoint(*, ordered: bool = False, **kwargs): # pylint: disable=redefined
with debug_lock:
debugger(frames, thread_id, **kwargs)
if ordered:
effect = debugging.DebugEffect.ORDERED_PRINT
else:
effect = debugging.DebugEffect.PRINT
debugging.debug_callback(_breakpoint_callback, effect, *flat_args)
debugging.debug_callback(_breakpoint_callback, *flat_args, ordered=ordered)

View File

@ -110,8 +110,8 @@ if jaxlib.version >= (0, 3, 15):
mlir.register_lowering(
debug_callback_p, debug_callback_lowering, platform="tpu")
def debug_callback(callback: Callable[..., Any], effect: DebugEffect, *args,
**kwargs):
def debug_callback(callback: Callable[..., Any], *args: Any,
ordered: bool = False, **kwargs: Any):
"""Calls a stageable Python callback.
`debug_callback` enables you to pass in a Python function that can be called
@ -128,33 +128,33 @@ def debug_callback(callback: Callable[..., Any], effect: DebugEffect, *args,
Args:
callback: A Python callable.
effect: A `DebugEffect`.
*args: The positional arguments to the callback.
ordered: A keyword only argument used to indicate whether or not the
staged out computation will enforce ordering of this callback w.r.t.
other ordered callbacks.
**kwargs: The positional arguments to the callback.
Returns:
The value of `callback(*args, **kwargs)`.
"""
if not isinstance(effect, DebugEffect):
raise ValueError("Can only use `DebugEffect` effects in `debug_callback`")
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
effect = DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT
return debug_callback_p.bind(*flat_args, callback=callback, effect=effect,
in_tree=in_tree)
def _format_print_callback(fmt: str, *args, **kwargs):
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
def debug_print(fmt: str, *args, ordered=False, **kwargs) -> None:
def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
"""Prints values and works in staged out JAX functions.
Args:
fmt: A format string, e.g. `"hello {x}"`, that will be used to format
fmt: A format string, e.g. ``"hello {x}"``, that will be used to format
input arguments.
*args: A list of positional arguments to be formatted.
ordered: A keyword only argument used to indicate whether or not the
staged out computation will enforce ordering of this `debug_print` w.r.t.
other ordered `debug_print`s.
staged out computation will enforce ordering of this ``debug_print``
w.r.t. other ordered ``debug_print`` calls.
**kwargs: Additional keyword arguments to be formatted.
"""
effect = DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT
debug_callback(functools.partial(_format_print_callback, fmt), effect, *args,
**kwargs)
debug_callback(functools.partial(_format_print_callback, fmt), *args,
**kwargs, ordered=ordered)

17
jax/debug.py Normal file
View File

@ -0,0 +1,17 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.debugging import debug_callback as callback
from jax._src.debugging import debug_print as print
from jax._src.debugging import DebugEffect
from jax._src.debugger import breakpoint

View File

@ -153,6 +153,9 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
# TODO: Remove once tensorflow is 2.10.0 everywhere.
if p.name == "optimization_barrier":
continue
if p.name == "debug_callback":
# TODO(sharadmv,necula): enable debug callbacks in TF
continue
if p.name in tf_not_yet_impl:
self.assertNotIn(
p, tf_impl) # Should not be in both tf_impl and tf_not_yet_impl