Merge pull request #11636 from sharadmv:webpdb

PiperOrigin-RevId: 463692257
This commit is contained in:
jax authors 2022-07-27 15:21:35 -07:00
commit d5fdd9e266
4 changed files with 93 additions and 10 deletions

View File

@ -14,6 +14,8 @@
from jax._src.debugger.core import breakpoint
from jax._src.debugger import cli_debugger
from jax._src.debugger import colab_debugger
from jax._src.debugger import web_debugger
del cli_debugger # For registration only
del colab_debugger # For registration only
del web_debugger # For registration only

View File

@ -45,6 +45,7 @@ class CliDebugger(cmd.Cmd):
env = {}
curr_frame = self.frames[self.frame_index]
env.update(curr_frame.locals)
env.update(curr_frame.globals)
return eval(expr, {}, env)
def default(self, arg):

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import dataclasses
import functools
import inspect
import threading
@ -34,30 +35,31 @@ class DebuggerFrame:
"""Encapsulates Python frame information."""
filename: str
locals: Dict[str, Any]
globals: Dict[str, Any]
code_context: str
source: List[str]
lineno: int
offset: Optional[int]
def tree_flatten(self):
flat_locals, locals_tree = tree_util.tree_flatten(self.locals)
flat_vars, vars_tree = tree_util.tree_flatten((self.locals, self.globals))
is_valid = [
isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray))
for l in flat_locals
for l in flat_vars
]
invalid_locals, valid_locals = util.partition_list(is_valid, flat_locals)
return valid_locals, (is_valid, invalid_locals, locals_tree, self.filename,
invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars)
return valid_vars, (is_valid, invalid_vars, vars_tree, self.filename,
self.code_context, self.source, self.lineno,
self.offset)
@classmethod
def tree_unflatten(cls, info, valid_locals):
(is_valid, invalid_locals, locals_tree, filename, code_context, source,
def tree_unflatten(cls, info, valid_vars):
(is_valid, invalid_vars, vars_tree, filename, code_context, source,
lineno, offset) = info
flat_locals = util.merge_lists(is_valid, invalid_locals, valid_locals)
locals_ = tree_util.tree_unflatten(locals_tree, flat_locals)
return DebuggerFrame(filename, locals_, code_context, source, lineno,
offset)
flat_vars = util.merge_lists(is_valid, invalid_vars, valid_vars)
locals_, globals_ = tree_util.tree_unflatten(vars_tree, flat_vars)
return DebuggerFrame(filename, locals_, globals_, code_context, source,
lineno, offset)
@classmethod
def from_frameinfo(cls, frame_info) -> DebuggerFrame:
@ -78,6 +80,7 @@ class DebuggerFrame:
return DebuggerFrame(
filename=frame_info.filename,
locals=frame_info.frame.f_locals,
globals=frame_info.frame.f_globals,
code_context=frame_info.code_context,
source=source,
lineno=frame_info.lineno,

View File

@ -0,0 +1,77 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import weakref
from typing import Any, Dict, List, Optional, Tuple
from jax._src.debugger import cli_debugger
from jax._src.debugger import core as debugger_core
try:
import web_pdb # pytype: disable=import-error
WEB_PDB_ENABLED = True
except:
WEB_PDB_ENABLED = False
_web_consoles: Dict[Tuple[str, int], web_pdb.WebConsole] = {}
class WebDebugger(cli_debugger.CliDebugger):
"""A web-based debugger."""
prompt = '(jaxdb) '
use_rawinput: bool = False
def __init__(self, frames: List[debugger_core.DebuggerFrame], thread_id,
completekey: str = "tab", host: str = "0.0.0.0", port: int = 5555):
if (host, port) not in _web_consoles:
_web_consoles[host, port] = web_pdb.WebConsole(host, port, self)
# Clobber the debugger in the web console
_web_console = _web_consoles[host, port]
_web_console._debugger = weakref.proxy(self)
super().__init__(frames, thread_id, stdin=_web_console, stdout=_web_console,
completekey=completekey)
def get_current_frame_data(self):
# Constructs the info needed for the web console to display info
current_frame = self.current_frame()
filename = current_frame.filename
lines = current_frame.source
locals = "\n".join([f"{key} = {value}" for key, value in
sorted(current_frame.locals.items())])
globals = "\n".join([f"{key} = {value}" for key, value in
sorted(current_frame.globals.items())])
current_line = None
if current_frame.offset is not None:
current_line = current_frame.offset + 1
return {
'dirname': os.path.dirname(os.path.abspath(filename)) + os.path.sep,
'filename': os.path.basename(filename),
'file_listing': '\n'.join(lines),
'current_line': current_line,
'breakpoints': [],
'globals': globals,
'locals': locals,
}
def run(self):
return self.cmdloop()
def run_debugger(frames: List[debugger_core.DebuggerFrame],
thread_id: Optional[int], **kwargs: Any):
WebDebugger(frames, thread_id, **kwargs).run()
if WEB_PDB_ENABLED:
debugger_core.register_debugger("web", run_debugger, 0)