Merge pull request #11670 from sharadmv:debugging-docs

PiperOrigin-RevId: 464145832
This commit is contained in:
jax authors 2022-07-29 13:26:56 -07:00
commit 6cffa720e7
6 changed files with 77 additions and 66 deletions

View File

@ -1,6 +1,6 @@
# JAX debugging flags
JAX offers flags and context managers.
JAX offers flags and context managers that enable catching errors more easily.
## `jax_debug_nans` configuration option and context manager

View File

@ -1,6 +1,6 @@
# Debugging in JAX
# Runtime value debugging in JAX
Do you have exploding gradients? Are nans making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools!
Do you have exploding gradients? Are nans making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has tl;dr summaries and you can click the "Read more" links at the bottom to learn more.
## [Interactive inspection with `jax.debug`](print_breakpoint)
@ -26,6 +26,8 @@ Do you have exploding gradients? Are nans making you gnash your teeth? Just want
# 🤯 0.9092974662780762 🤯
```
Click [here](print_breakpoint) to learn more!
## [Functional error checks with `jax.experimental.checkify`](checkify_guide)
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
@ -67,6 +69,8 @@ Do you have exploding gradients? Are nans making you gnash your teeth? Just want
# ValueError: nan generated by primitive sin at <...>:8 (f)
```
Click [here](checkify_guide) to learn more!
## [Throwing Python errors with JAX's debug flags](flags)
**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
@ -80,8 +84,10 @@ def f(x, y):
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
```
Click [here](flags) to learn more!
```{toctree}
:caption: Index
:caption: Read more
:maxdepth: 1
print_breakpoint

View File

@ -26,7 +26,7 @@ DebuggerFrame = debugger_core.DebuggerFrame
class CliDebugger(cmd.Cmd):
"""A text-based debugger."""
prompt = '(jaxdb) '
prompt = '(jdb) '
use_rawinput: bool = False
def __init__(self, frames: List[DebuggerFrame], thread_id,
@ -36,7 +36,7 @@ class CliDebugger(cmd.Cmd):
self.frames = frames
self.frame_index = 0
self.thread_id = thread_id
self.intro = 'Entering jaxdb:'
self.intro = 'Entering jdb:'
def current_frame(self):
return self.frames[self.frame_index]

View File

@ -95,7 +95,9 @@ class Debugger(Protocol):
_debugger_registry: Dict[str, Tuple[int, Debugger]] = {}
def get_debugger() -> Debugger:
def get_debugger(backend: Optional[str] = None) -> Debugger:
if backend is not None and backend in _debugger_registry:
return _debugger_registry[backend][1]
debuggers = sorted(_debugger_registry.values(), key=lambda x: -x[0])
if not debuggers:
raise ValueError("No debuggers registered!")
@ -111,7 +113,7 @@ def register_debugger(name: str, debugger: Debugger, priority: int) -> None:
debug_lock = threading.Lock()
def breakpoint(*, ordered: bool = False, **kwargs): # pylint: disable=redefined-builtin
def breakpoint(*, ordered: bool = False, backend=None, **kwargs): # pylint: disable=redefined-builtin
"""Enters a breakpoint at a point in a program."""
frame_infos = inspect.stack()
# Filter out internal frames
@ -131,7 +133,7 @@ def breakpoint(*, ordered: bool = False, **kwargs): # pylint: disable=redefined
thread_id = None
if threading.current_thread() is not threading.main_thread():
thread_id = threading.get_ident()
debugger = get_debugger()
debugger = get_debugger(backend=backend)
# Lock here because this could be called from multiple threads at the same
# time.
with debug_lock:

View File

@ -33,7 +33,7 @@ _web_consoles: Dict[Tuple[str, int], web_pdb.WebConsole] = {}
class WebDebugger(cli_debugger.CliDebugger):
"""A web-based debugger."""
prompt = '(jaxdb) '
prompt = '(jdb) '
use_rawinput: bool = False
def __init__(self, frames: List[debugger_core.DebuggerFrame], thread_id,

View File

@ -67,7 +67,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
with self.assertRaises(SystemExit):
f(2.)
@ -79,13 +79,13 @@ class CliDebuggerTest(jtu.JaxTestCase):
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
f(2.)
jax.effects_barrier()
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) """)
Entering jdb:
(jdb) """)
self.assertEqual(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
@ -94,12 +94,12 @@ class CliDebuggerTest(jtu.JaxTestCase):
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) DeviceArray(2., dtype=float32)
(jaxdb) """)
Entering jdb:
(jdb) DeviceArray(2., dtype=float32)
(jdb) """)
f(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
@ -111,12 +111,12 @@ class CliDebuggerTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) array(2., dtype=float32)
(jaxdb) """)
Entering jdb:
(jdb) array(2., dtype=float32)
(jdb) """)
f(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
@ -128,12 +128,12 @@ class CliDebuggerTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = x + 1.
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) (array(2., dtype=float32), array(3., dtype=float32))
(jaxdb) """)
Entering jdb:
(jdb) (array(2., dtype=float32), array(3., dtype=float32))
(jdb) """)
f(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
@ -145,20 +145,20 @@ class CliDebuggerTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
f(2.)
jax.effects_barrier()
expected = _format_multiline(r"""
Entering jaxdb:
\(jaxdb\) > .*debugger_test\.py\([0-9]+\)
Entering jdb:
\(jdb\) > .*debugger_test\.py\([0-9]+\)
@jax\.jit
def f\(x\):
y = jnp\.sin\(x\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)
return y
.*
\(jaxdb\) """)
\(jdb\) """)
self.assertRegex(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
@ -168,11 +168,11 @@ class CliDebuggerTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jaxdb:.*
\(jaxdb\) Traceback:.*
Entering jdb:.*
\(jdb\) Traceback:.*
""")
f(2.)
jax.effects_barrier()
@ -184,7 +184,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
@jax.jit
@ -192,27 +192,27 @@ class CliDebuggerTest(jtu.JaxTestCase):
y = f(x)
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jaxdb:
\(jaxdb\) > .*debugger_test\.py\([0-9]+\)
Entering jdb:
\(jdb\) > .*debugger_test\.py\([0-9]+\)
def f\(x\):
y = jnp\.sin\(x\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)
return y
.*
\(jaxdb\) > .*debugger_test\.py\([0-9]+\).*
\(jdb\) > .*debugger_test\.py\([0-9]+\).*
@jax\.jit
def g\(x\):
-> y = f\(x\)
return jnp\.exp\(y\)
.*
\(jaxdb\) array\(2\., dtype=float32\)
\(jaxdb\) > .*debugger_test\.py\([0-9]+\)
\(jdb\) array\(2\., dtype=float32\)
\(jdb\) > .*debugger_test\.py\([0-9]+\)
def f\(x\):
y = jnp\.sin\(x\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)
return y
.*
\(jaxdb\) """)
\(jdb\) """)
g(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
@ -223,20 +223,22 @@ class CliDebuggerTest(jtu.JaxTestCase):
def f(x):
y = x + 1.
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True)
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True,
backend="cli")
return y
@jax.jit
def g(x):
y = f(x) * 2.
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True)
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True,
backend="cli")
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) array(3., dtype=float32)
(jaxdb) Entering jaxdb:
(jaxdb) array(6., dtype=float32)
(jaxdb) """)
Entering jdb:
(jdb) array(3., dtype=float32)
(jdb) Entering jdb:
(jdb) array(6., dtype=float32)
(jdb) """)
g(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
@ -251,7 +253,8 @@ class CliDebuggerTest(jtu.JaxTestCase):
def f(x):
y = x + 1.
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=ordered)
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=ordered,
backend="cli")
return 2. * y
@jax.jit
@ -260,11 +263,11 @@ class CliDebuggerTest(jtu.JaxTestCase):
y = f(x)
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) array(1., dtype=float32)
(jaxdb) Entering jaxdb:
(jaxdb) array(2., dtype=float32)
(jaxdb) """)
Entering jdb:
(jdb) array(1., dtype=float32)
(jdb) Entering jdb:
(jdb) array(2., dtype=float32)
(jdb) """)
g(jnp.arange(2., dtype=jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
@ -277,7 +280,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
@jax.pmap
@ -285,11 +288,11 @@ class CliDebuggerTest(jtu.JaxTestCase):
y = f(x)
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jaxdb:
\(jaxdb\) array\(.*, dtype=float32\)
\(jaxdb\) Entering jaxdb:
\(jaxdb\) array\(.*, dtype=float32\)
\(jaxdb\) """)
Entering jdb:
\(jdb\) array\(.*, dtype=float32\)
\(jdb\) Entering jdb:
\(jdb\) array\(.*, dtype=float32\)
\(jdb\) """)
g(jnp.arange(2., dtype=jnp.float32))
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
@ -302,7 +305,7 @@ class CliDebuggerTest(jtu.JaxTestCase):
def f(x):
y = x + 1
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
def g(x):
@ -313,9 +316,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
with maps.Mesh(np.array(jax.devices()), ["dev"]):
arr = (1 + np.arange(8)).astype(np.int32)
expected = _format_multiline(r"""
Entering jaxdb:
\(jaxdb\) {}
\(jaxdb\) """.format(re.escape(repr(arr))))
Entering jdb:
\(jdb\) {}
\(jdb\) """.format(re.escape(repr(arr))))
g(jnp.arange(8, dtype=jnp.int32))
jax.effects_barrier()
print(stdout.getvalue())