Make sublevel weak-referable, and enable the leak checker on sublevels.

Reimplement Sublevel to not inherit from `int`.
See docs on weakref: "CPython implementation detail: Other built-in types
such as tuple and int do not support weak references even when subclassed."
This commit is contained in:
Lena Martens 2021-03-18 17:32:33 +00:00 committed by lenamartens
parent 2947f562d5
commit d86dd24bf8
2 changed files with 41 additions and 7 deletions

View File

@ -688,7 +688,23 @@ class TraceStack:
new.dynamic = self.dynamic
return new
class Sublevel(int): pass
@total_ordering
class Sublevel:
def __init__(self, level: int):
self.level = level
def __repr__(self):
return str(self.level)
def __eq__(self, other):
return type(other) is Sublevel and self.level == other.level
def __lt__(self, other):
return type(other) is Sublevel and self.level < other.level
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])
AxisName = Hashable
@ -793,12 +809,11 @@ def new_sublevel() -> Generator[None, None, None]:
finally:
thread_local_state.trace_state.substack.pop()
# TODO(mattjj): to check sublevel leaks, we need to make Sublevel weakref-able
# if debug_state.check_leaks:
# t = ref(sublevel)
# del sublevel
# if t() is not None:
# raise Exception('Leaked sublevel {}'.format(t()))
if debug_state.check_leaks:
t = ref(sublevel)
del sublevel
if t() is not None:
raise Exception(f'Leaked sublevel {t()}.')
def maybe_new_sublevel(trace):
# dynamic traces run the WrappedFun, so we raise the sublevel for them

View File

@ -2329,6 +2329,25 @@ class APITest(jtu.JaxTestCase):
lax.scan(to_scan, x, None, length=1)
f(np.arange(5.)) # doesn't crash
def test_leak_checker_catches_a_sublevel_leak(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with core.checking_leaks():
@jit
def f(x):
lst = []
@jit
def g(x):
lst.append(x)
return x
x = g(x)
return x
with self.assertRaisesRegex(Exception, r"Leaked sublevel"):
f(3)
def test_default_backend(self):
first_local_device = api.local_devices()[0]
self.assertEqual(first_local_device.platform, api.default_backend())