jax_check_tracer_leaks: add warning about debuggers

This commit is contained in:
Jake VanderPlas 2021-11-08 09:21:18 -08:00
parent 799df7e54a
commit b472ac3c46
2 changed files with 15 additions and 1 deletions

View File

@ -469,7 +469,9 @@ check_tracer_leaks = config.define_bool_state(
default=False,
help=('Turn on checking for leaked tracers as soon as a trace completes. '
'Enabling leak checking may have performance impacts: some caching '
'is disabled, and other overheads may be added.'))
'is disabled, and other overheads may be added. Additionally, be aware '
'that some Python debuggers can cause false positives, so it is recommended '
'to disable any debuggers while leak checking is enabled.'))
checking_leaks = functools.partial(check_tracer_leaks, True)
debug_nans = config.define_bool_state(

View File

@ -26,6 +26,7 @@ import types
from typing import (Any, Callable, ClassVar, DefaultDict, Dict, Generator,
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
Type, Union, cast, Iterable, Hashable)
import warnings
from weakref import ref
import numpy as np
@ -770,6 +771,15 @@ def reset_trace_state() -> bool:
def cur_sublevel() -> Sublevel:
return thread_local_state.trace_state.substack[-1]
TRACER_LEAK_DEBUGGER_WARNING = """\
JAX check_tracer_leaks behavior can trigger false positives when used with a debugger.
To avoid false positives and silence this warning, you can disable thread tracing using
the following:
import threading
threading.current_thread().pydev_do_not_trace = True
"""
def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]):
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
@ -777,6 +787,8 @@ def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]):
reference to `x` inside of a lambda closure, and no tracers were leaked
by the user. In this case an empty list is returned.
"""
if not getattr(threading.current_thread(), 'pydev_do_not_trace', True):
warnings.warn(TRACER_LEAK_DEBUGGER_WARNING)
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
return tracers