Add colab debugger

This commit is contained in:
Sharad Vikram 2022-07-12 18:30:40 -07:00
parent 7f8378e0db
commit 09fd173a3e
5 changed files with 617 additions and 1 deletions

View File

@ -13,5 +13,7 @@
# limitations under the License.
from jax._src.debugger.core import breakpoint
from jax._src.debugger import cli_debugger
from jax._src.debugger import colab_debugger
del cli_debugger # For registration only
del colab_debugger # For registration only

View File

@ -21,7 +21,6 @@ from typing import Any, IO, List, Optional
from jax._src.debugger import core as debugger_core
DebuggerFrame = debugger_core.DebuggerFrame
class CliDebugger(cmd.Cmd):

View File

@ -0,0 +1,313 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for Colab-specific debugger."""
from __future__ import annotations
import html
import inspect
import traceback
from typing import List
import uuid
from jax._src.debugger import colab_lib
from jax._src.debugger import core as debugger_core
from jax._src.debugger import cli_debugger
# pylint: disable=g-import-not-at-top
# pytype: disable=import-error
if colab_lib.IS_COLAB_ENABLED:
from google.colab import output
try:
import pygments
IS_PYGMENTS_ENABLED = True
except ImportError:
IS_PYGMENTS_ENABLED = False
# pytype: enable=import-error
# pylint: enable=g-import-not-at-top
class CodeViewer(colab_lib.DynamicDOMElement):
"""A mutable DOM element that displays code as HTML."""
def __init__(self, code_: str, highlights: List[int], linenostart: int = 1):
self._code = code_
self._highlights = highlights
self._view = colab_lib.dynamic(colab_lib.div())
self._linenostart = linenostart
def render(self):
self.update_code(
self._code, self._highlights, linenostart=self._linenostart)
def clear(self):
self._view.clear()
def append(self, child):
raise NotImplementedError
def update(self, elem):
self._view.update(elem)
def _highlight_code(self, code: str, highlights, linenostart: int):
is_dark_mode = output.eval_js(
'document.documentElement.matches("[theme=dark]");')
code_style = "monokai" if is_dark_mode else "default"
hl_color = "#4e56b7" if is_dark_mode else "#fff7c1"
if IS_PYGMENTS_ENABLED:
lexer = pygments.lexers.get_lexer_by_name("python")
formatter = pygments.formatters.HtmlFormatter(
full=False,
hl_lines=highlights,
linenos=True,
linenostart=linenostart,
style=code_style)
if hl_color:
formatter.style.highlight_color = hl_color
css_ = formatter.get_style_defs()
code = pygments.highlight(code, lexer, formatter)
else:
return "";
return code, css_
def update_code(self, code_, highlights, *, linenostart: int = 1):
"""Updates the code viewer to use new code."""
self._code = code_
self._view.clear()
code_, css_ = self._highlight_code(self._code, highlights, linenostart)
uuid_ = uuid.uuid4()
code_div = colab_lib.div(
colab_lib.css(css_),
code_,
id=f"code-{uuid_}",
style=colab_lib.style({
"max-height": "500px",
"overflow-y": "scroll",
"background-color": "var(--colab-border-color)",
"padding": "5px 5px 5px 5px",
}))
if highlights:
percent_scroll = highlights[0] / len(self._code.split("\n"))
else:
percent_scroll = 0.
self.update(code_div)
# Scroll to where the line is
output.eval_js("""
console.log("{id}")
var elem = document.getElementById("{id}")
var maxScrollPosition = elem.scrollHeight - elem.clientHeight;
elem.scrollTop = maxScrollPosition * {percent_scroll}
""".format(id=f"code-{uuid_}", percent_scroll=percent_scroll))
class FramePreview(colab_lib.DynamicDOMElement):
"""Displays information about a stack frame."""
def __init__(self, frame):
super().__init__()
self._header = colab_lib.dynamic(
colab_lib.div(colab_lib.pre(colab_lib.code(""))))
self._code_view = CodeViewer("", highlights=[])
self.frame = frame
self._file_cache = {}
def clear(self):
self._header.clear()
self._code_view.clear()
def append(self, child):
raise NotImplementedError
def update(self, elem):
raise NotImplementedError
def update_frame(self, frame):
"""Updates the frame viewer to use a new frame."""
self.frame = frame
lineno = self.frame.lineno or None
filename = self.frame.filename.strip()
if inspect.getmodulename(filename):
if filename not in self._file_cache:
try:
with open(filename, "r") as fp:
self._file_cache[filename] = fp.read()
source = self._file_cache[filename]
highlight = lineno
linenostart = 1
except FileNotFoundError:
source = "\n".join(frame.source)
highlight = min(frame.offset + 1, len(frame.source) - 1)
linenostart = lineno - frame.offset
else:
source = "\n".join(frame.source)
highlight = min(frame.offset + 1, len(frame.source) - 1)
linenostart = lineno - frame.offset
self._header.clear()
self._header.update(
colab_lib.div(
colab_lib.pre(colab_lib.code(f"{html.escape(filename)}({lineno})")),
style=colab_lib.style({
"padding": "5px 5px 5px 5px",
"background-color": "var(--colab-highlighted-surface-color)",
})))
self._code_view.update_code(source, [highlight], linenostart=linenostart)
def render(self):
self.update_frame(self.frame)
class DebuggerView(colab_lib.DynamicDOMElement):
"""Main view for the Colab debugger."""
def __init__(self, frame, *, log_color=""):
super().__init__()
self._interaction_log = colab_lib.dynamic(colab_lib.div())
self._frame_preview = FramePreview(frame)
self._header = colab_lib.dynamic(
colab_lib.div(
colab_lib.span("Breakpoint"),
style=colab_lib.style({
"background-color": "var(--colab-secondary-surface-color)",
"color": "var(--colab-primary-text-color)",
"padding": "5px 5px 5px 5px",
"font-weight": "bold",
})))
def render(self):
self._header.render()
self._frame_preview.render()
self._interaction_log.render()
def append(self, child):
raise NotImplementedError
def update(self, elem):
raise NotImplementedError
def clear(self):
self._header.clear()
self._interaction_log.clear()
self._frame_preview.clear()
def update_frame(self, frame):
self._frame_preview.update_frame(frame)
def log(self, text):
self._interaction_log.append(colab_lib.pre(text))
def read(self):
with output.use_tags(["stdin"]):
user_input = input()
output.clear(output_tags=["stdin"])
return user_input
class ColabDebugger(cli_debugger.CliDebugger):
"""A JAX debugger for a Colab environment."""
def __init__(self,
frames: List[debugger_core.DebuggerFrame],
thread_id: int):
super().__init__(frames, thread_id)
self._debugger_view = DebuggerView(self.current_frame())
def read(self):
return self._debugger_view.read()
def cmdloop(self, intro=None):
self.preloop()
stop = None
while not stop:
if self.cmdqueue:
line = self.cmdqueue.pop(0)
else:
try:
line = self.read()
except EOFError:
line = "EOF"
line = self.precmd(line)
stop = self.onecmd(line)
stop = self.postcmd(stop, line)
self.postloop()
def do_u(self, _):
if self.frame_index == len(self.frames) - 1:
self.log("At topmost frame.")
return False
self.frame_index += 1
self._debugger_view.update_frame(self.current_frame())
return False
def do_d(self, _):
if self.frame_index == 0:
self.log("At bottommost frame.")
return False
self.frame_index -= 1
self._debugger_view.update_frame(self.current_frame())
return False
def do_bt(self, _):
self.log("Traceback:")
for frame in self.frames[::-1]:
filename = frame.filename.strip()
filename = filename or "<no filename>"
self.log(f" File: {filename}, line ({frame.lineno})")
if frame.offset < len(frame.source):
line = frame.source[frame.offset]
self.log(f" {line.strip()}")
else:
self.log(" ")
def do_c(self, _):
return True
def do_q(self, _):
return True
def do_EOF(self, _):
return True
def do_p(self, arg):
try:
value = self.evaluate(arg)
self.log(repr(value))
except Exception: # pylint: disable=broad-except
self.log(traceback.format_exc(limit=1))
return False
do_pp = do_p
def log(self, text):
self._debugger_view.log(html.escape(text))
def run(self):
self._debugger_view.render()
try:
self.cmdloop()
except KeyboardInterrupt:
self.log("--Keyboard-Interrupt--")
pass
self._debugger_view.clear()
def _run_debugger(frames, thread_id, **kwargs):
try:
ColabDebugger(frames, thread_id, **kwargs).run()
except Exception:
traceback.print_exc()
if colab_lib.IS_COLAB_ENABLED:
debugger_core.register_debugger("colab", _run_debugger, 1)

View File

@ -0,0 +1,168 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for building interfaces in Colab."""
from __future__ import annotations
import abc
import dataclasses
import functools
import uuid
from typing import Any, Dict, List, Union
# pylint: disable=g-import-not-at-top
# pytype: disable=import-error
try:
from google.colab import output
from IPython import display
IS_COLAB_ENABLED = True
except ImportError:
IS_COLAB_ENABLED = False
# pytype: enable=import-error
# pylint: enable=g-import-not-at-top
class DOMElement(metaclass=abc.ABCMeta):
@abc.abstractmethod
def render(self):
pass
Element = Union[DOMElement, str]
class DynamicDOMElement(DOMElement):
"""A DOM element that can be mutated."""
@abc.abstractmethod
def render(self):
pass
@abc.abstractmethod
def append(self, child: DOMElement):
pass
@abc.abstractmethod
def update(self, elem: DOMElement):
pass
@abc.abstractmethod
def clear(self):
pass
@dataclasses.dataclass
class DynamicDiv(DynamicDOMElement):
"""A `div` that can be edited."""
_uuid: str = dataclasses.field(init=False)
_root_elem: DOMElement = dataclasses.field(init=False)
elem: Union[DOMElement, str]
def __post_init__(self):
self._uuid = str(uuid.uuid4())
self._rendered = False
self._root_elem = div(id=self.tag)
@property
def tag(self):
return f"tag-{self._uuid}"
def render(self):
if self._rendered:
raise ValueError("Can't call `render` twice.")
self._root_elem.render()
self._rendered = True
self.append(self.elem)
def append(self, child: DOMElement):
if not self._rendered:
self.render()
with output.use_tags([self.tag]):
with output.redirect_to_element(f"#{self.tag}"):
child.render()
def update(self, elem: DOMElement):
self.clear()
self.elem = elem
self.render()
def clear(self):
output.clear(output_tags=[self.tag])
self._rendered = False
@dataclasses.dataclass
class StaticDOMElement(DOMElement):
"""An immutable DOM element."""
_uuid: str = dataclasses.field(init=False)
name: str
children: List[Union[str, DOMElement]]
attrs: Dict[str, str]
def html(self):
attr_str = ""
if self.attrs:
attr_str = " " + (" ".join(
[f"{key}=\"{value}\"" for key, value in self.attrs.items()]))
children = []
children = "\n".join([str(c) for c in self.children])
return f"<{self.name}{attr_str}>{children}</{self.name}>"
def render(self):
display.display(display.HTML(self.html()))
def attr(self, key: str) -> str:
return self.attrs[key]
def __str__(self):
return self.html()
def __repr__(self):
return self.html()
def append(self, child: DOMElement) -> DOMElement:
return dataclasses.replace(self, children=[*self.children, child])
def replace(self, **kwargs) -> DOMElement:
return dataclasses.replace(self, **kwargs)
def _style_dict_to_str(style_dict: Dict[str, Any]) -> str:
return " ".join([f"{k}: {v};" for k, v in style_dict.items()])
def dynamic(elem: StaticDOMElement) -> DynamicDiv:
return DynamicDiv(elem)
def _make_elem(tag: str, *children: Element, **attrs) -> StaticDOMElement:
"""Helper function for making DOM elements."""
return StaticDOMElement(tag, list(children), attrs)
code = functools.partial(_make_elem, "code")
div = functools.partial(_make_elem, "div")
li = functools.partial(_make_elem, "li")
ol = functools.partial(_make_elem, "ol")
pre = functools.partial(_make_elem, "pre")
progress = functools.partial(_make_elem, "progress")
span = functools.partial(_make_elem, "span")
def css(text: str) -> StaticDOMElement:
return StaticDOMElement("style", [text], {})
def style(*args, **kwargs):
return _style_dict_to_str(dict(*args, **kwargs))

134
jax/_src/debugger/core.py Normal file
View File

@ -0,0 +1,134 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
import inspect
import threading
from typing import Any, Dict, List, Optional, Tuple
from typing_extensions import Protocol
import jax.numpy as jnp
from jax import core
from jax import tree_util
from jax._src import debugging
from jax._src import traceback_util
from jax._src import util
import numpy as np
@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class DebuggerFrame:
"""Encapsulates Python frame information."""
filename: str
locals: Dict[str, Any]
code_context: str
source: List[str]
lineno: int
offset: Optional[int]
def tree_flatten(self):
flat_locals, locals_tree = tree_util.tree_flatten(self.locals)
is_valid = [
isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray))
for l in flat_locals
]
invalid_locals, valid_locals = util.partition_list(is_valid, flat_locals)
return valid_locals, (is_valid, invalid_locals, locals_tree, self.filename,
self.code_context, self.source, self.lineno,
self.offset)
@classmethod
def tree_unflatten(cls, info, valid_locals):
(is_valid, invalid_locals, locals_tree, filename, code_context, source,
lineno, offset) = info
flat_locals = util.merge_lists(is_valid, invalid_locals, valid_locals)
locals_ = tree_util.tree_unflatten(locals_tree, flat_locals)
return DebuggerFrame(filename, locals_, code_context, source, lineno,
offset)
@classmethod
def from_frameinfo(cls, frame_info) -> DebuggerFrame:
try:
_, start = inspect.getsourcelines(frame_info.frame)
source = inspect.getsource(frame_info.frame).split('\n')
offset = frame_info.lineno - start
except OSError:
source = []
offset = None
return DebuggerFrame(
filename=frame_info.filename,
locals=frame_info.frame.f_locals,
code_context=frame_info.code_context,
source=source,
lineno=frame_info.lineno,
offset=offset)
class Debugger(Protocol):
def __call__(self, frames: List[DebuggerFrame], thread_id: Optional[int],
**kwargs: Any) -> None:
...
_debugger_registry: Dict[str, Tuple[int, Debugger]] = {}
def get_debugger() -> Debugger:
debuggers = sorted(_debugger_registry.values(), key=lambda x: -x[0])
if not debuggers:
raise ValueError("No debuggers registered!")
return debuggers[0][1]
def register_debugger(name: str, debugger: Debugger, priority: int) -> None:
if name in _debugger_registry:
raise ValueError(f"Debugger with name \"{name}\" already registered.")
_debugger_registry[name] = (priority, debugger)
debug_lock = threading.Lock()
def breakpoint(*, ordered: bool = False, **kwargs): # pylint: disable=redefined-builtin
"""Enters a breakpoint at a point in a program."""
frame_infos = inspect.stack()
# Filter out internal frames
frame_infos = [
frame_info for frame_info in frame_infos
if traceback_util.include_frame(frame_info.frame)
]
frames = [
DebuggerFrame.from_frameinfo(frame_info) for frame_info in frame_infos
]
# Throw out first frame corresponding to this function
frames = frames[1:]
flat_args, frames_tree = tree_util.tree_flatten(frames)
def _breakpoint_callback(*flat_args):
frames = tree_util.tree_unflatten(frames_tree, flat_args)
thread_id = None
if threading.current_thread() is not threading.main_thread():
thread_id = threading.get_ident()
debugger = get_debugger()
# Lock here because this could be called from multiple threads at the same
# time.
with debug_lock:
debugger(frames, thread_id, **kwargs)
if ordered:
effect = debugging.DebugEffect.ORDERED_PRINT
else:
effect = debugging.DebugEffect.PRINT
debugging.debug_callback(_breakpoint_callback, effect, *flat_args)