mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make sublevel weak-referable, and enable the leak checker on sublevels.
Reimplement Sublevel to not inherit from `int`. See docs on weakref: "CPython implementation detail: Other built-in types such as tuple and int do not support weak references even when subclassed."
This commit is contained in:
parent
2947f562d5
commit
d86dd24bf8
29
jax/core.py
29
jax/core.py
@ -688,7 +688,23 @@ class TraceStack:
|
||||
new.dynamic = self.dynamic
|
||||
return new
|
||||
|
||||
class Sublevel(int): pass
|
||||
|
||||
@total_ordering
|
||||
class Sublevel:
|
||||
|
||||
def __init__(self, level: int):
|
||||
self.level = level
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.level)
|
||||
|
||||
def __eq__(self, other):
|
||||
return type(other) is Sublevel and self.level == other.level
|
||||
|
||||
def __lt__(self, other):
|
||||
return type(other) is Sublevel and self.level < other.level
|
||||
|
||||
|
||||
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])
|
||||
AxisName = Hashable
|
||||
|
||||
@ -793,12 +809,11 @@ def new_sublevel() -> Generator[None, None, None]:
|
||||
finally:
|
||||
thread_local_state.trace_state.substack.pop()
|
||||
|
||||
# TODO(mattjj): to check sublevel leaks, we need to make Sublevel weakref-able
|
||||
# if debug_state.check_leaks:
|
||||
# t = ref(sublevel)
|
||||
# del sublevel
|
||||
# if t() is not None:
|
||||
# raise Exception('Leaked sublevel {}'.format(t()))
|
||||
if debug_state.check_leaks:
|
||||
t = ref(sublevel)
|
||||
del sublevel
|
||||
if t() is not None:
|
||||
raise Exception(f'Leaked sublevel {t()}.')
|
||||
|
||||
def maybe_new_sublevel(trace):
|
||||
# dynamic traces run the WrappedFun, so we raise the sublevel for them
|
||||
|
@ -2329,6 +2329,25 @@ class APITest(jtu.JaxTestCase):
|
||||
lax.scan(to_scan, x, None, length=1)
|
||||
f(np.arange(5.)) # doesn't crash
|
||||
|
||||
def test_leak_checker_catches_a_sublevel_leak(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with core.checking_leaks():
|
||||
@jit
|
||||
def f(x):
|
||||
lst = []
|
||||
@jit
|
||||
def g(x):
|
||||
lst.append(x)
|
||||
return x
|
||||
|
||||
x = g(x)
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(Exception, r"Leaked sublevel"):
|
||||
f(3)
|
||||
|
||||
def test_default_backend(self):
|
||||
first_local_device = api.local_devices()[0]
|
||||
self.assertEqual(first_local_device.platform, api.default_backend())
|
||||
|
Loading…
x
Reference in New Issue
Block a user