mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11808 from sharadmv:debugger-fix-flatten
PiperOrigin-RevId: 466448203
This commit is contained in:
commit
6d9512aa39
@ -14,11 +14,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import threading
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, Hashable, List, Optional, Tuple
|
||||
from typing_extensions import Protocol
|
||||
|
||||
import jax.numpy as jnp
|
||||
@ -29,6 +28,46 @@ from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
import numpy as np
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class _DictWrapper:
|
||||
keys: list[Hashable]
|
||||
values: list[Any]
|
||||
|
||||
def __init__(self, keys, values):
|
||||
self._keys = keys
|
||||
self._values = values
|
||||
|
||||
def to_dict(self):
|
||||
return dict(zip(self._keys, self._values))
|
||||
|
||||
def tree_flatten(self):
|
||||
return self._values, self._keys
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, keys, values):
|
||||
return _DictWrapper(keys, values)
|
||||
|
||||
|
||||
class _CantFlatten:
|
||||
__repr__ = lambda _: "<cant_flatten>"
|
||||
cant_flatten = _CantFlatten()
|
||||
|
||||
def _safe_flatten_dict(dct: dict[Any, Any]
|
||||
) -> tuple[list[Any], tree_util.PyTreeDef]:
|
||||
# We avoid comparison between keys by just using the original order
|
||||
keys, values = [], []
|
||||
for key, value in dct.items():
|
||||
try:
|
||||
tree_util.tree_leaves(value)
|
||||
except:
|
||||
# If flattening fails, we substitute a sentinel object.
|
||||
value = cant_flatten
|
||||
keys.append(key)
|
||||
values.append(value)
|
||||
return tree_util.tree_flatten(_DictWrapper(keys, values))
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DebuggerFrame:
|
||||
@ -42,22 +81,26 @@ class DebuggerFrame:
|
||||
offset: Optional[int]
|
||||
|
||||
def tree_flatten(self):
|
||||
flat_vars, vars_tree = tree_util.tree_flatten((self.locals, self.globals))
|
||||
flat_locals, locals_tree = _safe_flatten_dict(self.locals)
|
||||
flat_globals, globals_tree = _safe_flatten_dict(self.globals)
|
||||
flat_vars = flat_locals + flat_globals
|
||||
is_valid = [
|
||||
isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray))
|
||||
for l in flat_vars
|
||||
]
|
||||
invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars)
|
||||
return valid_vars, (is_valid, invalid_vars, vars_tree, self.filename,
|
||||
self.code_context, self.source, self.lineno,
|
||||
self.offset)
|
||||
return valid_vars, (is_valid, invalid_vars, locals_tree, globals_tree,
|
||||
len(flat_locals), self.filename, self.code_context,
|
||||
self.source, self.lineno, self.offset)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, info, valid_vars):
|
||||
(is_valid, invalid_vars, vars_tree, filename, code_context, source,
|
||||
lineno, offset) = info
|
||||
(is_valid, invalid_vars, locals_tree, globals_tree, num_locals, filename,
|
||||
code_context, source, lineno, offset) = info
|
||||
flat_vars = util.merge_lists(is_valid, invalid_vars, valid_vars)
|
||||
locals_, globals_ = tree_util.tree_unflatten(vars_tree, flat_vars)
|
||||
flat_locals, flat_globals = util.split_list(flat_vars, [num_locals])
|
||||
locals_ = tree_util.tree_unflatten(locals_tree, flat_locals).to_dict()
|
||||
globals_ = tree_util.tree_unflatten(globals_tree, flat_globals).to_dict()
|
||||
return DebuggerFrame(filename, locals_, globals_, code_context, source,
|
||||
lineno, offset)
|
||||
|
||||
|
@ -376,7 +376,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
x = 2
|
||||
g()
|
||||
return x
|
||||
|
||||
_ = f()
|
||||
expected = _format_multiline(r"""
|
||||
Entering jdb:
|
||||
@ -409,5 +408,27 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
jax.effects_barrier()
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
def test_can_handle_dictionaries_with_unsortable_keys(self):
|
||||
stdin, stdout = make_fake_stdin_stdout(["p x", "p weird_dict",
|
||||
"p weirder_dict", "c"])
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
weird_dict = {(lambda x: x): 2., (lambda x: x * 2): 3}
|
||||
weirder_dict = {(lambda x: x): weird_dict}
|
||||
x = 2.
|
||||
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
|
||||
del weirder_dict
|
||||
return x
|
||||
expected = _format_multiline(r"""
|
||||
Entering jdb:
|
||||
\(jdb\) 2.0
|
||||
\(jdb\) <cant_flatten>
|
||||
\(jdb\) <cant_flatten>
|
||||
\(jdb\) """)
|
||||
_ = f()
|
||||
jax.effects_barrier()
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user