mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #11806 from sharadmv:debugger-improvements
PiperOrigin-RevId: 466337260
This commit is contained in:
commit
870e8a2928
@ -80,7 +80,7 @@ class DebuggerFrame:
|
||||
return DebuggerFrame(
|
||||
filename=frame_info.filename,
|
||||
locals=frame_info.frame.f_locals,
|
||||
globals=frame_info.frame.f_globals,
|
||||
globals={},
|
||||
code_context=frame_info.code_context,
|
||||
source=source,
|
||||
lineno=frame_info.lineno,
|
||||
@ -113,19 +113,45 @@ def register_debugger(name: str, debugger: Debugger, priority: int) -> None:
|
||||
debug_lock = threading.Lock()
|
||||
|
||||
|
||||
def breakpoint(*, ordered: bool = False, backend=None, **kwargs): # pylint: disable=redefined-builtin
|
||||
"""Enters a breakpoint at a point in a program."""
|
||||
def breakpoint(*, backend: Optional[str] = None, filter_frames: bool = True,
|
||||
num_frames: Optional[int] = None, ordered: bool = False,
|
||||
**kwargs): # pylint: disable=redefined-builtin
|
||||
"""Enters a breakpoint at a point in a program.
|
||||
|
||||
Args:
|
||||
backend: The debugger backend to use. By default, picks the highest priority
|
||||
debugger and in the absence of other registered debuggers, falls back to
|
||||
the CLI debugger.
|
||||
filter_frames: Whether or not to filter out JAX-internal stack frames from
|
||||
the traceback. Since some libraries, like Flax, also make user of JAX's
|
||||
stack frame filtering system, this option can also affect whether stack
|
||||
frames from libraries are filtered.
|
||||
num_frames: The number of frames above the current stack frame to make
|
||||
available for inspection in the interactive debugger.
|
||||
ordered: A keyword only argument used to indicate whether or not the
|
||||
staged out computation will enforce ordering of this ``debug_print``
|
||||
with respect to other ordered ``debug_print`` calls.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
frame_infos = inspect.stack()
|
||||
# Filter out internal frames
|
||||
frame_infos = [
|
||||
frame_info for frame_info in frame_infos
|
||||
if traceback_util.include_frame(frame_info.frame)
|
||||
]
|
||||
frames = [
|
||||
DebuggerFrame.from_frameinfo(frame_info) for frame_info in frame_infos
|
||||
]
|
||||
# Throw out first frame corresponding to this function
|
||||
frames = frames[1:]
|
||||
frame_infos = frame_infos[1:]
|
||||
if num_frames is not None:
|
||||
frame_infos = frame_infos[:num_frames]
|
||||
# Filter out internal frames
|
||||
if filter_frames:
|
||||
frames = [
|
||||
DebuggerFrame.from_frameinfo(frame_info)
|
||||
for frame_info in frame_infos
|
||||
if traceback_util.include_frame(frame_info.frame)
|
||||
]
|
||||
else:
|
||||
frames = [
|
||||
DebuggerFrame.from_frameinfo(frame_info)
|
||||
for frame_info in frame_infos
|
||||
]
|
||||
flat_args, frames_tree = tree_util.tree_flatten(frames)
|
||||
|
||||
def _breakpoint_callback(*flat_args):
|
||||
|
@ -59,6 +59,8 @@ disabled_backends = []
|
||||
if jaxlib.version < (0, 3, 15):
|
||||
disabled_backends.append("tpu")
|
||||
|
||||
foo = 2
|
||||
|
||||
class CliDebuggerTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices(*disabled_backends)
|
||||
@ -321,8 +323,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
\(jdb\) """.format(re.escape(repr(arr))))
|
||||
g(jnp.arange(8, dtype=jnp.int32))
|
||||
jax.effects_barrier()
|
||||
print(stdout.getvalue())
|
||||
print(expected)
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
@jtu.skip_on_devices(*disabled_backends)
|
||||
@ -344,10 +344,70 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
\(jdb\) """)
|
||||
f(2.)
|
||||
jax.effects_barrier()
|
||||
print(stdout.getvalue())
|
||||
print(expected)
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
|
||||
@jtu.skip_on_devices(*disabled_backends)
|
||||
def test_debugger_accesses_globals(self):
|
||||
stdin, stdout = make_fake_stdin_stdout(["p foo", "c"])
|
||||
|
||||
@jax.jit
|
||||
def g():
|
||||
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
|
||||
|
||||
expected = _format_multiline(r"""
|
||||
Entering jdb:
|
||||
\(jdb\) \*\*\* NameError: name 'foo' is not defined
|
||||
\(jdb\) """)
|
||||
g()
|
||||
jax.effects_barrier()
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
@jtu.skip_on_devices(*disabled_backends)
|
||||
def test_can_limit_num_frames(self):
|
||||
stdin, stdout = make_fake_stdin_stdout(["u", "p x", "c"])
|
||||
|
||||
def g():
|
||||
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli",
|
||||
num_frames=2)
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
x = 2
|
||||
g()
|
||||
return x
|
||||
|
||||
_ = f()
|
||||
expected = _format_multiline(r"""
|
||||
Entering jdb:
|
||||
\(jdb\) .*
|
||||
.*
|
||||
.*
|
||||
.*
|
||||
.*
|
||||
.*
|
||||
.*
|
||||
\(jdb\) 2
|
||||
\(jdb\) """)
|
||||
jax.effects_barrier()
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
stdin, stdout = make_fake_stdin_stdout(["u", "u", "c"])
|
||||
|
||||
def g2():
|
||||
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli",
|
||||
num_frames=2)
|
||||
|
||||
@jax.jit
|
||||
def f2():
|
||||
x = 2
|
||||
g2()
|
||||
return x
|
||||
|
||||
expected = ".*At topmost frame.*"
|
||||
_ = f2()
|
||||
jax.effects_barrier()
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user