From d57d6fcee503e02ef6acf45383cf1a88bec7aac9 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 12 Jul 2022 13:52:14 -0700 Subject: [PATCH] Add webpdb option --- jax/_src/debugger/__init__.py | 2 + jax/_src/debugger/cli_debugger.py | 1 + jax/_src/debugger/core.py | 23 +++++---- jax/_src/debugger/web_debugger.py | 77 +++++++++++++++++++++++++++++++ 4 files changed, 93 insertions(+), 10 deletions(-) create mode 100644 jax/_src/debugger/web_debugger.py diff --git a/jax/_src/debugger/__init__.py b/jax/_src/debugger/__init__.py index 30fceb54f..e2e6b0474 100644 --- a/jax/_src/debugger/__init__.py +++ b/jax/_src/debugger/__init__.py @@ -14,6 +14,8 @@ from jax._src.debugger.core import breakpoint from jax._src.debugger import cli_debugger from jax._src.debugger import colab_debugger +from jax._src.debugger import web_debugger del cli_debugger # For registration only del colab_debugger # For registration only +del web_debugger # For registration only diff --git a/jax/_src/debugger/cli_debugger.py b/jax/_src/debugger/cli_debugger.py index 660556410..f2e013663 100644 --- a/jax/_src/debugger/cli_debugger.py +++ b/jax/_src/debugger/cli_debugger.py @@ -45,6 +45,7 @@ class CliDebugger(cmd.Cmd): env = {} curr_frame = self.frames[self.frame_index] env.update(curr_frame.locals) + env.update(curr_frame.globals) return eval(expr, {}, env) def default(self, arg): diff --git a/jax/_src/debugger/core.py b/jax/_src/debugger/core.py index ef085facc..d279ac49c 100644 --- a/jax/_src/debugger/core.py +++ b/jax/_src/debugger/core.py @@ -14,6 +14,7 @@ from __future__ import annotations import dataclasses +import functools import inspect import threading @@ -34,30 +35,31 @@ class DebuggerFrame: """Encapsulates Python frame information.""" filename: str locals: Dict[str, Any] + globals: Dict[str, Any] code_context: str source: List[str] lineno: int offset: Optional[int] def tree_flatten(self): - flat_locals, locals_tree = tree_util.tree_flatten(self.locals) + flat_vars, vars_tree = tree_util.tree_flatten((self.locals, self.globals)) is_valid = [ isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray)) - for l in flat_locals + for l in flat_vars ] - invalid_locals, valid_locals = util.partition_list(is_valid, flat_locals) - return valid_locals, (is_valid, invalid_locals, locals_tree, self.filename, + invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars) + return valid_vars, (is_valid, invalid_vars, vars_tree, self.filename, self.code_context, self.source, self.lineno, self.offset) @classmethod - def tree_unflatten(cls, info, valid_locals): - (is_valid, invalid_locals, locals_tree, filename, code_context, source, + def tree_unflatten(cls, info, valid_vars): + (is_valid, invalid_vars, vars_tree, filename, code_context, source, lineno, offset) = info - flat_locals = util.merge_lists(is_valid, invalid_locals, valid_locals) - locals_ = tree_util.tree_unflatten(locals_tree, flat_locals) - return DebuggerFrame(filename, locals_, code_context, source, lineno, - offset) + flat_vars = util.merge_lists(is_valid, invalid_vars, valid_vars) + locals_, globals_ = tree_util.tree_unflatten(vars_tree, flat_vars) + return DebuggerFrame(filename, locals_, globals_, code_context, source, + lineno, offset) @classmethod def from_frameinfo(cls, frame_info) -> DebuggerFrame: @@ -78,6 +80,7 @@ class DebuggerFrame: return DebuggerFrame( filename=frame_info.filename, locals=frame_info.frame.f_locals, + globals=frame_info.frame.f_globals, code_context=frame_info.code_context, source=source, lineno=frame_info.lineno, diff --git a/jax/_src/debugger/web_debugger.py b/jax/_src/debugger/web_debugger.py new file mode 100644 index 000000000..83aad1179 --- /dev/null +++ b/jax/_src/debugger/web_debugger.py @@ -0,0 +1,77 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations +import os +import weakref + +from typing import Any, Dict, List, Optional, Tuple + +from jax._src.debugger import cli_debugger +from jax._src.debugger import core as debugger_core + +try: + import web_pdb # pytype: disable=import-error + WEB_PDB_ENABLED = True +except: + WEB_PDB_ENABLED = False + + +_web_consoles: Dict[Tuple[str, int], web_pdb.WebConsole] = {} + +class WebDebugger(cli_debugger.CliDebugger): + """A web-based debugger.""" + prompt = '(jaxdb) ' + use_rawinput: bool = False + + def __init__(self, frames: List[debugger_core.DebuggerFrame], thread_id, + completekey: str = "tab", host: str = "0.0.0.0", port: int = 5555): + if (host, port) not in _web_consoles: + _web_consoles[host, port] = web_pdb.WebConsole(host, port, self) + # Clobber the debugger in the web console + _web_console = _web_consoles[host, port] + _web_console._debugger = weakref.proxy(self) + super().__init__(frames, thread_id, stdin=_web_console, stdout=_web_console, + completekey=completekey) + + def get_current_frame_data(self): + # Constructs the info needed for the web console to display info + current_frame = self.current_frame() + filename = current_frame.filename + lines = current_frame.source + locals = "\n".join([f"{key} = {value}" for key, value in + sorted(current_frame.locals.items())]) + globals = "\n".join([f"{key} = {value}" for key, value in + sorted(current_frame.globals.items())]) + current_line = None + if current_frame.offset is not None: + current_line = current_frame.offset + 1 + return { + 'dirname': os.path.dirname(os.path.abspath(filename)) + os.path.sep, + 'filename': os.path.basename(filename), + 'file_listing': '\n'.join(lines), + 'current_line': current_line, + 'breakpoints': [], + 'globals': globals, + 'locals': locals, + } + + def run(self): + return self.cmdloop() + +def run_debugger(frames: List[debugger_core.DebuggerFrame], + thread_id: Optional[int], **kwargs: Any): + WebDebugger(frames, thread_id, **kwargs).run() + +if WEB_PDB_ENABLED: + debugger_core.register_debugger("web", run_debugger, 0)