Merge pull request #10110 from pschuh:weakref-bug

PiperOrigin-RevId: 438887762
This commit is contained in:
jax authors 2022-04-01 12:45:35 -07:00
commit 1c3edc811d
2 changed files with 48 additions and 2 deletions

View File

@ -222,12 +222,28 @@ memoize = cache(max_size=None)
CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
def weakref_lru_cache(call: Callable, maxsize=2048):
"""
Least recently used cache decorator with weakref support.
The cache will take a weakref to the first argument of the wrapped function
and strong refs to all subsequent operations. In all other respects it should
behave similar to `functools.lru_cache`.
"""
cache: Dict[Any, Any] = {}
hits = misses = 0
lock = threading.Lock()
def remove_key(tctx, args, kwargs, weak_arg):
del cache[(weak_arg, tctx, args, kwargs)]
k = (weak_arg, tctx, args, kwargs)
try:
# This has a chance to race with the iteration in next(iter(cache)),
# but we cannot lock because GC can get triggered synchronously inside
# a critical section and will not relinquish control until the callback
# has finished. This would lead to a deadlock between this weakref
# cleanup function and any function below which locks.
del cache[k]
except KeyError:
pass
def wrapped(weak_arg, *args, **kwargs):
nonlocal hits, misses
@ -250,8 +266,22 @@ def weakref_lru_cache(call: Callable, maxsize=2048):
result = call(weak_arg, *args, **kwargs)
with lock:
cache[k] = result
num_errors = 0
while len(cache) > maxsize:
del cache[next(iter(cache))]
try:
del_k = next(iter(cache))
# This happens if a weakref callback happens between iter and
# next. Just ignore the error. WeakKeyDictionary handles this
# by deferring the deletes, but that has a chance at leaking,
# and this solution is easier.
except RuntimeError:
num_errors += 1
if num_errors > len(cache):
# This must be some other problem.
raise
else:
continue
del cache[del_k]
return result
def cache_info():

View File

@ -18,6 +18,7 @@ from jax import linear_util as lu
from jax._src import test_util as jtu
from jax.config import config
from jax._src.util import weakref_lru_cache
config.parse_flags_with_absl()
FLAGS = config.FLAGS
@ -63,6 +64,21 @@ class UtilTest(jtu.JaxTestCase):
self.assertEqual(dict(three=6, four=8), scaled_kwargs)
self.assertEqual(2, out_thunk())
def test_weakref_lru_cache(self):
@weakref_lru_cache
def example_cached_fn(key):
return object()
class Key:
def __init__(self):
# Make a GC loop.
self.ref_loop = [self]
stable_keys = [Key() for _ in range(2049)]
for i in range(10000):
example_cached_fn(stable_keys[i % len(stable_keys)])
example_cached_fn(Key())
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())