mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add tracers to LeakChecker error, and filter out false positives this way.
If we can't find any hanging tracers in the gc.get_referrers chain, is it really a leak? Probably not!
This commit is contained in:
parent
b6e25fa00c
commit
2190734637
24
jax/core.py
24
jax/core.py
@ -18,6 +18,7 @@ from operator import attrgetter
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
from functools import total_ordering
|
||||
import gc
|
||||
import itertools as it
|
||||
from weakref import ref
|
||||
import threading
|
||||
@ -736,6 +737,17 @@ def reset_trace_state() -> bool:
|
||||
def cur_sublevel() -> Sublevel:
|
||||
return thread_local_state.trace_state.substack[-1]
|
||||
|
||||
def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]):
|
||||
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
|
||||
|
||||
It's possible there's none! eg. there's some cases where JAX itself holds a
|
||||
reference to `x` inside of a lambda closure, and no tracers were leaked
|
||||
by the user. In this case an empty list is returned.
|
||||
"""
|
||||
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
|
||||
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
|
||||
return tracers
|
||||
|
||||
@contextmanager
|
||||
def new_main(trace_type: Type[Trace],
|
||||
dynamic: bool = False,
|
||||
@ -761,7 +773,9 @@ def new_main(trace_type: Type[Trace],
|
||||
t = ref(main)
|
||||
del main
|
||||
if t() is not None:
|
||||
raise Exception(f'Leaked trace {t()}')
|
||||
leaked_tracers = maybe_find_leaked_tracers(t())
|
||||
if leaked_tracers:
|
||||
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
|
||||
|
||||
@contextmanager
|
||||
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
|
||||
@ -782,7 +796,9 @@ def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
|
||||
t = ref(main)
|
||||
del main
|
||||
if t() is not None:
|
||||
raise Exception('Leaked trace {}'.format(t()))
|
||||
leaked_tracers = maybe_find_leaked_tracers(t())
|
||||
if leaked_tracers:
|
||||
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
|
||||
|
||||
@contextmanager
|
||||
def eval_context():
|
||||
@ -802,7 +818,9 @@ def new_sublevel() -> Generator[None, None, None]:
|
||||
t = ref(sublevel)
|
||||
del sublevel
|
||||
if t() is not None:
|
||||
raise Exception(f'Leaked sublevel {t()}.')
|
||||
leaked_tracers = maybe_find_leaked_tracers(t())
|
||||
if leaked_tracers:
|
||||
raise Exception(f'Leaked sublevel {t()}. Leaked tracer(s): {leaked_tracers}.')
|
||||
|
||||
def full_lower(val):
|
||||
if isinstance(val, Tracer):
|
||||
|
@ -2596,6 +2596,23 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(Exception, r"Leaked sublevel"):
|
||||
f(3)
|
||||
|
||||
def test_leak_checker_avoids_false_positive_custom_jvp(self):
|
||||
# see https://github.com/google/jax/issues/5636
|
||||
with jax.checking_leaks():
|
||||
@api.custom_jvp
|
||||
def t(y):
|
||||
return y
|
||||
|
||||
def t_jvp(p, t):
|
||||
pass
|
||||
|
||||
t.defjvp(t_jvp)
|
||||
|
||||
@jit
|
||||
def s(y):
|
||||
return t(y)
|
||||
s(3) # doesn't crash
|
||||
|
||||
def test_default_backend(self):
|
||||
first_local_device = api.local_devices()[0]
|
||||
self.assertEqual(first_local_device.platform, api.default_backend())
|
||||
|
Loading…
x
Reference in New Issue
Block a user