diff --git a/CHANGELOG.md b/CHANGELOG.md index cc17898fb..414e19a0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.3.14 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...main). +* Breaking changes + * {func}`jax.experimental.compilation_cache.initialize_cache` does not support `max_cache_size_ bytes` anymore and will not get that as an input. * Changes * {func}`jax.numpy.linalg.slogdet` now accepts an optional `method` argument that allows selection between an LU-decomposition based implementation and @@ -45,6 +47,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. traces as an alternative to the Tensorboard UI. * Added a `jax.named_scope` context manager that adds profiler metadata to Python programs (similar to `jax.named_call`). + * {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs + bucket path as input. ## jaxlib 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py index 3077de6bd..1e64313ec 100644 --- a/jax/experimental/compilation_cache/compilation_cache.py +++ b/jax/experimental/compilation_cache/compilation_cache.py @@ -19,21 +19,19 @@ import sys from typing import List, Optional import jax -from jax.experimental.compilation_cache.file_system_cache import FileSystemCache +from jax.experimental.compilation_cache.gfile_cache import GFileCache import jax._src.lib from jax._src.lib import xla_client from absl import logging _cache = None -def initialize_cache(path, max_cache_size_bytes=32 * 2**30): +def initialize_cache(path): """Creates a global cache object. Should only be called once per process. - - max_cache_sixe defaults to 32GiB. """ global _cache assert _cache == None, f"The cache path has already been initialized to {_cache._path}" - _cache = FileSystemCache(path, max_cache_size_bytes) + _cache = GFileCache(path) logging.warning("Initialized persistent compilation cache at %s", path) def get_executable(xla_computation, compile_options, backend) -> Optional[xla_client.Executable]: diff --git a/jax/experimental/compilation_cache/file_system_cache.py b/jax/experimental/compilation_cache/file_system_cache.py deleted file mode 100644 index b85969de2..000000000 --- a/jax/experimental/compilation_cache/file_system_cache.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from jax.experimental.compilation_cache.cache_interface import CacheInterface -import tempfile -from typing import Optional -import warnings - -class FileSystemCache(CacheInterface): - - def __init__(self, path: str, max_cache_size_bytes=32 * 2**30): - """Sets up a cache at 'path'. Cached values may already be present.""" - os.makedirs(path, exist_ok=True) - self._path = path - self._max_cache_size_bytes = max_cache_size_bytes - - def get(self, key: str) -> Optional[bytes]: - """Returns None if 'key' isn't present.""" - if not key: - raise ValueError("key cannot be empty") - path_to_key = os.path.join(self._path, key) - if os.path.exists(path_to_key): - with open(path_to_key, "rb") as file: - return file.read() - else: - return None - - def put(self, key: str, value: bytes): - """Adds new cache entry, possibly evicting older entries.""" - if not key: - raise ValueError("key cannot be empty") - if self._evict_entries_if_necessary(key, value): - path_to_new_file = os.path.join(self._path, key) - # Create the path for the file in a temporary directory so we can use the - # atomic move function to ensure that the file is properly stored and read - # in the case of concurrent access across multiple threads or processes - with tempfile.TemporaryDirectory() as tmpdir: - temp_path_to_file = os.path.join(tmpdir, key) - with open(temp_path_to_file, "wb") as file: - file.write(value) - file.flush() - os.fsync(file.fileno()) - os.rename(temp_path_to_file, path_to_new_file) - else: - warnings.warn(f"Cache value of size {len(value)} is larger than" - f" the max cache size of {self._max_cache_size_bytes}") - - def _evict_entries_if_necessary(self, key: str, value: bytes) -> bool: - """Returns True if there's enough space to add 'value', False otherwise.""" - new_file_size = len(value) - - if new_file_size >= self._max_cache_size_bytes: - return False - - while new_file_size + self._get_cache_directory_size() > self._max_cache_size_bytes: - last_time = float('inf') - file_to_delete = None - for file_name in os.listdir(self._path): - file_to_inspect = os.path.join(self._path, file_name) - atime = os.stat(file_to_inspect).st_atime - if atime < last_time: - last_time = atime - file_to_delete = file_to_inspect - assert file_to_delete - os.remove(file_to_delete) - return True - - def _get_cache_directory_size(self): - """Retrieves the current size of the directory, self.path""" - return sum(os.path.getsize(f) for f in os.scandir(self._path) if f.is_file()) diff --git a/jax/experimental/compilation_cache/gfile_cache.py b/jax/experimental/compilation_cache/gfile_cache.py new file mode 100644 index 000000000..9c3dc4579 --- /dev/null +++ b/jax/experimental/compilation_cache/gfile_cache.py @@ -0,0 +1,57 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib + +from jax.experimental.compilation_cache.cache_interface import CacheInterface +from etils import epath +from absl import logging + +class GFileCache(CacheInterface): + + def __init__(self, path: str): + """Sets up a cache at 'path'. Cached values may already be present.""" + self._path = epath.Path(path) + self._path.mkdir(parents=True, exist_ok=True) + + def get(self, key: str): + """Returns None if 'key' isn't present.""" + if not key: + raise ValueError("key cannot be empty") + path_to_key = self._path / key + if path_to_key.exists(): + return path_to_key.read_bytes() + else: + return None + + def put(self, key: str, value: bytes): + """Adds new cache entry.""" + if not key: + raise ValueError("key cannot be empty") + path_to_new_file = self._path / key + if str(path_to_new_file).startswith('gs://'): + # Writes to gcs are atomic. + path_to_new_file.write_bytes(value) + elif str(path_to_new_file).startswith('file://') or '://' not in str(path_to_new_file): + tmp_path = self._path / f"_temp_{key}" + with open(str(tmp_path), "wb") as f: + f.write(value) + f.flush() + os.fsync(f.fileno()) + os.rename(tmp_path, path_to_new_file) + else: + tmp_path = self._path / f"_temp_{key}" + tmp_path.write_bytes(value) + tmp_path.rename(str(path_to_new_file)) diff --git a/tests/file_system_cache_test.py b/tests/file_system_cache_test.py deleted file mode 100644 index 21d41aa4a..000000000 --- a/tests/file_system_cache_test.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -from jax.experimental.compilation_cache.file_system_cache import FileSystemCache -import jax._src.test_util as jtu -import tempfile -import threading -import time - -class FileSystemCacheTest(jtu.JaxTestCase): - - def test_get_nonexistent_key(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir) - self.assertEqual(cache.get("nonExistentKey"), None) - - def test_put_and_get_key(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir) - cache.put("foo", b"bar") - self.assertEqual(cache.get("foo"), b"bar") - - def test_existing_cache_path(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache1 = FileSystemCache(tmpdir) - cache1.put("foo", b"bar") - del cache1 - cache2 = FileSystemCache(tmpdir) - self.assertEqual(cache2.get("foo"), b"bar") - - def test_empty_value_put(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir) - cache.put("foo", b"") - self.assertEqual(cache.get("foo"), b"") - - def test_empty_key_put(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir) - with self.assertRaisesRegex(ValueError , r"key cannot be empty"): - cache.put("", b"bar") - - def test_empty_key_get(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir) - with self.assertRaisesRegex(ValueError , r"key cannot be empty"): - cache.get("") - - def test_size_of_directory(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir) - cache.put("foo", b"bar") - self.assertEqual(cache._get_cache_directory_size(), 3) - - def test_size_of_empty_directory(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir) - self.assertEqual(cache._get_cache_directory_size(), 0) - - def test_size_of_existing_directory(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache1 = FileSystemCache(tmpdir) - cache1.put("foo", b"bar") - del cache1 - cache2 = FileSystemCache(tmpdir) - self.assertEqual(cache2._get_cache_directory_size(), 3) - - def test_cache_is_full(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir, max_cache_size_bytes=6) - cache.put("first", b"one") - # Sleep because otherwise these operations execute too fast and - # the access time isn't captured properly. - time.sleep(1) - cache.put("second", b"two") - cache.put("third", b"the") - self.assertEqual(cache.get("first"), None) - self.assertEqual(cache.get("second"), b"two") - self.assertEqual(cache.get("third"), b"the") - - def test_delete_multiple_files(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir, max_cache_size_bytes=6) - cache.put("first", b"one") - # Sleep because otherwise these operations execute too fast and - # the access time isn't captured properly. - time.sleep(1) - cache.put("second", b"two") - cache.put("third", b"three") - self.assertEqual(cache.get("first"), None) - self.assertEqual(cache.get("second"), None) - self.assertEqual(cache.get("third"), b"three") - - def test_least_recently_accessed_file(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir, max_cache_size_bytes=6) - cache.put("first", b"one") - cache.put("second", b"two") - # Sleep because otherwise these operations execute too fast and - # the access time isn't captured properly. - time.sleep(1) - cache.get("first") - cache.put("third", b"the") - self.assertEqual(cache.get("first"), b"one") - self.assertEqual(cache.get("second"), None) - - @jtu.ignore_warning(message=("Cache value of size 3 is larger than the max cache size of 2")) - def test_file_bigger_than_cache(self): - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir, max_cache_size_bytes=2) - cache.put("foo", b"bar") - self.assertEqual(cache.get("foo"), None) - - def test_threads(self): - file_contents1 = "1" * (65536 + 1) - file_contents2 = "2" * (65536 + 1) - - def call_multiple_puts_and_gets(cache): - for i in range(50): - cache.put("foo", file_contents1.encode('utf-8').strip()) - cache.put("foo", file_contents2.encode('utf-8').strip()) - cache.get("foo") - self.assertEqual(cache.get("foo"), file_contents2.encode('utf-8').strip()) - - with tempfile.TemporaryDirectory() as tmpdir: - cache = FileSystemCache(tmpdir) - threads = [] - for i in range(50): - t = threading.Thread(target=call_multiple_puts_and_gets(cache)) - t.start() - threads.append(t) - for t in threads: - t.join() - - self.assertEqual(cache.get("foo"), file_contents2.encode('utf-8').strip()) - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/gfile_cache_test.py b/tests/gfile_cache_test.py new file mode 100644 index 000000000..8f3bf4aac --- /dev/null +++ b/tests/gfile_cache_test.py @@ -0,0 +1,85 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from jax.experimental.compilation_cache.gfile_cache import GFileCache +import jax._src.test_util as jtu +import tempfile +import threading + +class FileSystemCacheTest(jtu.JaxTestCase): + + def test_get_nonexistent_key(self): + with tempfile.TemporaryDirectory() as tmpdir: + cache = GFileCache(tmpdir) + self.assertEqual(cache.get("nonExistentKey"), None) + + def test_put_and_get_key(self): + with tempfile.TemporaryDirectory() as tmpdir: + cache = GFileCache(tmpdir) + cache.put("foo", b"bar") + self.assertEqual(cache.get("foo"), b"bar") + + def test_existing_cache_path(self): + with tempfile.TemporaryDirectory() as tmpdir: + cache1 = GFileCache(tmpdir) + cache1.put("foo", b"bar") + del cache1 + cache2 = GFileCache(tmpdir) + self.assertEqual(cache2.get("foo"), b"bar") + + def test_empty_value_put(self): + with tempfile.TemporaryDirectory() as tmpdir: + cache = GFileCache(tmpdir) + cache.put("foo", b"") + self.assertEqual(cache.get("foo"), b"") + + def test_empty_key_put(self): + with tempfile.TemporaryDirectory() as tmpdir: + cache = GFileCache(tmpdir) + with self.assertRaisesRegex(ValueError , r"key cannot be empty"): + cache.put("", b"bar") + + def test_empty_key_get(self): + with tempfile.TemporaryDirectory() as tmpdir: + cache = GFileCache(tmpdir) + with self.assertRaisesRegex(ValueError , r"key cannot be empty"): + cache.get("") + + + def test_threads(self): + file_contents1 = "1" * (65536 + 1) + file_contents2 = "2" * (65536 + 1) + + def call_multiple_puts_and_gets(cache): + for i in range(50): + cache.put("foo", file_contents1.encode('utf-8').strip()) + cache.put("foo", file_contents2.encode('utf-8').strip()) + cache.get("foo") + self.assertEqual(cache.get("foo"), file_contents2.encode('utf-8').strip()) + + with tempfile.TemporaryDirectory() as tmpdir: + cache = GFileCache(tmpdir) + threads = [] + for i in range(50): + t = threading.Thread(target=call_multiple_puts_and_gets(cache)) + t.start() + threads.append(t) + for t in threads: + t.join() + + self.assertEqual(cache.get("foo"), file_contents2.encode('utf-8').strip()) + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())