mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
83 lines
3.2 KiB
Python
83 lines
3.2 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
from jax.experimental.compilation_cache.cache_interface import CacheInterface
|
|
import tempfile
|
|
from typing import Optional
|
|
import warnings
|
|
|
|
class FileSystemCache(CacheInterface):
|
|
|
|
def __init__(self, path: str, max_cache_size_bytes=32 * 2**30):
|
|
"""Sets up a cache at 'path'. Cached values may already be present."""
|
|
os.makedirs(path, exist_ok=True)
|
|
self._path = path
|
|
self._max_cache_size_bytes = max_cache_size_bytes
|
|
|
|
def get(self, key: str) -> Optional[bytes]:
|
|
"""Returns None if 'key' isn't present."""
|
|
if not key:
|
|
raise ValueError("key cannot be empty")
|
|
path_to_key = os.path.join(self._path, key)
|
|
if os.path.exists(path_to_key):
|
|
with open(path_to_key, "rb") as file:
|
|
return file.read()
|
|
else:
|
|
return None
|
|
|
|
def put(self, key: str, value: bytes):
|
|
"""Adds new cache entry, possibly evicting older entries."""
|
|
if not key:
|
|
raise ValueError("key cannot be empty")
|
|
if self._evict_entries_if_necessary(key, value):
|
|
path_to_new_file = os.path.join(self._path, key)
|
|
# Create the path for the file in a temporary directory so we can use the
|
|
# atomic move function to ensure that the file is properly stored and read
|
|
# in the case of concurrent access across multiple threads or processes
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
temp_path_to_file = os.path.join(tmpdir, key)
|
|
with open(temp_path_to_file, "wb") as file:
|
|
file.write(value)
|
|
file.flush()
|
|
os.fsync(file.fileno())
|
|
os.rename(temp_path_to_file, path_to_new_file)
|
|
else:
|
|
warnings.warn(f"Cache value of size {len(value)} is larger than"
|
|
f" the max cache size of {self._max_cache_size_bytes}")
|
|
|
|
def _evict_entries_if_necessary(self, key: str, value: bytes) -> bool:
|
|
"""Returns True if there's enough space to add 'value', False otherwise."""
|
|
new_file_size = len(value)
|
|
|
|
if new_file_size >= self._max_cache_size_bytes:
|
|
return False
|
|
|
|
while new_file_size + self._get_cache_directory_size() > self._max_cache_size_bytes:
|
|
last_time = float('inf')
|
|
file_to_delete = None
|
|
for file_name in os.listdir(self._path):
|
|
file_to_inspect = os.path.join(self._path, file_name)
|
|
atime = os.stat(file_to_inspect).st_atime
|
|
if atime < last_time:
|
|
last_time = atime
|
|
file_to_delete = file_to_inspect
|
|
assert file_to_delete
|
|
os.remove(file_to_delete)
|
|
return True
|
|
|
|
def _get_cache_directory_size(self):
|
|
"""Retrieves the current size of the directory, self.path"""
|
|
return sum(os.path.getsize(f) for f in os.scandir(self._path) if f.is_file())
|