mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jax_check_tracer_leaks: add warning about debuggers
This commit is contained in:
parent
799df7e54a
commit
b472ac3c46
@ -469,7 +469,9 @@ check_tracer_leaks = config.define_bool_state(
|
||||
default=False,
|
||||
help=('Turn on checking for leaked tracers as soon as a trace completes. '
|
||||
'Enabling leak checking may have performance impacts: some caching '
|
||||
'is disabled, and other overheads may be added.'))
|
||||
'is disabled, and other overheads may be added. Additionally, be aware '
|
||||
'that some Python debuggers can cause false positives, so it is recommended '
|
||||
'to disable any debuggers while leak checking is enabled.'))
|
||||
checking_leaks = functools.partial(check_tracer_leaks, True)
|
||||
|
||||
debug_nans = config.define_bool_state(
|
||||
|
12
jax/core.py
12
jax/core.py
@ -26,6 +26,7 @@ import types
|
||||
from typing import (Any, Callable, ClassVar, DefaultDict, Dict, Generator,
|
||||
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
|
||||
Type, Union, cast, Iterable, Hashable)
|
||||
import warnings
|
||||
from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
@ -770,6 +771,15 @@ def reset_trace_state() -> bool:
|
||||
def cur_sublevel() -> Sublevel:
|
||||
return thread_local_state.trace_state.substack[-1]
|
||||
|
||||
TRACER_LEAK_DEBUGGER_WARNING = """\
|
||||
JAX check_tracer_leaks behavior can trigger false positives when used with a debugger.
|
||||
To avoid false positives and silence this warning, you can disable thread tracing using
|
||||
the following:
|
||||
|
||||
import threading
|
||||
threading.current_thread().pydev_do_not_trace = True
|
||||
"""
|
||||
|
||||
def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]):
|
||||
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
|
||||
|
||||
@ -777,6 +787,8 @@ def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]):
|
||||
reference to `x` inside of a lambda closure, and no tracers were leaked
|
||||
by the user. In this case an empty list is returned.
|
||||
"""
|
||||
if not getattr(threading.current_thread(), 'pydev_do_not_trace', True):
|
||||
warnings.warn(TRACER_LEAK_DEBUGGER_WARNING)
|
||||
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
|
||||
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
|
||||
return tracers
|
||||
|
Loading…
x
Reference in New Issue
Block a user