mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

The atime file is only needed to implement the LRU eviction policy, which is only needed if a max persistence compilation cache size is set. Writing this file can cause network filesystem performace and other issues, so only write it if users are opted-in.
171 lines
4.8 KiB
Python
171 lines
4.8 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import tempfile
|
|
import time
|
|
|
|
from absl.testing import absltest
|
|
|
|
from jax._src import path as pathlib
|
|
from jax._src.lru_cache import _ATIME_SUFFIX, _CACHE_SUFFIX, LRUCache
|
|
import jax._src.test_util as jtu
|
|
|
|
|
|
class LRUCacheTestCase(jtu.JaxTestCase):
|
|
name: str | None
|
|
path: pathlib.Path | None
|
|
|
|
def setUp(self):
|
|
if importlib.util.find_spec("filelock") is None:
|
|
self.skipTest("filelock is not installed")
|
|
|
|
super().setUp()
|
|
tmpdir = tempfile.TemporaryDirectory()
|
|
self.enter_context(tmpdir)
|
|
self.name = tmpdir.name
|
|
self.path = pathlib.Path(self.name)
|
|
|
|
def tearDown(self):
|
|
self.path = None
|
|
self.name = None
|
|
super().tearDown()
|
|
|
|
def assertCacheKeys(self, keys):
|
|
self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), {self.path / f"{key}{_CACHE_SUFFIX}" for key in keys})
|
|
|
|
|
|
class LRUCacheTest(LRUCacheTestCase):
|
|
|
|
def test_get_nonexistent_key(self):
|
|
cache = LRUCache(self.name, max_size=-1)
|
|
self.assertIsNone(cache.get("a"))
|
|
|
|
def test_put_and_get_key(self):
|
|
cache = LRUCache(self.name, max_size=-1)
|
|
|
|
cache.put("a", b"a")
|
|
self.assertEqual(cache.get("a"), b"a")
|
|
self.assertCacheKeys(("a",))
|
|
|
|
cache.put("b", b"b")
|
|
self.assertEqual(cache.get("a"), b"a")
|
|
self.assertEqual(cache.get("b"), b"b")
|
|
self.assertCacheKeys(("a", "b"))
|
|
|
|
def test_put_empty_value(self):
|
|
cache = LRUCache(self.name, max_size=-1)
|
|
|
|
cache.put("a", b"")
|
|
self.assertEqual(cache.get("a"), b"")
|
|
|
|
def test_put_empty_key(self):
|
|
cache = LRUCache(self.name, max_size=-1)
|
|
|
|
with self.assertRaisesRegex(ValueError, r"key cannot be empty"):
|
|
cache.put("", b"a")
|
|
|
|
def test_eviction(self):
|
|
cache = LRUCache(self.name, max_size=2)
|
|
|
|
cache.put("a", b"a")
|
|
cache.put("b", b"b")
|
|
|
|
# `sleep()` is necessary to guarantee that `b`'s timestamp is strictly greater than `a`'s
|
|
time.sleep(1)
|
|
cache.get("b")
|
|
|
|
# write `c`. `a` should be evicted
|
|
cache.put("c", b"c")
|
|
self.assertCacheKeys(("b", "c"))
|
|
|
|
# calling `get()` on `b` makes `c` least recently used
|
|
time.sleep(1)
|
|
cache.get("b")
|
|
|
|
# write `d`. `c` should be evicted
|
|
cache.put("d", b"d")
|
|
self.assertCacheKeys(("b", "d"))
|
|
|
|
def test_eviction_with_empty_value(self):
|
|
cache = LRUCache(self.name, max_size=1)
|
|
|
|
cache.put("a", b"a")
|
|
|
|
# write `b` with length 0
|
|
# eviction should not happen even though the cache is full
|
|
cache.put("b", b"")
|
|
self.assertCacheKeys(("a", "b"))
|
|
|
|
# calling `get()` on `a` makes `b` least recently used
|
|
time.sleep(1)
|
|
cache.get("a")
|
|
|
|
# writing `c` should result in evicting the
|
|
# least recent used file (`b`) first,
|
|
# but this is not sufficient to make room for `c`,
|
|
# so `a` should be evicted as well
|
|
cache.put("c", b"c")
|
|
self.assertCacheKeys(("c",))
|
|
|
|
def test_existing_cache_dir(self):
|
|
cache = LRUCache(self.name, max_size=2)
|
|
|
|
cache.put("a", b"a")
|
|
|
|
# simulates reinitializing the cache in another process
|
|
del cache
|
|
cache = LRUCache(self.name, max_size=2)
|
|
|
|
self.assertEqual(cache.get("a"), b"a")
|
|
|
|
# ensure that the LRU policy survives cache reinitialization
|
|
cache.put("b", b"b")
|
|
|
|
# calling `get()` on `a` makes `b` least recently used
|
|
time.sleep(1)
|
|
cache.get("a")
|
|
|
|
# write `c`. `b` should be evicted
|
|
cache.put("c", b"c")
|
|
self.assertCacheKeys(("a", "c"))
|
|
|
|
def test_max_size(self):
|
|
cache = LRUCache(self.name, max_size=1)
|
|
|
|
msg = (r"Cache value for key .+? of size \d+ bytes exceeds the maximum "
|
|
r"cache size of \d+ bytes")
|
|
with self.assertWarnsRegex(UserWarning, msg):
|
|
cache.put("a", b"aaaa")
|
|
self.assertIsNone(cache.get("a"))
|
|
self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), set())
|
|
|
|
# Check that we don't write access time file when the eviction policy is
|
|
# disabled. Writing this file can be extremely unperformant and cause
|
|
# problems on large-scale network storage.
|
|
def test_no_atime_file(self):
|
|
cache = LRUCache(self.name, max_size=-1)
|
|
|
|
cache.put("a", b"a")
|
|
self.assertEmpty(list(self.path.glob(f"*{_ATIME_SUFFIX}")))
|
|
|
|
cache.get("a")
|
|
self.assertEmpty(list(self.path.glob(f"*{_ATIME_SUFFIX}")))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|