Refactor debugger to have a registry

This commit is contained in:
Sharad Vikram 2022-07-12 17:44:37 -07:00
parent 9e16efa98a
commit 7f8378e0db
2 changed files with 13 additions and 94 deletions

View File

@ -11,4 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.debugger.cli_debugger import breakpoint
from jax._src.debugger.core import breakpoint
from jax._src.debugger import cli_debugger
del cli_debugger # For registration only

View File

@ -14,106 +14,17 @@
from __future__ import annotations
import cmd
import dataclasses
import inspect
import sys
import threading
import traceback
from typing import Any, Callable, Dict, IO, List, Optional
from typing import Any, IO, List, Optional
import numpy as np
from jax import core
from jax import tree_util
from jax._src import debugging
from jax._src import traceback_util
from jax._src import util
import jax.numpy as jnp
from jax._src.debugger import core as debugger_core
@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class DebuggerFrame:
"""Encapsulates Python frame information."""
filename: str
locals: Dict[str, Any]
code_context: str
source: List[str]
lineno: int
offset: Optional[int]
DebuggerFrame = debugger_core.DebuggerFrame
def tree_flatten(self):
flat_locals, locals_tree = tree_util.tree_flatten(self.locals)
is_valid = [
isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray))
for l in flat_locals
]
invalid_locals, valid_locals = util.partition_list(is_valid, flat_locals)
return valid_locals, (is_valid, invalid_locals, locals_tree, self.filename,
self.code_context, self.source, self.lineno,
self.offset)
@classmethod
def tree_unflatten(cls, info, valid_locals):
(is_valid, invalid_locals, locals_tree, filename, code_context, source,
lineno, offset) = info
flat_locals = util.merge_lists(is_valid, invalid_locals, valid_locals)
locals_ = tree_util.tree_unflatten(locals_tree, flat_locals)
return DebuggerFrame(filename, locals_, code_context, source, lineno,
offset)
@classmethod
def from_frameinfo(cls, frame_info) -> DebuggerFrame:
try:
_, start = inspect.getsourcelines(frame_info.frame)
source = inspect.getsource(frame_info.frame).split('\n')
offset = frame_info.lineno - start
except OSError:
source = []
offset = None
return DebuggerFrame(
filename=frame_info.filename,
locals=frame_info.frame.f_locals,
code_context=frame_info.code_context,
source=source,
lineno=frame_info.lineno,
offset=offset)
debug_lock = threading.Lock()
def breakpoint(*, ordered: bool = False, **kwargs): # pylint: disable=redefined-builtin
"""Enters a breakpoint at a point in a program."""
frame_infos = inspect.stack()
# Filter out internal frames
frame_infos = [
frame_info for frame_info in frame_infos
if traceback_util.include_frame(frame_info.frame)
]
frames = [
DebuggerFrame.from_frameinfo(frame_info) for frame_info in frame_infos
]
# Throw out first frame corresponding to this function
frames = frames[1:]
flat_args, frames_tree = tree_util.tree_flatten(frames)
def _breakpoint_callback(*flat_args):
frames = tree_util.tree_unflatten(frames_tree, flat_args)
thread_id = None
if threading.current_thread() is not threading.main_thread():
thread_id = threading.get_ident()
with debug_lock:
TextDebugger(frames, thread_id, **kwargs).run()
if ordered:
effect = debugging.DebugEffect.ORDERED_PRINT
else:
effect = debugging.DebugEffect.PRINT
debugging.debug_callback(_breakpoint_callback, effect, *flat_args)
class TextDebugger(cmd.Cmd):
class CliDebugger(cmd.Cmd):
"""A text-based debugger."""
prompt = '(jaxdb) '
use_rawinput: bool = False
@ -200,3 +111,8 @@ class TextDebugger(cmd.Cmd):
break
except KeyboardInterrupt:
self.stdout.write('--KeyboardInterrupt--\n')
def run_debugger(frames: List[DebuggerFrame], thread_id: Optional[int],
**kwargs: Any):
CliDebugger(frames, thread_id, **kwargs).run()
debugger_core.register_debugger("cli", run_debugger, -1)