Merge pull request #11808 from sharadmv:debugger-fix-flatten

PiperOrigin-RevId: 466448203
This commit is contained in:
jax authors 2022-08-09 13:17:44 -07:00
commit 6d9512aa39
2 changed files with 74 additions and 10 deletions

View File

@ -14,11 +14,10 @@
from __future__ import annotations
import dataclasses
import functools
import inspect
import threading
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Hashable, List, Optional, Tuple
from typing_extensions import Protocol
import jax.numpy as jnp
@ -29,6 +28,46 @@ from jax._src import traceback_util
from jax._src import util
import numpy as np
@tree_util.register_pytree_node_class
class _DictWrapper:
keys: list[Hashable]
values: list[Any]
def __init__(self, keys, values):
self._keys = keys
self._values = values
def to_dict(self):
return dict(zip(self._keys, self._values))
def tree_flatten(self):
return self._values, self._keys
@classmethod
def tree_unflatten(cls, keys, values):
return _DictWrapper(keys, values)
class _CantFlatten:
__repr__ = lambda _: "<cant_flatten>"
cant_flatten = _CantFlatten()
def _safe_flatten_dict(dct: dict[Any, Any]
) -> tuple[list[Any], tree_util.PyTreeDef]:
# We avoid comparison between keys by just using the original order
keys, values = [], []
for key, value in dct.items():
try:
tree_util.tree_leaves(value)
except:
# If flattening fails, we substitute a sentinel object.
value = cant_flatten
keys.append(key)
values.append(value)
return tree_util.tree_flatten(_DictWrapper(keys, values))
@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class DebuggerFrame:
@ -42,22 +81,26 @@ class DebuggerFrame:
offset: Optional[int]
def tree_flatten(self):
flat_vars, vars_tree = tree_util.tree_flatten((self.locals, self.globals))
flat_locals, locals_tree = _safe_flatten_dict(self.locals)
flat_globals, globals_tree = _safe_flatten_dict(self.globals)
flat_vars = flat_locals + flat_globals
is_valid = [
isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray))
for l in flat_vars
]
invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars)
return valid_vars, (is_valid, invalid_vars, vars_tree, self.filename,
self.code_context, self.source, self.lineno,
self.offset)
return valid_vars, (is_valid, invalid_vars, locals_tree, globals_tree,
len(flat_locals), self.filename, self.code_context,
self.source, self.lineno, self.offset)
@classmethod
def tree_unflatten(cls, info, valid_vars):
(is_valid, invalid_vars, vars_tree, filename, code_context, source,
lineno, offset) = info
(is_valid, invalid_vars, locals_tree, globals_tree, num_locals, filename,
code_context, source, lineno, offset) = info
flat_vars = util.merge_lists(is_valid, invalid_vars, valid_vars)
locals_, globals_ = tree_util.tree_unflatten(vars_tree, flat_vars)
flat_locals, flat_globals = util.split_list(flat_vars, [num_locals])
locals_ = tree_util.tree_unflatten(locals_tree, flat_locals).to_dict()
globals_ = tree_util.tree_unflatten(globals_tree, flat_globals).to_dict()
return DebuggerFrame(filename, locals_, globals_, code_context, source,
lineno, offset)

View File

@ -376,7 +376,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
x = 2
g()
return x
_ = f()
expected = _format_multiline(r"""
Entering jdb:
@ -409,5 +408,27 @@ class CliDebuggerTest(jtu.JaxTestCase):
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
def test_can_handle_dictionaries_with_unsortable_keys(self):
stdin, stdout = make_fake_stdin_stdout(["p x", "p weird_dict",
"p weirder_dict", "c"])
@jax.jit
def f():
weird_dict = {(lambda x: x): 2., (lambda x: x * 2): 3}
weirder_dict = {(lambda x: x): weird_dict}
x = 2.
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
del weirder_dict
return x
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) 2.0
\(jdb\) <cant_flatten>
\(jdb\) <cant_flatten>
\(jdb\) """)
_ = f()
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())