mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #10110 from pschuh:weakref-bug
PiperOrigin-RevId: 438887762
This commit is contained in:
commit
1c3edc811d
@ -222,12 +222,28 @@ memoize = cache(max_size=None)
|
||||
CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
|
||||
|
||||
def weakref_lru_cache(call: Callable, maxsize=2048):
|
||||
"""
|
||||
Least recently used cache decorator with weakref support.
|
||||
|
||||
The cache will take a weakref to the first argument of the wrapped function
|
||||
and strong refs to all subsequent operations. In all other respects it should
|
||||
behave similar to `functools.lru_cache`.
|
||||
"""
|
||||
cache: Dict[Any, Any] = {}
|
||||
hits = misses = 0
|
||||
lock = threading.Lock()
|
||||
|
||||
def remove_key(tctx, args, kwargs, weak_arg):
|
||||
del cache[(weak_arg, tctx, args, kwargs)]
|
||||
k = (weak_arg, tctx, args, kwargs)
|
||||
try:
|
||||
# This has a chance to race with the iteration in next(iter(cache)),
|
||||
# but we cannot lock because GC can get triggered synchronously inside
|
||||
# a critical section and will not relinquish control until the callback
|
||||
# has finished. This would lead to a deadlock between this weakref
|
||||
# cleanup function and any function below which locks.
|
||||
del cache[k]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def wrapped(weak_arg, *args, **kwargs):
|
||||
nonlocal hits, misses
|
||||
@ -250,8 +266,22 @@ def weakref_lru_cache(call: Callable, maxsize=2048):
|
||||
result = call(weak_arg, *args, **kwargs)
|
||||
with lock:
|
||||
cache[k] = result
|
||||
num_errors = 0
|
||||
while len(cache) > maxsize:
|
||||
del cache[next(iter(cache))]
|
||||
try:
|
||||
del_k = next(iter(cache))
|
||||
# This happens if a weakref callback happens between iter and
|
||||
# next. Just ignore the error. WeakKeyDictionary handles this
|
||||
# by deferring the deletes, but that has a chance at leaking,
|
||||
# and this solution is easier.
|
||||
except RuntimeError:
|
||||
num_errors += 1
|
||||
if num_errors > len(cache):
|
||||
# This must be some other problem.
|
||||
raise
|
||||
else:
|
||||
continue
|
||||
del cache[del_k]
|
||||
return result
|
||||
|
||||
def cache_info():
|
||||
|
@ -18,6 +18,7 @@ from jax import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
from jax._src.util import weakref_lru_cache
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
@ -63,6 +64,21 @@ class UtilTest(jtu.JaxTestCase):
|
||||
self.assertEqual(dict(three=6, four=8), scaled_kwargs)
|
||||
self.assertEqual(2, out_thunk())
|
||||
|
||||
def test_weakref_lru_cache(self):
|
||||
@weakref_lru_cache
|
||||
def example_cached_fn(key):
|
||||
return object()
|
||||
|
||||
class Key:
|
||||
def __init__(self):
|
||||
# Make a GC loop.
|
||||
self.ref_loop = [self]
|
||||
|
||||
stable_keys = [Key() for _ in range(2049)]
|
||||
for i in range(10000):
|
||||
example_cached_fn(stable_keys[i % len(stable_keys)])
|
||||
example_cached_fn(Key())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user