mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Storing the last access time of a cache entry in a separate file
This commit is contained in:
parent
2e0c100fef
commit
934142dff4
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
import heapq
|
||||
import logging
|
||||
import pathlib
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from jax._src.compilation_cache_interface import CacheInterface
|
||||
@ -31,6 +32,10 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_CACHE_SUFFIX = "-cache"
|
||||
_ATIME_SUFFIX = "-atime"
|
||||
|
||||
|
||||
class LRUCache(CacheInterface):
|
||||
"""Bounded cache with least-recently-used (LRU) eviction policy.
|
||||
|
||||
@ -87,19 +92,25 @@ class LRUCache(CacheInterface):
|
||||
if not key:
|
||||
raise ValueError("key cannot be empty")
|
||||
|
||||
file = self.path / key
|
||||
cache_path = self.path / f"{key}{_CACHE_SUFFIX}"
|
||||
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
|
||||
|
||||
if self.eviction_enabled:
|
||||
self.lock.acquire(timeout=self.lock_timeout_secs)
|
||||
|
||||
try:
|
||||
if not file.exists():
|
||||
if not cache_path.exists():
|
||||
logger.debug(f"Cache miss for key: {key!r}")
|
||||
return None
|
||||
|
||||
logger.debug(f"Cache hit for key: {key!r}")
|
||||
file.touch() # update mtime
|
||||
return file.read_bytes()
|
||||
|
||||
val = cache_path.read_bytes()
|
||||
|
||||
timestamp = time.time_ns().to_bytes(8, "little")
|
||||
atime_path.write_bytes(timestamp)
|
||||
|
||||
return val
|
||||
|
||||
finally:
|
||||
if self.eviction_enabled:
|
||||
@ -125,17 +136,22 @@ class LRUCache(CacheInterface):
|
||||
warnings.warn(msg)
|
||||
return
|
||||
|
||||
file = self.path / key
|
||||
cache_path = self.path / f"{key}{_CACHE_SUFFIX}"
|
||||
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
|
||||
|
||||
if self.eviction_enabled:
|
||||
self.lock.acquire(timeout=self.lock_timeout_secs)
|
||||
|
||||
try:
|
||||
if file.exists():
|
||||
if cache_path.exists():
|
||||
return
|
||||
|
||||
self._evict_if_needed(additional_size=len(val))
|
||||
file.write_bytes(val)
|
||||
|
||||
cache_path.write_bytes(val)
|
||||
|
||||
timestamp = time.time_ns().to_bytes(8, "little")
|
||||
atime_path.write_bytes(timestamp)
|
||||
|
||||
finally:
|
||||
if self.eviction_enabled:
|
||||
@ -153,26 +169,34 @@ class LRUCache(CacheInterface):
|
||||
if not self.eviction_enabled:
|
||||
return
|
||||
|
||||
# a priority queue, each element is a tuple `(file_mtime, file, file_size)`
|
||||
h: list[tuple[int, pathlib.Path, int]] = []
|
||||
# a priority queue, each element is a tuple `(file_atime, key, file_size)`
|
||||
h: list[tuple[int, str, int]] = []
|
||||
dir_size = 0
|
||||
for file in self.path.iterdir():
|
||||
if file.is_file() and file != self.lock_path:
|
||||
file_size = file.stat().st_size
|
||||
file_mtime = file.stat().st_mtime_ns
|
||||
for cache_path in self.path.glob(f"*{_CACHE_SUFFIX}"):
|
||||
file_size = cache_path.stat().st_size
|
||||
|
||||
dir_size += file_size
|
||||
heapq.heappush(h, (file_mtime, file, file_size))
|
||||
key = cache_path.name.removesuffix(_CACHE_SUFFIX)
|
||||
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
|
||||
file_atime = int.from_bytes(atime_path.read_bytes(), "little")
|
||||
|
||||
dir_size += file_size
|
||||
heapq.heappush(h, (file_atime, key, file_size))
|
||||
|
||||
target_size = self.max_size - additional_size
|
||||
# evict files until the directory size is less than or equal
|
||||
# to `target_size`
|
||||
while dir_size > target_size:
|
||||
file_mtime, file, file_size = heapq.heappop(h)
|
||||
msg = (f"Evicting cache file {file.name}: file size {file_size} bytes, "
|
||||
f"target cache size {target_size} bytes")
|
||||
logger.debug(msg)
|
||||
file.unlink()
|
||||
file_atime, key, file_size = heapq.heappop(h)
|
||||
|
||||
logger.debug("Evicting cache entry %r: file size %d bytes, "
|
||||
"target cache size %d bytes", key, file_size, target_size)
|
||||
|
||||
cache_path = self.path / f"{key}{_CACHE_SUFFIX}"
|
||||
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
|
||||
|
||||
cache_path.unlink()
|
||||
atime_path.unlink()
|
||||
|
||||
dir_size -= file_size
|
||||
|
||||
# See comments in `jax.src.compilation_cache.get_file_cache()` for details.
|
||||
|
@ -21,7 +21,7 @@ import time
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax._src import path as pathlib
|
||||
from jax._src.lru_cache import LRUCache
|
||||
from jax._src.lru_cache import _CACHE_SUFFIX, LRUCache
|
||||
import jax._src.test_util as jtu
|
||||
|
||||
|
||||
@ -44,30 +44,33 @@ class LRUCacheTestCase(jtu.JaxTestCase):
|
||||
self.name = None
|
||||
super().tearDown()
|
||||
|
||||
def assertCacheKeys(self, keys):
|
||||
self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), {self.path / f"{key}{_CACHE_SUFFIX}" for key in keys})
|
||||
|
||||
|
||||
class LRUCacheTest(LRUCacheTestCase):
|
||||
|
||||
def test_get_nonexistent_key(self):
|
||||
cache = LRUCache(self.name, max_size=-1)
|
||||
self.assertIsNone(cache.get("cache-a"))
|
||||
self.assertIsNone(cache.get("a"))
|
||||
|
||||
def test_put_and_get_key(self):
|
||||
cache = LRUCache(self.name, max_size=-1)
|
||||
|
||||
cache.put("cache-a", b"a")
|
||||
self.assertEqual(cache.get("cache-a"), b"a")
|
||||
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a"})
|
||||
cache.put("a", b"a")
|
||||
self.assertEqual(cache.get("a"), b"a")
|
||||
self.assertCacheKeys(("a",))
|
||||
|
||||
cache.put("cache-b", b"b")
|
||||
self.assertEqual(cache.get("cache-a"), b"a")
|
||||
self.assertEqual(cache.get("cache-b"), b"b")
|
||||
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"})
|
||||
cache.put("b", b"b")
|
||||
self.assertEqual(cache.get("a"), b"a")
|
||||
self.assertEqual(cache.get("b"), b"b")
|
||||
self.assertCacheKeys(("a", "b"))
|
||||
|
||||
def test_put_empty_value(self):
|
||||
cache = LRUCache(self.name, max_size=-1)
|
||||
|
||||
cache.put("cache-a", b"")
|
||||
self.assertEqual(cache.get("cache-a"), b"")
|
||||
cache.put("a", b"")
|
||||
self.assertEqual(cache.get("a"), b"")
|
||||
|
||||
def test_put_empty_key(self):
|
||||
cache = LRUCache(self.name, max_size=-1)
|
||||
@ -78,67 +81,67 @@ class LRUCacheTest(LRUCacheTestCase):
|
||||
def test_eviction(self):
|
||||
cache = LRUCache(self.name, max_size=2)
|
||||
|
||||
cache.put("cache-a", b"a")
|
||||
cache.put("cache-b", b"b")
|
||||
cache.put("a", b"a")
|
||||
cache.put("b", b"b")
|
||||
|
||||
# `sleep()` is necessary to guarantee that `cache-b`"s timestamp is strictly greater than `cache-a`"s
|
||||
# `sleep()` is necessary to guarantee that `b`'s timestamp is strictly greater than `a`'s
|
||||
time.sleep(1)
|
||||
cache.get("cache-b")
|
||||
cache.get("b")
|
||||
|
||||
# write `cache-c`, evict `cache-a`
|
||||
cache.put("cache-c", b"c")
|
||||
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-c"})
|
||||
# write `c`. `a` should be evicted
|
||||
cache.put("c", b"c")
|
||||
self.assertCacheKeys(("b", "c"))
|
||||
|
||||
# calling `get()` on `cache-b` makes `cache-c` least recently used
|
||||
# calling `get()` on `b` makes `c` least recently used
|
||||
time.sleep(1)
|
||||
cache.get("cache-b")
|
||||
cache.get("b")
|
||||
|
||||
# write `cache-d`, evict `cache-c`
|
||||
cache.put("cache-d", b"d")
|
||||
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-b", self.path / "cache-d"})
|
||||
# write `d`. `c` should be evicted
|
||||
cache.put("d", b"d")
|
||||
self.assertCacheKeys(("b", "d"))
|
||||
|
||||
def test_eviction_with_empty_value(self):
|
||||
cache = LRUCache(self.name, max_size=1)
|
||||
|
||||
cache.put("cache-a", b"a")
|
||||
cache.put("a", b"a")
|
||||
|
||||
# write `cache-b` with length 0
|
||||
# write `b` with length 0
|
||||
# eviction should not happen even though the cache is full
|
||||
cache.put("cache-b", b"")
|
||||
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-b"})
|
||||
cache.put("b", b"")
|
||||
self.assertCacheKeys(("a", "b"))
|
||||
|
||||
# calling `get()` on `cache-a` makes `cache-b` least recently used
|
||||
# calling `get()` on `a` makes `b` least recently used
|
||||
time.sleep(1)
|
||||
cache.get("cache-a")
|
||||
cache.get("a")
|
||||
|
||||
# writing `cache-c` should result in evicting the
|
||||
# least recent used file (`cache-b`) first,
|
||||
# but this is not sufficient to make room for `cache-c`,
|
||||
# so `cache-a` should be evicted as well
|
||||
cache.put("cache-c", b"c")
|
||||
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-c"})
|
||||
# writing `c` should result in evicting the
|
||||
# least recent used file (`b`) first,
|
||||
# but this is not sufficient to make room for `c`,
|
||||
# so `a` should be evicted as well
|
||||
cache.put("c", b"c")
|
||||
self.assertCacheKeys(("c",))
|
||||
|
||||
def test_existing_cache_dir(self):
|
||||
cache = LRUCache(self.name, max_size=2)
|
||||
|
||||
cache.put("cache-a", b"a")
|
||||
cache.put("a", b"a")
|
||||
|
||||
# simulates reinitializing the cache in another process
|
||||
del cache
|
||||
cache = LRUCache(self.name, max_size=2)
|
||||
|
||||
self.assertEqual(cache.get("cache-a"), b"a")
|
||||
self.assertEqual(cache.get("a"), b"a")
|
||||
|
||||
# ensure that the LRU policy survives cache reinitialization
|
||||
cache.put("cache-b", b"b")
|
||||
cache.put("b", b"b")
|
||||
|
||||
# calling `get()` on `cache-a` makes `cache-b` least recently used
|
||||
# calling `get()` on `a` makes `b` least recently used
|
||||
time.sleep(1)
|
||||
cache.get("cache-a")
|
||||
cache.get("a")
|
||||
|
||||
# write `cache-c`, evict `cache-b`
|
||||
cache.put("cache-c", b"c")
|
||||
self.assertEqual(set(self.path.glob("cache-*")), {self.path / "cache-a", self.path / "cache-c"})
|
||||
# write `c`. `b` should be evicted
|
||||
cache.put("c", b"c")
|
||||
self.assertCacheKeys(("a", "c"))
|
||||
|
||||
def test_max_size(self):
|
||||
cache = LRUCache(self.name, max_size=1)
|
||||
@ -146,9 +149,9 @@ class LRUCacheTest(LRUCacheTestCase):
|
||||
msg = (r"Cache value for key .+? of size \d+ bytes exceeds the maximum "
|
||||
r"cache size of \d+ bytes")
|
||||
with self.assertWarnsRegex(UserWarning, msg):
|
||||
cache.put("cache-a", b"aaaa")
|
||||
self.assertIsNone(cache.get("cache-a"))
|
||||
self.assertEqual(set(self.path.glob("cache-*")), set())
|
||||
cache.put("a", b"aaaa")
|
||||
self.assertIsNone(cache.get("a"))
|
||||
self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), set())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user