mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Refactor debugger to have a registry
This commit is contained in:
parent
9e16efa98a
commit
7f8378e0db
@ -11,4 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from jax._src.debugger.cli_debugger import breakpoint
|
||||
from jax._src.debugger.core import breakpoint
|
||||
from jax._src.debugger import cli_debugger
|
||||
|
||||
del cli_debugger # For registration only
|
||||
|
@ -14,106 +14,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import cmd
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from typing import Any, Callable, Dict, IO, List, Optional
|
||||
from typing import Any, IO, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from jax import core
|
||||
from jax import tree_util
|
||||
from jax._src import debugging
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
import jax.numpy as jnp
|
||||
from jax._src.debugger import core as debugger_core
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DebuggerFrame:
|
||||
"""Encapsulates Python frame information."""
|
||||
filename: str
|
||||
locals: Dict[str, Any]
|
||||
code_context: str
|
||||
source: List[str]
|
||||
lineno: int
|
||||
offset: Optional[int]
|
||||
DebuggerFrame = debugger_core.DebuggerFrame
|
||||
|
||||
def tree_flatten(self):
|
||||
flat_locals, locals_tree = tree_util.tree_flatten(self.locals)
|
||||
is_valid = [
|
||||
isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray))
|
||||
for l in flat_locals
|
||||
]
|
||||
invalid_locals, valid_locals = util.partition_list(is_valid, flat_locals)
|
||||
return valid_locals, (is_valid, invalid_locals, locals_tree, self.filename,
|
||||
self.code_context, self.source, self.lineno,
|
||||
self.offset)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, info, valid_locals):
|
||||
(is_valid, invalid_locals, locals_tree, filename, code_context, source,
|
||||
lineno, offset) = info
|
||||
flat_locals = util.merge_lists(is_valid, invalid_locals, valid_locals)
|
||||
locals_ = tree_util.tree_unflatten(locals_tree, flat_locals)
|
||||
return DebuggerFrame(filename, locals_, code_context, source, lineno,
|
||||
offset)
|
||||
|
||||
@classmethod
|
||||
def from_frameinfo(cls, frame_info) -> DebuggerFrame:
|
||||
try:
|
||||
_, start = inspect.getsourcelines(frame_info.frame)
|
||||
source = inspect.getsource(frame_info.frame).split('\n')
|
||||
offset = frame_info.lineno - start
|
||||
except OSError:
|
||||
source = []
|
||||
offset = None
|
||||
return DebuggerFrame(
|
||||
filename=frame_info.filename,
|
||||
locals=frame_info.frame.f_locals,
|
||||
code_context=frame_info.code_context,
|
||||
source=source,
|
||||
lineno=frame_info.lineno,
|
||||
offset=offset)
|
||||
|
||||
|
||||
debug_lock = threading.Lock()
|
||||
|
||||
|
||||
def breakpoint(*, ordered: bool = False, **kwargs): # pylint: disable=redefined-builtin
|
||||
"""Enters a breakpoint at a point in a program."""
|
||||
frame_infos = inspect.stack()
|
||||
# Filter out internal frames
|
||||
frame_infos = [
|
||||
frame_info for frame_info in frame_infos
|
||||
if traceback_util.include_frame(frame_info.frame)
|
||||
]
|
||||
frames = [
|
||||
DebuggerFrame.from_frameinfo(frame_info) for frame_info in frame_infos
|
||||
]
|
||||
# Throw out first frame corresponding to this function
|
||||
frames = frames[1:]
|
||||
flat_args, frames_tree = tree_util.tree_flatten(frames)
|
||||
|
||||
def _breakpoint_callback(*flat_args):
|
||||
frames = tree_util.tree_unflatten(frames_tree, flat_args)
|
||||
thread_id = None
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
thread_id = threading.get_ident()
|
||||
with debug_lock:
|
||||
TextDebugger(frames, thread_id, **kwargs).run()
|
||||
|
||||
if ordered:
|
||||
effect = debugging.DebugEffect.ORDERED_PRINT
|
||||
else:
|
||||
effect = debugging.DebugEffect.PRINT
|
||||
debugging.debug_callback(_breakpoint_callback, effect, *flat_args)
|
||||
|
||||
|
||||
class TextDebugger(cmd.Cmd):
|
||||
class CliDebugger(cmd.Cmd):
|
||||
"""A text-based debugger."""
|
||||
prompt = '(jaxdb) '
|
||||
use_rawinput: bool = False
|
||||
@ -200,3 +111,8 @@ class TextDebugger(cmd.Cmd):
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
self.stdout.write('--KeyboardInterrupt--\n')
|
||||
|
||||
def run_debugger(frames: List[DebuggerFrame], thread_id: Optional[int],
|
||||
**kwargs: Any):
|
||||
CliDebugger(frames, thread_id, **kwargs).run()
|
||||
debugger_core.register_debugger("cli", run_debugger, -1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user