make leak checker errors explain why objects are alive

Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
Matthew Johnson 2022-10-26 14:14:58 -07:00
parent a08ced86f3
commit 6ebf44a681
4 changed files with 105 additions and 5 deletions

View File

@ -20,6 +20,7 @@ from dataclasses import dataclass
import functools
from functools import partial, partialmethod, total_ordering
import gc
import inspect
import itertools as it
import operator
from operator import attrgetter
@ -888,7 +889,9 @@ def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]
"""
if not getattr(threading.current_thread(), 'pydev_do_not_trace', True):
warnings.warn(TRACER_LEAK_DEBUGGER_WARNING)
# Trigger garbage collection to filter out cyclical dependency false positives
# Trigger garbage collection to filter out unreachable objects that are alive
# only due to cyclical dependencies. (We don't care about unreachable leaked
# tracers since they can't interact with user code and cause a problem.)
gc.collect()
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
@ -896,9 +899,71 @@ def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]
def leaked_tracer_error(name: str, t, tracers: List[Tracer]) -> Exception:
assert tracers
msgs = '\n\n'.join(f'{tracer}{tracer._origin_msg()}' for tracer in tracers)
why = partial(_why_alive, {id(tracers)})
msgs = '\n\n'.join(f'{tracers[i]}{tracers[i]._origin_msg()}{why(tracers[i])}'
for i in range(len(tracers)))
return Exception(f'Leaked {name} {t}. Leaked tracer(s):\n\n{msgs}\n')
def _why_alive(ignore_ids: Set[int], x: Any) -> str:
parents = lambda x: [r for r in gc.get_referrers(x) if id(r) not in ignore_ids]
child, lines, seen = x, [], set()
while (id(child) not in seen and type(child) is not types.ModuleType
and parents(child)):
parent = parents(child)[0] # just pick one parent
# For namespaces (like modules and class instances) and closures, the
# references may form a simple chain: e.g. instance refers to its own
# __dict__ which refers to child, or function refers to its __closure__
# which refers to cells which refer to child. In these cases, we can provide
# a more intuitive description by collapsing the chain into a single
# parent->child jump. We do that by setting `parent` here to be a
# grandparent (or great-grandparent) of `child`, and then handling that case
# in _why_alive_container_info. See example:
# https://github.com/google/jax/pull/13022#discussion_r1008456599
# To prevent this collapsing behavior, just comment out this code block.
# TODO(mattjj): after Python 3.7 is unsupported, replace with types.CellType
cell_type = type((lambda x: lambda: x)(10.28).__closure__[0]) # type: ignore
if (isinstance(parent, dict) and
getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]):
parent = parents(parent)[0]
elif type(parent) is cell_type:
parent = parents(parents(parent)[0])[0]
line = f'<{type(child).__name__} {id(child)}> is referred to by '
lines.append(line + _why_alive_container_info(parent, id(child)))
seen.add(id(child))
child = parent
return '\n' + '\n'.join(lines) if lines else ''
def _why_alive_container_info(container, obj_id) -> str:
name = f'<{type(container).__name__} {id(container)}>'
if type(container) is types.ModuleType:
name = getattr(container, '__name__', name)
if type(container) is types.FunctionType:
name_ = getattr(container, '__name__', '<no-name>')
closure = inspect.getclosurevars(container)
keys = [k for k, v in dict(closure.nonlocals, **closure.globals).items()
if id(v) == obj_id]
if len(keys) == 1: return f'{name} ({name_}) closed-over variable {keys[0]}'
elif len(keys) > 1: return (f'{name} in closed-over variables ' +
', '.join(map(repr, keys)))
if hasattr(container, '__dict__'):
keys = [k for k in vars(container) if id(vars(container)[k]) == obj_id]
if len(keys) == 1: return f'{name}.{str(keys[0])}'
elif len(keys) > 1: return f'{name} in vars ' + ', '.join(map(repr, keys))
if isinstance(container, (list, tuple)):
idxs = [i for i, x in enumerate(container) if id(x) == obj_id]
if len(idxs) == 1: return f'{name}[{idxs[0]}]'
else: return f'{name} at indices ' + ', '.join(map(str, idxs))
if isinstance(container, dict):
keys = [k for k in container if id(container[k]) == obj_id]
if len(keys) == 1: return f'{name}[{repr(keys[0])}]'
else: return f'{name} at keys ' + ', '.join(map(repr, keys))
if isinstance(container, types.ModuleType):
return f' named {container.__name__}'
return name
@contextmanager
def new_main(trace_type: Type[Trace],
dynamic: bool = False,

View File

@ -164,8 +164,8 @@ class BatchTracer(Tracer):
def _origin_msg(self):
if self.source_info is None:
return ""
return ("\nThis Tracer was created on line "
f"{source_info_util.summarize(self.source_info)}")
return (f"\nThis BatchTracer with object id {id(self)} was created on line:"
f"\n {source_info_util.summarize(self.source_info)}")
def _contents(self):
return [('val', self.val), ('batch_dim', self.batch_dim)]

View File

@ -1460,7 +1460,7 @@ class DynamicJaxprTracer(core.Tracer):
if not self._trace.main.jaxpr_stack: # type: ignore
# If this Tracer has been leaked the jaxpr stack may no longer be
# available. So we can't print as much origin information.
return ("\nThis Tracer was created on line "
return ("\nThis DynamicJaxprTracer was created on line "
f"{source_info_util.summarize(self._line_info)}")
else:
invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)

View File

@ -3554,6 +3554,41 @@ class APITest(jtu.JaxTestCase):
with jax.check_tracer_leaks():
jax.jit(apply_fn)(1.0) # don't crash
def test_leak_checker_reference_chain(self):
class A:
def __init__(self, dct):
self.dct = dct
a = A({})
x = jnp.arange(3)
def sketch(x):
def foo():
return x
a.dct['hi'] = [foo]
return x
# TODO(mattjj): full test msg below fails (harmlessly) on CI, investigate
msg = (
r"This BatchTracer with object id [0-9]+ was created on line:\n"
r" .*\n"
r"<BatchTracer [0-9]+> is referred to by"
)
# msg = (
# r"This BatchTracer with object id [0-9]+ was created on line:\n"
# r" .*\n"
# r"<BatchTracer [0-9]+> is referred to by <function [0-9]+> \(foo\) "
# r"closed-over variable x\n"
# r"<function [0-9]+> is referred to by <list [0-9]+>\[0\]\n"
# r"<list [0-9]+> is referred to by <dict [0-9]+>\['hi'\]\n"
# r"<dict [0-9]+> is referred to by <A [0-9]+>\.dct\n"
# )
with jax.check_tracer_leaks():
with self.assertRaisesRegex(Exception, msg):
jax.vmap(sketch)(x)
def test_default_backend(self):
first_local_device = api.local_devices()[0]
self.assertEqual(first_local_device.platform, api.default_backend())