mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #11636 from sharadmv:webpdb
PiperOrigin-RevId: 463692257
This commit is contained in:
commit
d5fdd9e266
@ -14,6 +14,8 @@
|
||||
from jax._src.debugger.core import breakpoint
|
||||
from jax._src.debugger import cli_debugger
|
||||
from jax._src.debugger import colab_debugger
|
||||
from jax._src.debugger import web_debugger
|
||||
|
||||
del cli_debugger # For registration only
|
||||
del colab_debugger # For registration only
|
||||
del web_debugger # For registration only
|
||||
|
@ -45,6 +45,7 @@ class CliDebugger(cmd.Cmd):
|
||||
env = {}
|
||||
curr_frame = self.frames[self.frame_index]
|
||||
env.update(curr_frame.locals)
|
||||
env.update(curr_frame.globals)
|
||||
return eval(expr, {}, env)
|
||||
|
||||
def default(self, arg):
|
||||
|
@ -14,6 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import threading
|
||||
|
||||
@ -34,30 +35,31 @@ class DebuggerFrame:
|
||||
"""Encapsulates Python frame information."""
|
||||
filename: str
|
||||
locals: Dict[str, Any]
|
||||
globals: Dict[str, Any]
|
||||
code_context: str
|
||||
source: List[str]
|
||||
lineno: int
|
||||
offset: Optional[int]
|
||||
|
||||
def tree_flatten(self):
|
||||
flat_locals, locals_tree = tree_util.tree_flatten(self.locals)
|
||||
flat_vars, vars_tree = tree_util.tree_flatten((self.locals, self.globals))
|
||||
is_valid = [
|
||||
isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray))
|
||||
for l in flat_locals
|
||||
for l in flat_vars
|
||||
]
|
||||
invalid_locals, valid_locals = util.partition_list(is_valid, flat_locals)
|
||||
return valid_locals, (is_valid, invalid_locals, locals_tree, self.filename,
|
||||
invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars)
|
||||
return valid_vars, (is_valid, invalid_vars, vars_tree, self.filename,
|
||||
self.code_context, self.source, self.lineno,
|
||||
self.offset)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, info, valid_locals):
|
||||
(is_valid, invalid_locals, locals_tree, filename, code_context, source,
|
||||
def tree_unflatten(cls, info, valid_vars):
|
||||
(is_valid, invalid_vars, vars_tree, filename, code_context, source,
|
||||
lineno, offset) = info
|
||||
flat_locals = util.merge_lists(is_valid, invalid_locals, valid_locals)
|
||||
locals_ = tree_util.tree_unflatten(locals_tree, flat_locals)
|
||||
return DebuggerFrame(filename, locals_, code_context, source, lineno,
|
||||
offset)
|
||||
flat_vars = util.merge_lists(is_valid, invalid_vars, valid_vars)
|
||||
locals_, globals_ = tree_util.tree_unflatten(vars_tree, flat_vars)
|
||||
return DebuggerFrame(filename, locals_, globals_, code_context, source,
|
||||
lineno, offset)
|
||||
|
||||
@classmethod
|
||||
def from_frameinfo(cls, frame_info) -> DebuggerFrame:
|
||||
@ -78,6 +80,7 @@ class DebuggerFrame:
|
||||
return DebuggerFrame(
|
||||
filename=frame_info.filename,
|
||||
locals=frame_info.frame.f_locals,
|
||||
globals=frame_info.frame.f_globals,
|
||||
code_context=frame_info.code_context,
|
||||
source=source,
|
||||
lineno=frame_info.lineno,
|
||||
|
77
jax/_src/debugger/web_debugger.py
Normal file
77
jax/_src/debugger/web_debugger.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import weakref
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from jax._src.debugger import cli_debugger
|
||||
from jax._src.debugger import core as debugger_core
|
||||
|
||||
try:
|
||||
import web_pdb # pytype: disable=import-error
|
||||
WEB_PDB_ENABLED = True
|
||||
except:
|
||||
WEB_PDB_ENABLED = False
|
||||
|
||||
|
||||
_web_consoles: Dict[Tuple[str, int], web_pdb.WebConsole] = {}
|
||||
|
||||
class WebDebugger(cli_debugger.CliDebugger):
|
||||
"""A web-based debugger."""
|
||||
prompt = '(jaxdb) '
|
||||
use_rawinput: bool = False
|
||||
|
||||
def __init__(self, frames: List[debugger_core.DebuggerFrame], thread_id,
|
||||
completekey: str = "tab", host: str = "0.0.0.0", port: int = 5555):
|
||||
if (host, port) not in _web_consoles:
|
||||
_web_consoles[host, port] = web_pdb.WebConsole(host, port, self)
|
||||
# Clobber the debugger in the web console
|
||||
_web_console = _web_consoles[host, port]
|
||||
_web_console._debugger = weakref.proxy(self)
|
||||
super().__init__(frames, thread_id, stdin=_web_console, stdout=_web_console,
|
||||
completekey=completekey)
|
||||
|
||||
def get_current_frame_data(self):
|
||||
# Constructs the info needed for the web console to display info
|
||||
current_frame = self.current_frame()
|
||||
filename = current_frame.filename
|
||||
lines = current_frame.source
|
||||
locals = "\n".join([f"{key} = {value}" for key, value in
|
||||
sorted(current_frame.locals.items())])
|
||||
globals = "\n".join([f"{key} = {value}" for key, value in
|
||||
sorted(current_frame.globals.items())])
|
||||
current_line = None
|
||||
if current_frame.offset is not None:
|
||||
current_line = current_frame.offset + 1
|
||||
return {
|
||||
'dirname': os.path.dirname(os.path.abspath(filename)) + os.path.sep,
|
||||
'filename': os.path.basename(filename),
|
||||
'file_listing': '\n'.join(lines),
|
||||
'current_line': current_line,
|
||||
'breakpoints': [],
|
||||
'globals': globals,
|
||||
'locals': locals,
|
||||
}
|
||||
|
||||
def run(self):
|
||||
return self.cmdloop()
|
||||
|
||||
def run_debugger(frames: List[debugger_core.DebuggerFrame],
|
||||
thread_id: Optional[int], **kwargs: Any):
|
||||
WebDebugger(frames, thread_id, **kwargs).run()
|
||||
|
||||
if WEB_PDB_ENABLED:
|
||||
debugger_core.register_debugger("web", run_debugger, 0)
|
Loading…
x
Reference in New Issue
Block a user