Merge pull request #11806 from sharadmv:debugger-improvements

PiperOrigin-RevId: 466337260
This commit is contained in:
jax authors 2022-08-09 06:14:56 -07:00
commit 870e8a2928
2 changed files with 102 additions and 16 deletions

View File

@ -80,7 +80,7 @@ class DebuggerFrame:
return DebuggerFrame(
filename=frame_info.filename,
locals=frame_info.frame.f_locals,
globals=frame_info.frame.f_globals,
globals={},
code_context=frame_info.code_context,
source=source,
lineno=frame_info.lineno,
@ -113,19 +113,45 @@ def register_debugger(name: str, debugger: Debugger, priority: int) -> None:
debug_lock = threading.Lock()
def breakpoint(*, ordered: bool = False, backend=None, **kwargs): # pylint: disable=redefined-builtin
"""Enters a breakpoint at a point in a program."""
def breakpoint(*, backend: Optional[str] = None, filter_frames: bool = True,
num_frames: Optional[int] = None, ordered: bool = False,
**kwargs): # pylint: disable=redefined-builtin
"""Enters a breakpoint at a point in a program.
Args:
backend: The debugger backend to use. By default, picks the highest priority
debugger and in the absence of other registered debuggers, falls back to
the CLI debugger.
filter_frames: Whether or not to filter out JAX-internal stack frames from
the traceback. Since some libraries, like Flax, also make user of JAX's
stack frame filtering system, this option can also affect whether stack
frames from libraries are filtered.
num_frames: The number of frames above the current stack frame to make
available for inspection in the interactive debugger.
ordered: A keyword only argument used to indicate whether or not the
staged out computation will enforce ordering of this ``debug_print``
with respect to other ordered ``debug_print`` calls.
Returns:
None.
"""
frame_infos = inspect.stack()
# Filter out internal frames
frame_infos = [
frame_info for frame_info in frame_infos
if traceback_util.include_frame(frame_info.frame)
]
frames = [
DebuggerFrame.from_frameinfo(frame_info) for frame_info in frame_infos
]
# Throw out first frame corresponding to this function
frames = frames[1:]
frame_infos = frame_infos[1:]
if num_frames is not None:
frame_infos = frame_infos[:num_frames]
# Filter out internal frames
if filter_frames:
frames = [
DebuggerFrame.from_frameinfo(frame_info)
for frame_info in frame_infos
if traceback_util.include_frame(frame_info.frame)
]
else:
frames = [
DebuggerFrame.from_frameinfo(frame_info)
for frame_info in frame_infos
]
flat_args, frames_tree = tree_util.tree_flatten(frames)
def _breakpoint_callback(*flat_args):

View File

@ -59,6 +59,8 @@ disabled_backends = []
if jaxlib.version < (0, 3, 15):
disabled_backends.append("tpu")
foo = 2
class CliDebuggerTest(jtu.JaxTestCase):
@jtu.skip_on_devices(*disabled_backends)
@ -321,8 +323,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
\(jdb\) """.format(re.escape(repr(arr))))
g(jnp.arange(8, dtype=jnp.int32))
jax.effects_barrier()
print(stdout.getvalue())
print(expected)
self.assertRegex(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
@ -344,10 +344,70 @@ class CliDebuggerTest(jtu.JaxTestCase):
\(jdb\) """)
f(2.)
jax.effects_barrier()
print(stdout.getvalue())
print(expected)
self.assertRegex(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_accesses_globals(self):
stdin, stdout = make_fake_stdin_stdout(["p foo", "c"])
@jax.jit
def g():
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) \*\*\* NameError: name 'foo' is not defined
\(jdb\) """)
g()
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_can_limit_num_frames(self):
stdin, stdout = make_fake_stdin_stdout(["u", "p x", "c"])
def g():
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli",
num_frames=2)
@jax.jit
def f():
x = 2
g()
return x
_ = f()
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) .*
.*
.*
.*
.*
.*
.*
\(jdb\) 2
\(jdb\) """)
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
stdin, stdout = make_fake_stdin_stdout(["u", "u", "c"])
def g2():
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli",
num_frames=2)
@jax.jit
def f2():
x = 2
g2()
return x
expected = ".*At topmost frame.*"
_ = f2()
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())