diff --git a/jax/_src/debugger/__init__.py b/jax/_src/debugger/__init__.py index 8e24e0b9d..019890143 100644 --- a/jax/_src/debugger/__init__.py +++ b/jax/_src/debugger/__init__.py @@ -11,4 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.debugger.cli_debugger import breakpoint +from jax._src.debugger.core import breakpoint +from jax._src.debugger import cli_debugger + +del cli_debugger # For registration only diff --git a/jax/_src/debugger/cli_debugger.py b/jax/_src/debugger/cli_debugger.py index e3f726427..da80840f3 100644 --- a/jax/_src/debugger/cli_debugger.py +++ b/jax/_src/debugger/cli_debugger.py @@ -14,106 +14,17 @@ from __future__ import annotations import cmd -import dataclasses -import inspect import sys -import threading import traceback -from typing import Any, Callable, Dict, IO, List, Optional +from typing import Any, IO, List, Optional -import numpy as np -from jax import core -from jax import tree_util -from jax._src import debugging -from jax._src import traceback_util -from jax._src import util -import jax.numpy as jnp +from jax._src.debugger import core as debugger_core -@tree_util.register_pytree_node_class -@dataclasses.dataclass(frozen=True) -class DebuggerFrame: - """Encapsulates Python frame information.""" - filename: str - locals: Dict[str, Any] - code_context: str - source: List[str] - lineno: int - offset: Optional[int] +DebuggerFrame = debugger_core.DebuggerFrame - def tree_flatten(self): - flat_locals, locals_tree = tree_util.tree_flatten(self.locals) - is_valid = [ - isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray)) - for l in flat_locals - ] - invalid_locals, valid_locals = util.partition_list(is_valid, flat_locals) - return valid_locals, (is_valid, invalid_locals, locals_tree, self.filename, - self.code_context, self.source, self.lineno, - self.offset) - - @classmethod - def tree_unflatten(cls, info, valid_locals): - (is_valid, invalid_locals, locals_tree, filename, code_context, source, - lineno, offset) = info - flat_locals = util.merge_lists(is_valid, invalid_locals, valid_locals) - locals_ = tree_util.tree_unflatten(locals_tree, flat_locals) - return DebuggerFrame(filename, locals_, code_context, source, lineno, - offset) - - @classmethod - def from_frameinfo(cls, frame_info) -> DebuggerFrame: - try: - _, start = inspect.getsourcelines(frame_info.frame) - source = inspect.getsource(frame_info.frame).split('\n') - offset = frame_info.lineno - start - except OSError: - source = [] - offset = None - return DebuggerFrame( - filename=frame_info.filename, - locals=frame_info.frame.f_locals, - code_context=frame_info.code_context, - source=source, - lineno=frame_info.lineno, - offset=offset) - - -debug_lock = threading.Lock() - - -def breakpoint(*, ordered: bool = False, **kwargs): # pylint: disable=redefined-builtin - """Enters a breakpoint at a point in a program.""" - frame_infos = inspect.stack() - # Filter out internal frames - frame_infos = [ - frame_info for frame_info in frame_infos - if traceback_util.include_frame(frame_info.frame) - ] - frames = [ - DebuggerFrame.from_frameinfo(frame_info) for frame_info in frame_infos - ] - # Throw out first frame corresponding to this function - frames = frames[1:] - flat_args, frames_tree = tree_util.tree_flatten(frames) - - def _breakpoint_callback(*flat_args): - frames = tree_util.tree_unflatten(frames_tree, flat_args) - thread_id = None - if threading.current_thread() is not threading.main_thread(): - thread_id = threading.get_ident() - with debug_lock: - TextDebugger(frames, thread_id, **kwargs).run() - - if ordered: - effect = debugging.DebugEffect.ORDERED_PRINT - else: - effect = debugging.DebugEffect.PRINT - debugging.debug_callback(_breakpoint_callback, effect, *flat_args) - - -class TextDebugger(cmd.Cmd): +class CliDebugger(cmd.Cmd): """A text-based debugger.""" prompt = '(jaxdb) ' use_rawinput: bool = False @@ -200,3 +111,8 @@ class TextDebugger(cmd.Cmd): break except KeyboardInterrupt: self.stdout.write('--KeyboardInterrupt--\n') + +def run_debugger(frames: List[DebuggerFrame], thread_id: Optional[int], + **kwargs: Any): + CliDebugger(frames, thread_id, **kwargs).run() +debugger_core.register_debugger("cli", run_debugger, -1)