Add tracers to LeakChecker error, and filter out false positives this way.

If we can't find any hanging tracers in the gc.get_referrers chain, is it
really a leak? Probably not!
This commit is contained in:
Lena Martens 2021-07-21 13:27:48 +01:00 committed by lenamartens
parent b6e25fa00c
commit 2190734637
2 changed files with 38 additions and 3 deletions

View File

@ -18,6 +18,7 @@ from operator import attrgetter
from contextlib import contextmanager
from collections import namedtuple
from functools import total_ordering
import gc
import itertools as it
from weakref import ref
import threading
@ -736,6 +737,17 @@ def reset_trace_state() -> bool:
def cur_sublevel() -> Sublevel:
return thread_local_state.trace_state.substack[-1]
def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]):
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
It's possible there's none! eg. there's some cases where JAX itself holds a
reference to `x` inside of a lambda closure, and no tracers were leaked
by the user. In this case an empty list is returned.
"""
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
return tracers
@contextmanager
def new_main(trace_type: Type[Trace],
dynamic: bool = False,
@ -761,7 +773,9 @@ def new_main(trace_type: Type[Trace],
t = ref(main)
del main
if t() is not None:
raise Exception(f'Leaked trace {t()}')
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers:
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
@contextmanager
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
@ -782,7 +796,9 @@ def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
t = ref(main)
del main
if t() is not None:
raise Exception('Leaked trace {}'.format(t()))
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers:
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
@contextmanager
def eval_context():
@ -802,7 +818,9 @@ def new_sublevel() -> Generator[None, None, None]:
t = ref(sublevel)
del sublevel
if t() is not None:
raise Exception(f'Leaked sublevel {t()}.')
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers:
raise Exception(f'Leaked sublevel {t()}. Leaked tracer(s): {leaked_tracers}.')
def full_lower(val):
if isinstance(val, Tracer):

View File

@ -2596,6 +2596,23 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(Exception, r"Leaked sublevel"):
f(3)
def test_leak_checker_avoids_false_positive_custom_jvp(self):
# see https://github.com/google/jax/issues/5636
with jax.checking_leaks():
@api.custom_jvp
def t(y):
return y
def t_jvp(p, t):
pass
t.defjvp(t_jvp)
@jit
def s(y):
return t(y)
s(3) # doesn't crash
def test_default_backend(self):
first_local_device = api.local_devices()[0]
self.assertEqual(first_local_device.platform, api.default_backend())