Using etils(gfile) to support gcs buckets and file system for persistent compilation caching

This commit is contained in:
Shiva Shahrokhi 2022-05-19 21:06:25 +00:00
parent 859883cfae
commit 498ee6007d
6 changed files with 149 additions and 237 deletions

View File

@ -10,6 +10,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.3.14 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...main).
* Breaking changes
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support `max_cache_size_ bytes` anymore and will not get that as an input.
* Changes
* {func}`jax.numpy.linalg.slogdet` now accepts an optional `method` argument
that allows selection between an LU-decomposition based implementation and
@ -45,6 +47,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
traces as an alternative to the Tensorboard UI.
* Added a `jax.named_scope` context manager that adds profiler metadata to
Python programs (similar to `jax.named_call`).
* {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs
bucket path as input.
## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -19,21 +19,19 @@ import sys
from typing import List, Optional
import jax
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
from jax.experimental.compilation_cache.gfile_cache import GFileCache
import jax._src.lib
from jax._src.lib import xla_client
from absl import logging
_cache = None
def initialize_cache(path, max_cache_size_bytes=32 * 2**30):
def initialize_cache(path):
"""Creates a global cache object. Should only be called once per process.
max_cache_sixe defaults to 32GiB.
"""
global _cache
assert _cache == None, f"The cache path has already been initialized to {_cache._path}"
_cache = FileSystemCache(path, max_cache_size_bytes)
_cache = GFileCache(path)
logging.warning("Initialized persistent compilation cache at %s", path)
def get_executable(xla_computation, compile_options, backend) -> Optional[xla_client.Executable]:

View File

@ -1,82 +0,0 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from jax.experimental.compilation_cache.cache_interface import CacheInterface
import tempfile
from typing import Optional
import warnings
class FileSystemCache(CacheInterface):
def __init__(self, path: str, max_cache_size_bytes=32 * 2**30):
"""Sets up a cache at 'path'. Cached values may already be present."""
os.makedirs(path, exist_ok=True)
self._path = path
self._max_cache_size_bytes = max_cache_size_bytes
def get(self, key: str) -> Optional[bytes]:
"""Returns None if 'key' isn't present."""
if not key:
raise ValueError("key cannot be empty")
path_to_key = os.path.join(self._path, key)
if os.path.exists(path_to_key):
with open(path_to_key, "rb") as file:
return file.read()
else:
return None
def put(self, key: str, value: bytes):
"""Adds new cache entry, possibly evicting older entries."""
if not key:
raise ValueError("key cannot be empty")
if self._evict_entries_if_necessary(key, value):
path_to_new_file = os.path.join(self._path, key)
# Create the path for the file in a temporary directory so we can use the
# atomic move function to ensure that the file is properly stored and read
# in the case of concurrent access across multiple threads or processes
with tempfile.TemporaryDirectory() as tmpdir:
temp_path_to_file = os.path.join(tmpdir, key)
with open(temp_path_to_file, "wb") as file:
file.write(value)
file.flush()
os.fsync(file.fileno())
os.rename(temp_path_to_file, path_to_new_file)
else:
warnings.warn(f"Cache value of size {len(value)} is larger than"
f" the max cache size of {self._max_cache_size_bytes}")
def _evict_entries_if_necessary(self, key: str, value: bytes) -> bool:
"""Returns True if there's enough space to add 'value', False otherwise."""
new_file_size = len(value)
if new_file_size >= self._max_cache_size_bytes:
return False
while new_file_size + self._get_cache_directory_size() > self._max_cache_size_bytes:
last_time = float('inf')
file_to_delete = None
for file_name in os.listdir(self._path):
file_to_inspect = os.path.join(self._path, file_name)
atime = os.stat(file_to_inspect).st_atime
if atime < last_time:
last_time = atime
file_to_delete = file_to_inspect
assert file_to_delete
os.remove(file_to_delete)
return True
def _get_cache_directory_size(self):
"""Retrieves the current size of the directory, self.path"""
return sum(os.path.getsize(f) for f in os.scandir(self._path) if f.is_file())

View File

@ -0,0 +1,57 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
from jax.experimental.compilation_cache.cache_interface import CacheInterface
from etils import epath
from absl import logging
class GFileCache(CacheInterface):
def __init__(self, path: str):
"""Sets up a cache at 'path'. Cached values may already be present."""
self._path = epath.Path(path)
self._path.mkdir(parents=True, exist_ok=True)
def get(self, key: str):
"""Returns None if 'key' isn't present."""
if not key:
raise ValueError("key cannot be empty")
path_to_key = self._path / key
if path_to_key.exists():
return path_to_key.read_bytes()
else:
return None
def put(self, key: str, value: bytes):
"""Adds new cache entry."""
if not key:
raise ValueError("key cannot be empty")
path_to_new_file = self._path / key
if str(path_to_new_file).startswith('gs://'):
# Writes to gcs are atomic.
path_to_new_file.write_bytes(value)
elif str(path_to_new_file).startswith('file://') or '://' not in str(path_to_new_file):
tmp_path = self._path / f"_temp_{key}"
with open(str(tmp_path), "wb") as f:
f.write(value)
f.flush()
os.fsync(f.fileno())
os.rename(tmp_path, path_to_new_file)
else:
tmp_path = self._path / f"_temp_{key}"
tmp_path.write_bytes(value)
tmp_path.rename(str(path_to_new_file))

View File

@ -1,150 +0,0 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
import jax._src.test_util as jtu
import tempfile
import threading
import time
class FileSystemCacheTest(jtu.JaxTestCase):
def test_get_nonexistent_key(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir)
self.assertEqual(cache.get("nonExistentKey"), None)
def test_put_and_get_key(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir)
cache.put("foo", b"bar")
self.assertEqual(cache.get("foo"), b"bar")
def test_existing_cache_path(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache1 = FileSystemCache(tmpdir)
cache1.put("foo", b"bar")
del cache1
cache2 = FileSystemCache(tmpdir)
self.assertEqual(cache2.get("foo"), b"bar")
def test_empty_value_put(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir)
cache.put("foo", b"")
self.assertEqual(cache.get("foo"), b"")
def test_empty_key_put(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir)
with self.assertRaisesRegex(ValueError , r"key cannot be empty"):
cache.put("", b"bar")
def test_empty_key_get(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir)
with self.assertRaisesRegex(ValueError , r"key cannot be empty"):
cache.get("")
def test_size_of_directory(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir)
cache.put("foo", b"bar")
self.assertEqual(cache._get_cache_directory_size(), 3)
def test_size_of_empty_directory(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir)
self.assertEqual(cache._get_cache_directory_size(), 0)
def test_size_of_existing_directory(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache1 = FileSystemCache(tmpdir)
cache1.put("foo", b"bar")
del cache1
cache2 = FileSystemCache(tmpdir)
self.assertEqual(cache2._get_cache_directory_size(), 3)
def test_cache_is_full(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir, max_cache_size_bytes=6)
cache.put("first", b"one")
# Sleep because otherwise these operations execute too fast and
# the access time isn't captured properly.
time.sleep(1)
cache.put("second", b"two")
cache.put("third", b"the")
self.assertEqual(cache.get("first"), None)
self.assertEqual(cache.get("second"), b"two")
self.assertEqual(cache.get("third"), b"the")
def test_delete_multiple_files(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir, max_cache_size_bytes=6)
cache.put("first", b"one")
# Sleep because otherwise these operations execute too fast and
# the access time isn't captured properly.
time.sleep(1)
cache.put("second", b"two")
cache.put("third", b"three")
self.assertEqual(cache.get("first"), None)
self.assertEqual(cache.get("second"), None)
self.assertEqual(cache.get("third"), b"three")
def test_least_recently_accessed_file(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir, max_cache_size_bytes=6)
cache.put("first", b"one")
cache.put("second", b"two")
# Sleep because otherwise these operations execute too fast and
# the access time isn't captured properly.
time.sleep(1)
cache.get("first")
cache.put("third", b"the")
self.assertEqual(cache.get("first"), b"one")
self.assertEqual(cache.get("second"), None)
@jtu.ignore_warning(message=("Cache value of size 3 is larger than the max cache size of 2"))
def test_file_bigger_than_cache(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir, max_cache_size_bytes=2)
cache.put("foo", b"bar")
self.assertEqual(cache.get("foo"), None)
def test_threads(self):
file_contents1 = "1" * (65536 + 1)
file_contents2 = "2" * (65536 + 1)
def call_multiple_puts_and_gets(cache):
for i in range(50):
cache.put("foo", file_contents1.encode('utf-8').strip())
cache.put("foo", file_contents2.encode('utf-8').strip())
cache.get("foo")
self.assertEqual(cache.get("foo"), file_contents2.encode('utf-8').strip())
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileSystemCache(tmpdir)
threads = []
for i in range(50):
t = threading.Thread(target=call_multiple_puts_and_gets(cache))
t.start()
threads.append(t)
for t in threads:
t.join()
self.assertEqual(cache.get("foo"), file_contents2.encode('utf-8').strip())
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

85
tests/gfile_cache_test.py Normal file
View File

@ -0,0 +1,85 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from jax.experimental.compilation_cache.gfile_cache import GFileCache
import jax._src.test_util as jtu
import tempfile
import threading
class FileSystemCacheTest(jtu.JaxTestCase):
def test_get_nonexistent_key(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = GFileCache(tmpdir)
self.assertEqual(cache.get("nonExistentKey"), None)
def test_put_and_get_key(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = GFileCache(tmpdir)
cache.put("foo", b"bar")
self.assertEqual(cache.get("foo"), b"bar")
def test_existing_cache_path(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache1 = GFileCache(tmpdir)
cache1.put("foo", b"bar")
del cache1
cache2 = GFileCache(tmpdir)
self.assertEqual(cache2.get("foo"), b"bar")
def test_empty_value_put(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = GFileCache(tmpdir)
cache.put("foo", b"")
self.assertEqual(cache.get("foo"), b"")
def test_empty_key_put(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = GFileCache(tmpdir)
with self.assertRaisesRegex(ValueError , r"key cannot be empty"):
cache.put("", b"bar")
def test_empty_key_get(self):
with tempfile.TemporaryDirectory() as tmpdir:
cache = GFileCache(tmpdir)
with self.assertRaisesRegex(ValueError , r"key cannot be empty"):
cache.get("")
def test_threads(self):
file_contents1 = "1" * (65536 + 1)
file_contents2 = "2" * (65536 + 1)
def call_multiple_puts_and_gets(cache):
for i in range(50):
cache.put("foo", file_contents1.encode('utf-8').strip())
cache.put("foo", file_contents2.encode('utf-8').strip())
cache.get("foo")
self.assertEqual(cache.get("foo"), file_contents2.encode('utf-8').strip())
with tempfile.TemporaryDirectory() as tmpdir:
cache = GFileCache(tmpdir)
threads = []
for i in range(50):
t = threading.Thread(target=call_multiple_puts_and_gets(cache))
t.start()
threads.append(t)
for t in threads:
t.join()
self.assertEqual(cache.get("foo"), file_contents2.encode('utf-8').strip())
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())