mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
make leak checker errors explain why objects are alive
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com> Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
parent
a08ced86f3
commit
6ebf44a681
69
jax/core.py
69
jax/core.py
@ -20,6 +20,7 @@ from dataclasses import dataclass
|
||||
import functools
|
||||
from functools import partial, partialmethod, total_ordering
|
||||
import gc
|
||||
import inspect
|
||||
import itertools as it
|
||||
import operator
|
||||
from operator import attrgetter
|
||||
@ -888,7 +889,9 @@ def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]
|
||||
"""
|
||||
if not getattr(threading.current_thread(), 'pydev_do_not_trace', True):
|
||||
warnings.warn(TRACER_LEAK_DEBUGGER_WARNING)
|
||||
# Trigger garbage collection to filter out cyclical dependency false positives
|
||||
# Trigger garbage collection to filter out unreachable objects that are alive
|
||||
# only due to cyclical dependencies. (We don't care about unreachable leaked
|
||||
# tracers since they can't interact with user code and cause a problem.)
|
||||
gc.collect()
|
||||
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
|
||||
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
|
||||
@ -896,9 +899,71 @@ def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]
|
||||
|
||||
def leaked_tracer_error(name: str, t, tracers: List[Tracer]) -> Exception:
|
||||
assert tracers
|
||||
msgs = '\n\n'.join(f'{tracer}{tracer._origin_msg()}' for tracer in tracers)
|
||||
why = partial(_why_alive, {id(tracers)})
|
||||
msgs = '\n\n'.join(f'{tracers[i]}{tracers[i]._origin_msg()}{why(tracers[i])}'
|
||||
for i in range(len(tracers)))
|
||||
return Exception(f'Leaked {name} {t}. Leaked tracer(s):\n\n{msgs}\n')
|
||||
|
||||
def _why_alive(ignore_ids: Set[int], x: Any) -> str:
|
||||
parents = lambda x: [r for r in gc.get_referrers(x) if id(r) not in ignore_ids]
|
||||
child, lines, seen = x, [], set()
|
||||
while (id(child) not in seen and type(child) is not types.ModuleType
|
||||
and parents(child)):
|
||||
parent = parents(child)[0] # just pick one parent
|
||||
|
||||
# For namespaces (like modules and class instances) and closures, the
|
||||
# references may form a simple chain: e.g. instance refers to its own
|
||||
# __dict__ which refers to child, or function refers to its __closure__
|
||||
# which refers to cells which refer to child. In these cases, we can provide
|
||||
# a more intuitive description by collapsing the chain into a single
|
||||
# parent->child jump. We do that by setting `parent` here to be a
|
||||
# grandparent (or great-grandparent) of `child`, and then handling that case
|
||||
# in _why_alive_container_info. See example:
|
||||
# https://github.com/google/jax/pull/13022#discussion_r1008456599
|
||||
# To prevent this collapsing behavior, just comment out this code block.
|
||||
# TODO(mattjj): after Python 3.7 is unsupported, replace with types.CellType
|
||||
cell_type = type((lambda x: lambda: x)(10.28).__closure__[0]) # type: ignore
|
||||
if (isinstance(parent, dict) and
|
||||
getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]):
|
||||
parent = parents(parent)[0]
|
||||
elif type(parent) is cell_type:
|
||||
parent = parents(parents(parent)[0])[0]
|
||||
|
||||
line = f'<{type(child).__name__} {id(child)}> is referred to by '
|
||||
lines.append(line + _why_alive_container_info(parent, id(child)))
|
||||
seen.add(id(child))
|
||||
child = parent
|
||||
return '\n' + '\n'.join(lines) if lines else ''
|
||||
|
||||
def _why_alive_container_info(container, obj_id) -> str:
|
||||
name = f'<{type(container).__name__} {id(container)}>'
|
||||
if type(container) is types.ModuleType:
|
||||
name = getattr(container, '__name__', name)
|
||||
if type(container) is types.FunctionType:
|
||||
name_ = getattr(container, '__name__', '<no-name>')
|
||||
closure = inspect.getclosurevars(container)
|
||||
keys = [k for k, v in dict(closure.nonlocals, **closure.globals).items()
|
||||
if id(v) == obj_id]
|
||||
if len(keys) == 1: return f'{name} ({name_}) closed-over variable {keys[0]}'
|
||||
elif len(keys) > 1: return (f'{name} in closed-over variables ' +
|
||||
', '.join(map(repr, keys)))
|
||||
if hasattr(container, '__dict__'):
|
||||
keys = [k for k in vars(container) if id(vars(container)[k]) == obj_id]
|
||||
if len(keys) == 1: return f'{name}.{str(keys[0])}'
|
||||
elif len(keys) > 1: return f'{name} in vars ' + ', '.join(map(repr, keys))
|
||||
if isinstance(container, (list, tuple)):
|
||||
idxs = [i for i, x in enumerate(container) if id(x) == obj_id]
|
||||
if len(idxs) == 1: return f'{name}[{idxs[0]}]'
|
||||
else: return f'{name} at indices ' + ', '.join(map(str, idxs))
|
||||
if isinstance(container, dict):
|
||||
keys = [k for k in container if id(container[k]) == obj_id]
|
||||
if len(keys) == 1: return f'{name}[{repr(keys[0])}]'
|
||||
else: return f'{name} at keys ' + ', '.join(map(repr, keys))
|
||||
if isinstance(container, types.ModuleType):
|
||||
return f' named {container.__name__}'
|
||||
return name
|
||||
|
||||
|
||||
@contextmanager
|
||||
def new_main(trace_type: Type[Trace],
|
||||
dynamic: bool = False,
|
||||
|
@ -164,8 +164,8 @@ class BatchTracer(Tracer):
|
||||
def _origin_msg(self):
|
||||
if self.source_info is None:
|
||||
return ""
|
||||
return ("\nThis Tracer was created on line "
|
||||
f"{source_info_util.summarize(self.source_info)}")
|
||||
return (f"\nThis BatchTracer with object id {id(self)} was created on line:"
|
||||
f"\n {source_info_util.summarize(self.source_info)}")
|
||||
|
||||
def _contents(self):
|
||||
return [('val', self.val), ('batch_dim', self.batch_dim)]
|
||||
|
@ -1460,7 +1460,7 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
if not self._trace.main.jaxpr_stack: # type: ignore
|
||||
# If this Tracer has been leaked the jaxpr stack may no longer be
|
||||
# available. So we can't print as much origin information.
|
||||
return ("\nThis Tracer was created on line "
|
||||
return ("\nThis DynamicJaxprTracer was created on line "
|
||||
f"{source_info_util.summarize(self._line_info)}")
|
||||
else:
|
||||
invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
|
||||
|
@ -3554,6 +3554,41 @@ class APITest(jtu.JaxTestCase):
|
||||
with jax.check_tracer_leaks():
|
||||
jax.jit(apply_fn)(1.0) # don't crash
|
||||
|
||||
def test_leak_checker_reference_chain(self):
|
||||
class A:
|
||||
def __init__(self, dct):
|
||||
self.dct = dct
|
||||
|
||||
a = A({})
|
||||
x = jnp.arange(3)
|
||||
|
||||
def sketch(x):
|
||||
def foo():
|
||||
return x
|
||||
a.dct['hi'] = [foo]
|
||||
return x
|
||||
|
||||
# TODO(mattjj): full test msg below fails (harmlessly) on CI, investigate
|
||||
msg = (
|
||||
r"This BatchTracer with object id [0-9]+ was created on line:\n"
|
||||
r" .*\n"
|
||||
r"<BatchTracer [0-9]+> is referred to by"
|
||||
)
|
||||
|
||||
# msg = (
|
||||
# r"This BatchTracer with object id [0-9]+ was created on line:\n"
|
||||
# r" .*\n"
|
||||
# r"<BatchTracer [0-9]+> is referred to by <function [0-9]+> \(foo\) "
|
||||
# r"closed-over variable x\n"
|
||||
# r"<function [0-9]+> is referred to by <list [0-9]+>\[0\]\n"
|
||||
# r"<list [0-9]+> is referred to by <dict [0-9]+>\['hi'\]\n"
|
||||
# r"<dict [0-9]+> is referred to by <A [0-9]+>\.dct\n"
|
||||
# )
|
||||
|
||||
with jax.check_tracer_leaks():
|
||||
with self.assertRaisesRegex(Exception, msg):
|
||||
jax.vmap(sketch)(x)
|
||||
|
||||
def test_default_backend(self):
|
||||
first_local_device = api.local_devices()[0]
|
||||
self.assertEqual(first_local_device.platform, api.default_backend())
|
||||
|
Loading…
x
Reference in New Issue
Block a user