Don't write atime file if JAX_COMPILATIION_CACHE_MAX_SIZE == -1

The atime file is only needed to implement the LRU eviction policy,
which is only needed if a max persistence compilation cache size is
set. Writing this file can cause network filesystem performace and
other issues, so only write it if users are opted-in.
This commit is contained in:
Skye Wanderman-Milne 2025-02-13 17:37:27 -08:00
parent 4b94665f4f
commit d5d43fc46e
3 changed files with 28 additions and 8 deletions

View File

@ -22,7 +22,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
details.
* Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`,
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
* {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support
column-pivoting on CPU and GPU. See {jax-issue}`#20282` and
@ -46,6 +46,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
info. This change does not affect public APIs.
See https://github.com/jax-ml/jax/issues/26480 for more detail.
* Bug fixes
* Persistent compilation cache no longer writes access time file if
JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU
eviction policy isn't enabled. This should improve performance when using
the cache with large-scale network storage.
## jax 0.5.0 (Jan 17, 2025)
As of this release, JAX now uses

View File

@ -94,7 +94,6 @@ class LRUCache(CacheInterface):
raise ValueError("key cannot be empty")
cache_path = self.path / f"{key}{_CACHE_SUFFIX}"
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)
@ -108,8 +107,10 @@ class LRUCache(CacheInterface):
val = cache_path.read_bytes()
timestamp = time.time_ns().to_bytes(8, "little")
atime_path.write_bytes(timestamp)
if self.eviction_enabled:
timestamp = time.time_ns().to_bytes(8, "little")
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
atime_path.write_bytes(timestamp)
return val
@ -138,7 +139,6 @@ class LRUCache(CacheInterface):
return
cache_path = self.path / f"{key}{_CACHE_SUFFIX}"
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)
@ -151,8 +151,10 @@ class LRUCache(CacheInterface):
cache_path.write_bytes(val)
timestamp = time.time_ns().to_bytes(8, "little")
atime_path.write_bytes(timestamp)
if self.eviction_enabled:
timestamp = time.time_ns().to_bytes(8, "little")
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
atime_path.write_bytes(timestamp)
finally:
if self.eviction_enabled:

View File

@ -21,7 +21,7 @@ import time
from absl.testing import absltest
from jax._src import path as pathlib
from jax._src.lru_cache import _CACHE_SUFFIX, LRUCache
from jax._src.lru_cache import _ATIME_SUFFIX, _CACHE_SUFFIX, LRUCache
import jax._src.test_util as jtu
@ -153,6 +153,18 @@ class LRUCacheTest(LRUCacheTestCase):
self.assertIsNone(cache.get("a"))
self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), set())
# Check that we don't write access time file when the eviction policy is
# disabled. Writing this file can be extremely unperformant and cause
# problems on large-scale network storage.
def test_no_atime_file(self):
cache = LRUCache(self.name, max_size=-1)
cache.put("a", b"a")
self.assertEmpty(list(self.path.glob(f"*{_ATIME_SUFFIX}")))
cache.get("a")
self.assertEmpty(list(self.path.glob(f"*{_ATIME_SUFFIX}")))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())