diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d5fa9428..5ac7e96e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. --> ## jax 0.3.21 +* Changes + * The persistent compilation cache will now warn instead of raising an + exception on error ({jax-issue}`#12582`), so program execution can continue + if something goes wrong with the cache. Set + `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` to revert this behavior. ## jaxlib 0.3.21 diff --git a/jax/_src/config.py b/jax/_src/config.py index e4421f925..e4c6411e1 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -699,6 +699,16 @@ enable_custom_vjp_by_custom_transpose = config.define_bool_state( help=('Enables an internal upgrade that implements `jax.custom_vjp` by ' 'reduction to `jax.custom_jvp` and `jax.custom_transpose`.')) +raise_persistent_cache_errors = config.define_bool_state( + name='jax_raise_persistent_cache_errors', + default=False, + help=('If true, exceptions raised when reading or writing to the ' + 'persistent compilation cache will be allowed through, halting ' + 'program execution if not manually caught. If false, exceptions are ' + 'caught and raised as warnings, allowing program execution to ' + 'continue. Defaults to false so cache bugs or intermittent issues ' + 'are non-fatal.')) + hlo_source_file_canonicalization_regex = config.define_string_state( name='jax_hlo_source_file_canonicalization_regex', default=None, diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 60d11034b..531049c33 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -78,6 +78,7 @@ Device = xc.Device Buffer = xe.Buffer XlaExecutable = xc.Executable +CompileOptions = xc.CompileOptions map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -1016,6 +1017,10 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options, sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value + + if FLAGS.jax_dump_ir_to: + _dump_ir_to_file(module_name, mlir.module_to_string(computation)) + # Convert ir.Module to a string representation, unless the # back-end expliclity flags the ability to handle a module directly # (avoiding the overhead of back and forth conversions) @@ -1036,23 +1041,57 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options, if "--xla_gpu_enable_xla_runtime_executable=true" in os.environ.get("XLA_FLAGS", ""): supported_platforms.append("gpu") if cc.is_initialized() and backend.platform in supported_platforms: - cached_executable = cc.get_executable(serialized_computation, - compile_options, backend) + cached_executable = _cache_read(serialized_computation, module_name, + compile_options, backend) if cached_executable is not None: - logging.info('Persistent compilation cache hit for %s.', module_name) + logging.info("Persistent compilation cache hit for '%s'", module_name) return cached_executable else: compiled = backend_compile(backend, serialized_computation, compile_options, host_callbacks) - cc.put_executable(module_name, serialized_computation, compile_options, - compiled, backend) + _cache_write(serialized_computation, module_name, compile_options, + backend, compiled) return compiled - if FLAGS.jax_dump_ir_to: - _dump_ir_to_file(module_name, mlir.module_to_string(computation)) return backend_compile(backend, serialized_computation, compile_options, host_callbacks) +def _cache_read(computation: Union[str, bytes, ir.Module], + module_name: str, + compile_options: CompileOptions, + backend: Backend) -> Optional[XlaExecutable]: + """Looks up `computation` in the persisent compilation cache.""" + # Avoid import cycle between jax and jax.experimental + from jax.experimental.compilation_cache import compilation_cache as cc + + try: + return cc.get_executable(computation, compile_options, backend) + except Exception as ex: + if config.jax_raise_persistent_cache_errors: + raise + warnings.warn( + f"Error reading persistent compilation cache entry for " + f"'{module_name}': {type(ex).__name__}: {ex}") + return None + +def _cache_write(computation: Union[str, bytes, ir.Module], + module_name: str, + compile_options: CompileOptions, + backend: Backend, + compiled: XlaExecutable): + """Writes `computation` to the persistent compilation cache.""" + # Avoid import cycle between jax and jax.experimental + from jax.experimental.compilation_cache import compilation_cache as cc + + try: + cc.put_executable(module_name, computation, compile_options, compiled, + backend) + except Exception as ex: + if config.jax_raise_persistent_cache_errors: + raise + warnings.warn( + f"Error writing persistent compilation cache entry for " + f"'{module_name}': {type(ex).__name__}: {ex}") def get_buffer_counts(out_avals, ordered_effects, has_unordered_effects): buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals] diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index bf959dd54..c603d1cdf 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -19,7 +19,8 @@ import random import sys import tempfile import unittest -from unittest import SkipTest +from unittest import mock, SkipTest +import warnings from absl.testing import absltest from jax.experimental import PartitionSpec as P @@ -35,9 +36,12 @@ from jax._src.lib import xla_client import numpy as np from jax.config import config +from jax._src.config import raise_persistent_cache_errors + config.parse_flags_with_absl() FLAGS = config.FLAGS +@jtu.with_config(jax_raise_persistent_cache_errors=True) class CompilationCacheTest(jtu.JaxTestCase): def setUp(self): @@ -295,6 +299,38 @@ class CompilationCacheTest(jtu.JaxTestCase): files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 2) + def test_cache_write_warning(self): + with tempfile.TemporaryDirectory() as tmpdir: + cc.initialize_cache(tmpdir) + f = jit(lambda x: x*x) + + with raise_persistent_cache_errors(False), \ + mock.patch.object(cc._cache.__class__, 'put') as mock_put, \ + warnings.catch_warnings(record=True) as w: + mock_put.side_effect = RuntimeError("test error") + self.assertEqual(f(2), 4) + self.assertLen(w, 1) + self.assertIn( + "Error writing persistent compilation cache entry " + "for 'jit__lambda_': RuntimeError: test error", + str(w[0].message)) + + def test_cache_read_warning(self): + with tempfile.TemporaryDirectory() as tmpdir: + cc.initialize_cache(tmpdir) + f = jit(lambda x: x*x) + + with raise_persistent_cache_errors(False), \ + mock.patch.object(cc._cache.__class__, 'get') as mock_get, \ + warnings.catch_warnings(record=True) as w: + mock_get.side_effect = RuntimeError("test error") + self.assertEqual(f(2), 4) + self.assertLen(w, 1) + self.assertIn( + "Error reading persistent compilation cache entry " + "for 'jit__lambda_': RuntimeError: test error", + str(w[0].message)) + def create_new_debug_options(self, debug_options_obj): debug_options_obj.xla_cpu_enable_fast_math = False debug_options_obj.xla_cpu_fast_math_honor_infs = False