From 09fd173a3e24d9a12e81b9d893a982a62d37c071 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 12 Jul 2022 18:30:40 -0700 Subject: [PATCH] Add colab debugger --- jax/_src/debugger/__init__.py | 2 + jax/_src/debugger/cli_debugger.py | 1 - jax/_src/debugger/colab_debugger.py | 313 ++++++++++++++++++++++++++++ jax/_src/debugger/colab_lib.py | 168 +++++++++++++++ jax/_src/debugger/core.py | 134 ++++++++++++ 5 files changed, 617 insertions(+), 1 deletion(-) create mode 100644 jax/_src/debugger/colab_debugger.py create mode 100644 jax/_src/debugger/colab_lib.py create mode 100644 jax/_src/debugger/core.py diff --git a/jax/_src/debugger/__init__.py b/jax/_src/debugger/__init__.py index 019890143..30fceb54f 100644 --- a/jax/_src/debugger/__init__.py +++ b/jax/_src/debugger/__init__.py @@ -13,5 +13,7 @@ # limitations under the License. from jax._src.debugger.core import breakpoint from jax._src.debugger import cli_debugger +from jax._src.debugger import colab_debugger del cli_debugger # For registration only +del colab_debugger # For registration only diff --git a/jax/_src/debugger/cli_debugger.py b/jax/_src/debugger/cli_debugger.py index da80840f3..79b222cc6 100644 --- a/jax/_src/debugger/cli_debugger.py +++ b/jax/_src/debugger/cli_debugger.py @@ -21,7 +21,6 @@ from typing import Any, IO, List, Optional from jax._src.debugger import core as debugger_core - DebuggerFrame = debugger_core.DebuggerFrame class CliDebugger(cmd.Cmd): diff --git a/jax/_src/debugger/colab_debugger.py b/jax/_src/debugger/colab_debugger.py new file mode 100644 index 000000000..91632b2bf --- /dev/null +++ b/jax/_src/debugger/colab_debugger.py @@ -0,0 +1,313 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for Colab-specific debugger.""" +from __future__ import annotations + +import html +import inspect +import traceback + +from typing import List + +import uuid + +from jax._src.debugger import colab_lib +from jax._src.debugger import core as debugger_core +from jax._src.debugger import cli_debugger + +# pylint: disable=g-import-not-at-top +# pytype: disable=import-error +if colab_lib.IS_COLAB_ENABLED: + from google.colab import output +try: + import pygments + IS_PYGMENTS_ENABLED = True +except ImportError: + IS_PYGMENTS_ENABLED = False +# pytype: enable=import-error +# pylint: enable=g-import-not-at-top + + +class CodeViewer(colab_lib.DynamicDOMElement): + """A mutable DOM element that displays code as HTML.""" + + def __init__(self, code_: str, highlights: List[int], linenostart: int = 1): + self._code = code_ + self._highlights = highlights + self._view = colab_lib.dynamic(colab_lib.div()) + self._linenostart = linenostart + + def render(self): + self.update_code( + self._code, self._highlights, linenostart=self._linenostart) + + def clear(self): + self._view.clear() + + def append(self, child): + raise NotImplementedError + + def update(self, elem): + self._view.update(elem) + + def _highlight_code(self, code: str, highlights, linenostart: int): + is_dark_mode = output.eval_js( + 'document.documentElement.matches("[theme=dark]");') + code_style = "monokai" if is_dark_mode else "default" + hl_color = "#4e56b7" if is_dark_mode else "#fff7c1" + if IS_PYGMENTS_ENABLED: + lexer = pygments.lexers.get_lexer_by_name("python") + formatter = pygments.formatters.HtmlFormatter( + full=False, + hl_lines=highlights, + linenos=True, + linenostart=linenostart, + style=code_style) + if hl_color: + formatter.style.highlight_color = hl_color + css_ = formatter.get_style_defs() + code = pygments.highlight(code, lexer, formatter) + else: + return ""; + return code, css_ + + def update_code(self, code_, highlights, *, linenostart: int = 1): + """Updates the code viewer to use new code.""" + self._code = code_ + self._view.clear() + code_, css_ = self._highlight_code(self._code, highlights, linenostart) + uuid_ = uuid.uuid4() + code_div = colab_lib.div( + colab_lib.css(css_), + code_, + id=f"code-{uuid_}", + style=colab_lib.style({ + "max-height": "500px", + "overflow-y": "scroll", + "background-color": "var(--colab-border-color)", + "padding": "5px 5px 5px 5px", + })) + if highlights: + percent_scroll = highlights[0] / len(self._code.split("\n")) + else: + percent_scroll = 0. + self.update(code_div) + # Scroll to where the line is + output.eval_js(""" + console.log("{id}") + var elem = document.getElementById("{id}") + var maxScrollPosition = elem.scrollHeight - elem.clientHeight; + elem.scrollTop = maxScrollPosition * {percent_scroll} + """.format(id=f"code-{uuid_}", percent_scroll=percent_scroll)) + + +class FramePreview(colab_lib.DynamicDOMElement): + """Displays information about a stack frame.""" + + def __init__(self, frame): + super().__init__() + self._header = colab_lib.dynamic( + colab_lib.div(colab_lib.pre(colab_lib.code("")))) + self._code_view = CodeViewer("", highlights=[]) + self.frame = frame + self._file_cache = {} + + def clear(self): + self._header.clear() + self._code_view.clear() + + def append(self, child): + raise NotImplementedError + + def update(self, elem): + raise NotImplementedError + + def update_frame(self, frame): + """Updates the frame viewer to use a new frame.""" + self.frame = frame + lineno = self.frame.lineno or None + filename = self.frame.filename.strip() + if inspect.getmodulename(filename): + if filename not in self._file_cache: + try: + with open(filename, "r") as fp: + self._file_cache[filename] = fp.read() + source = self._file_cache[filename] + highlight = lineno + linenostart = 1 + except FileNotFoundError: + source = "\n".join(frame.source) + highlight = min(frame.offset + 1, len(frame.source) - 1) + linenostart = lineno - frame.offset + else: + source = "\n".join(frame.source) + highlight = min(frame.offset + 1, len(frame.source) - 1) + linenostart = lineno - frame.offset + self._header.clear() + self._header.update( + colab_lib.div( + colab_lib.pre(colab_lib.code(f"{html.escape(filename)}({lineno})")), + style=colab_lib.style({ + "padding": "5px 5px 5px 5px", + "background-color": "var(--colab-highlighted-surface-color)", + }))) + self._code_view.update_code(source, [highlight], linenostart=linenostart) + + def render(self): + self.update_frame(self.frame) + + +class DebuggerView(colab_lib.DynamicDOMElement): + """Main view for the Colab debugger.""" + + def __init__(self, frame, *, log_color=""): + super().__init__() + self._interaction_log = colab_lib.dynamic(colab_lib.div()) + self._frame_preview = FramePreview(frame) + self._header = colab_lib.dynamic( + colab_lib.div( + colab_lib.span("Breakpoint"), + style=colab_lib.style({ + "background-color": "var(--colab-secondary-surface-color)", + "color": "var(--colab-primary-text-color)", + "padding": "5px 5px 5px 5px", + "font-weight": "bold", + }))) + + def render(self): + self._header.render() + self._frame_preview.render() + self._interaction_log.render() + + def append(self, child): + raise NotImplementedError + + def update(self, elem): + raise NotImplementedError + + def clear(self): + self._header.clear() + self._interaction_log.clear() + self._frame_preview.clear() + + def update_frame(self, frame): + self._frame_preview.update_frame(frame) + + def log(self, text): + self._interaction_log.append(colab_lib.pre(text)) + + def read(self): + with output.use_tags(["stdin"]): + user_input = input() + output.clear(output_tags=["stdin"]) + return user_input + + +class ColabDebugger(cli_debugger.CliDebugger): + """A JAX debugger for a Colab environment.""" + + def __init__(self, + frames: List[debugger_core.DebuggerFrame], + thread_id: int): + super().__init__(frames, thread_id) + self._debugger_view = DebuggerView(self.current_frame()) + + def read(self): + return self._debugger_view.read() + + def cmdloop(self, intro=None): + self.preloop() + stop = None + while not stop: + if self.cmdqueue: + line = self.cmdqueue.pop(0) + else: + try: + line = self.read() + except EOFError: + line = "EOF" + line = self.precmd(line) + stop = self.onecmd(line) + stop = self.postcmd(stop, line) + self.postloop() + + def do_u(self, _): + if self.frame_index == len(self.frames) - 1: + self.log("At topmost frame.") + return False + self.frame_index += 1 + self._debugger_view.update_frame(self.current_frame()) + return False + + def do_d(self, _): + if self.frame_index == 0: + self.log("At bottommost frame.") + return False + self.frame_index -= 1 + self._debugger_view.update_frame(self.current_frame()) + return False + + def do_bt(self, _): + self.log("Traceback:") + for frame in self.frames[::-1]: + filename = frame.filename.strip() + filename = filename or "" + self.log(f" File: {filename}, line ({frame.lineno})") + if frame.offset < len(frame.source): + line = frame.source[frame.offset] + self.log(f" {line.strip()}") + else: + self.log(" ") + + def do_c(self, _): + return True + + def do_q(self, _): + return True + + def do_EOF(self, _): + return True + + def do_p(self, arg): + try: + value = self.evaluate(arg) + self.log(repr(value)) + except Exception: # pylint: disable=broad-except + self.log(traceback.format_exc(limit=1)) + return False + + do_pp = do_p + + def log(self, text): + self._debugger_view.log(html.escape(text)) + + def run(self): + self._debugger_view.render() + try: + self.cmdloop() + except KeyboardInterrupt: + self.log("--Keyboard-Interrupt--") + pass + self._debugger_view.clear() + + +def _run_debugger(frames, thread_id, **kwargs): + try: + ColabDebugger(frames, thread_id, **kwargs).run() + except Exception: + traceback.print_exc() + + +if colab_lib.IS_COLAB_ENABLED: + debugger_core.register_debugger("colab", _run_debugger, 1) diff --git a/jax/_src/debugger/colab_lib.py b/jax/_src/debugger/colab_lib.py new file mode 100644 index 000000000..97c7d53f3 --- /dev/null +++ b/jax/_src/debugger/colab_lib.py @@ -0,0 +1,168 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for building interfaces in Colab.""" +from __future__ import annotations + +import abc +import dataclasses +import functools +import uuid + +from typing import Any, Dict, List, Union + +# pylint: disable=g-import-not-at-top +# pytype: disable=import-error +try: + from google.colab import output + from IPython import display + IS_COLAB_ENABLED = True +except ImportError: + IS_COLAB_ENABLED = False +# pytype: enable=import-error +# pylint: enable=g-import-not-at-top + + +class DOMElement(metaclass=abc.ABCMeta): + + @abc.abstractmethod + def render(self): + pass + + +Element = Union[DOMElement, str] + + +class DynamicDOMElement(DOMElement): + """A DOM element that can be mutated.""" + + @abc.abstractmethod + def render(self): + pass + + @abc.abstractmethod + def append(self, child: DOMElement): + pass + + @abc.abstractmethod + def update(self, elem: DOMElement): + pass + + @abc.abstractmethod + def clear(self): + pass + +@dataclasses.dataclass +class DynamicDiv(DynamicDOMElement): + """A `div` that can be edited.""" + _uuid: str = dataclasses.field(init=False) + _root_elem: DOMElement = dataclasses.field(init=False) + elem: Union[DOMElement, str] + + def __post_init__(self): + self._uuid = str(uuid.uuid4()) + self._rendered = False + self._root_elem = div(id=self.tag) + + @property + def tag(self): + return f"tag-{self._uuid}" + + def render(self): + if self._rendered: + raise ValueError("Can't call `render` twice.") + self._root_elem.render() + self._rendered = True + self.append(self.elem) + + def append(self, child: DOMElement): + if not self._rendered: + self.render() + with output.use_tags([self.tag]): + with output.redirect_to_element(f"#{self.tag}"): + child.render() + + def update(self, elem: DOMElement): + self.clear() + self.elem = elem + self.render() + + def clear(self): + output.clear(output_tags=[self.tag]) + self._rendered = False + + +@dataclasses.dataclass +class StaticDOMElement(DOMElement): + """An immutable DOM element.""" + _uuid: str = dataclasses.field(init=False) + name: str + children: List[Union[str, DOMElement]] + attrs: Dict[str, str] + + def html(self): + attr_str = "" + if self.attrs: + attr_str = " " + (" ".join( + [f"{key}=\"{value}\"" for key, value in self.attrs.items()])) + children = [] + children = "\n".join([str(c) for c in self.children]) + return f"<{self.name}{attr_str}>{children}" + + def render(self): + display.display(display.HTML(self.html())) + + def attr(self, key: str) -> str: + return self.attrs[key] + + def __str__(self): + return self.html() + + def __repr__(self): + return self.html() + + def append(self, child: DOMElement) -> DOMElement: + return dataclasses.replace(self, children=[*self.children, child]) + + def replace(self, **kwargs) -> DOMElement: + return dataclasses.replace(self, **kwargs) + + +def _style_dict_to_str(style_dict: Dict[str, Any]) -> str: + return " ".join([f"{k}: {v};" for k, v in style_dict.items()]) + + +def dynamic(elem: StaticDOMElement) -> DynamicDiv: + return DynamicDiv(elem) + + +def _make_elem(tag: str, *children: Element, **attrs) -> StaticDOMElement: + """Helper function for making DOM elements.""" + return StaticDOMElement(tag, list(children), attrs) + + +code = functools.partial(_make_elem, "code") +div = functools.partial(_make_elem, "div") +li = functools.partial(_make_elem, "li") +ol = functools.partial(_make_elem, "ol") +pre = functools.partial(_make_elem, "pre") +progress = functools.partial(_make_elem, "progress") +span = functools.partial(_make_elem, "span") + + +def css(text: str) -> StaticDOMElement: + return StaticDOMElement("style", [text], {}) + + +def style(*args, **kwargs): + return _style_dict_to_str(dict(*args, **kwargs)) diff --git a/jax/_src/debugger/core.py b/jax/_src/debugger/core.py new file mode 100644 index 000000000..f7b220235 --- /dev/null +++ b/jax/_src/debugger/core.py @@ -0,0 +1,134 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import inspect +import threading + +from typing import Any, Dict, List, Optional, Tuple +from typing_extensions import Protocol + +import jax.numpy as jnp +from jax import core +from jax import tree_util +from jax._src import debugging +from jax._src import traceback_util +from jax._src import util +import numpy as np + +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class DebuggerFrame: + """Encapsulates Python frame information.""" + filename: str + locals: Dict[str, Any] + code_context: str + source: List[str] + lineno: int + offset: Optional[int] + + def tree_flatten(self): + flat_locals, locals_tree = tree_util.tree_flatten(self.locals) + is_valid = [ + isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray)) + for l in flat_locals + ] + invalid_locals, valid_locals = util.partition_list(is_valid, flat_locals) + return valid_locals, (is_valid, invalid_locals, locals_tree, self.filename, + self.code_context, self.source, self.lineno, + self.offset) + + @classmethod + def tree_unflatten(cls, info, valid_locals): + (is_valid, invalid_locals, locals_tree, filename, code_context, source, + lineno, offset) = info + flat_locals = util.merge_lists(is_valid, invalid_locals, valid_locals) + locals_ = tree_util.tree_unflatten(locals_tree, flat_locals) + return DebuggerFrame(filename, locals_, code_context, source, lineno, + offset) + + @classmethod + def from_frameinfo(cls, frame_info) -> DebuggerFrame: + try: + _, start = inspect.getsourcelines(frame_info.frame) + source = inspect.getsource(frame_info.frame).split('\n') + offset = frame_info.lineno - start + except OSError: + source = [] + offset = None + return DebuggerFrame( + filename=frame_info.filename, + locals=frame_info.frame.f_locals, + code_context=frame_info.code_context, + source=source, + lineno=frame_info.lineno, + offset=offset) + + +class Debugger(Protocol): + + def __call__(self, frames: List[DebuggerFrame], thread_id: Optional[int], + **kwargs: Any) -> None: + ... +_debugger_registry: Dict[str, Tuple[int, Debugger]] = {} + + +def get_debugger() -> Debugger: + debuggers = sorted(_debugger_registry.values(), key=lambda x: -x[0]) + if not debuggers: + raise ValueError("No debuggers registered!") + return debuggers[0][1] + + +def register_debugger(name: str, debugger: Debugger, priority: int) -> None: + if name in _debugger_registry: + raise ValueError(f"Debugger with name \"{name}\" already registered.") + _debugger_registry[name] = (priority, debugger) + + +debug_lock = threading.Lock() + + +def breakpoint(*, ordered: bool = False, **kwargs): # pylint: disable=redefined-builtin + """Enters a breakpoint at a point in a program.""" + frame_infos = inspect.stack() + # Filter out internal frames + frame_infos = [ + frame_info for frame_info in frame_infos + if traceback_util.include_frame(frame_info.frame) + ] + frames = [ + DebuggerFrame.from_frameinfo(frame_info) for frame_info in frame_infos + ] + # Throw out first frame corresponding to this function + frames = frames[1:] + flat_args, frames_tree = tree_util.tree_flatten(frames) + + def _breakpoint_callback(*flat_args): + frames = tree_util.tree_unflatten(frames_tree, flat_args) + thread_id = None + if threading.current_thread() is not threading.main_thread(): + thread_id = threading.get_ident() + debugger = get_debugger() + # Lock here because this could be called from multiple threads at the same + # time. + with debug_lock: + debugger(frames, thread_id, **kwargs) + + if ordered: + effect = debugging.DebugEffect.ORDERED_PRINT + else: + effect = debugging.DebugEffect.PRINT + debugging.debug_callback(_breakpoint_callback, effect, *flat_args)