diff --git a/CHANGELOG.md b/CHANGELOG.md index 11f93483d..c8efeaacf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more details. * Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`, - {func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`, + {func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`, {func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`. * {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support column-pivoting on CPU and GPU. See {jax-issue}`#20282` and @@ -46,6 +46,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. info. This change does not affect public APIs. See https://github.com/jax-ml/jax/issues/26480 for more detail. +* Bug fixes + * Persistent compilation cache no longer writes access time file if + JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU + eviction policy isn't enabled. This should improve performance when using + the cache with large-scale network storage. + ## jax 0.5.0 (Jan 17, 2025) As of this release, JAX now uses diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py index 7e3f53568..c8626e58f 100644 --- a/jax/_src/lru_cache.py +++ b/jax/_src/lru_cache.py @@ -94,7 +94,6 @@ class LRUCache(CacheInterface): raise ValueError("key cannot be empty") cache_path = self.path / f"{key}{_CACHE_SUFFIX}" - atime_path = self.path / f"{key}{_ATIME_SUFFIX}" if self.eviction_enabled: self.lock.acquire(timeout=self.lock_timeout_secs) @@ -108,8 +107,10 @@ class LRUCache(CacheInterface): val = cache_path.read_bytes() - timestamp = time.time_ns().to_bytes(8, "little") - atime_path.write_bytes(timestamp) + if self.eviction_enabled: + timestamp = time.time_ns().to_bytes(8, "little") + atime_path = self.path / f"{key}{_ATIME_SUFFIX}" + atime_path.write_bytes(timestamp) return val @@ -138,7 +139,6 @@ class LRUCache(CacheInterface): return cache_path = self.path / f"{key}{_CACHE_SUFFIX}" - atime_path = self.path / f"{key}{_ATIME_SUFFIX}" if self.eviction_enabled: self.lock.acquire(timeout=self.lock_timeout_secs) @@ -151,8 +151,10 @@ class LRUCache(CacheInterface): cache_path.write_bytes(val) - timestamp = time.time_ns().to_bytes(8, "little") - atime_path.write_bytes(timestamp) + if self.eviction_enabled: + timestamp = time.time_ns().to_bytes(8, "little") + atime_path = self.path / f"{key}{_ATIME_SUFFIX}" + atime_path.write_bytes(timestamp) finally: if self.eviction_enabled: diff --git a/tests/lru_cache_test.py b/tests/lru_cache_test.py index 588eccddb..438142c2b 100644 --- a/tests/lru_cache_test.py +++ b/tests/lru_cache_test.py @@ -21,7 +21,7 @@ import time from absl.testing import absltest from jax._src import path as pathlib -from jax._src.lru_cache import _CACHE_SUFFIX, LRUCache +from jax._src.lru_cache import _ATIME_SUFFIX, _CACHE_SUFFIX, LRUCache import jax._src.test_util as jtu @@ -153,6 +153,18 @@ class LRUCacheTest(LRUCacheTestCase): self.assertIsNone(cache.get("a")) self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), set()) + # Check that we don't write access time file when the eviction policy is + # disabled. Writing this file can be extremely unperformant and cause + # problems on large-scale network storage. + def test_no_atime_file(self): + cache = LRUCache(self.name, max_size=-1) + + cache.put("a", b"a") + self.assertEmpty(list(self.path.glob(f"*{_ATIME_SUFFIX}"))) + + cache.get("a") + self.assertEmpty(list(self.path.glob(f"*{_ATIME_SUFFIX}"))) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())