Make persistent compilation cache warn instead of raise an error on cache read/write failures

Fixes #12582. Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.

Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
This commit is contained in:
Skye Wanderman-Milne 2022-09-27 20:59:08 +00:00
parent 9ff570e6c3
commit 15e5f38a16
4 changed files with 98 additions and 8 deletions

View File

@ -9,6 +9,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
-->
## jax 0.3.21
* Changes
* The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}`#12582`), so program execution can continue
if something goes wrong with the cache. Set
`JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` to revert this behavior.
## jaxlib 0.3.21

View File

@ -699,6 +699,16 @@ enable_custom_vjp_by_custom_transpose = config.define_bool_state(
help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))
raise_persistent_cache_errors = config.define_bool_state(
name='jax_raise_persistent_cache_errors',
default=False,
help=('If true, exceptions raised when reading or writing to the '
'persistent compilation cache will be allowed through, halting '
'program execution if not manually caught. If false, exceptions are '
'caught and raised as warnings, allowing program execution to '
'continue. Defaults to false so cache bugs or intermittent issues '
'are non-fatal.'))
hlo_source_file_canonicalization_regex = config.define_string_state(
name='jax_hlo_source_file_canonicalization_regex',
default=None,

View File

@ -78,6 +78,7 @@ Device = xc.Device
Buffer = xe.Buffer
XlaExecutable = xc.Executable
CompileOptions = xc.CompileOptions
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
@ -1016,6 +1017,10 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value
if FLAGS.jax_dump_ir_to:
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
# Convert ir.Module to a string representation, unless the
# back-end expliclity flags the ability to handle a module directly
# (avoiding the overhead of back and forth conversions)
@ -1036,23 +1041,57 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
if "--xla_gpu_enable_xla_runtime_executable=true" in os.environ.get("XLA_FLAGS", ""):
supported_platforms.append("gpu")
if cc.is_initialized() and backend.platform in supported_platforms:
cached_executable = cc.get_executable(serialized_computation,
compile_options, backend)
cached_executable = _cache_read(serialized_computation, module_name,
compile_options, backend)
if cached_executable is not None:
logging.info('Persistent compilation cache hit for %s.', module_name)
logging.info("Persistent compilation cache hit for '%s'", module_name)
return cached_executable
else:
compiled = backend_compile(backend, serialized_computation,
compile_options, host_callbacks)
cc.put_executable(module_name, serialized_computation, compile_options,
compiled, backend)
_cache_write(serialized_computation, module_name, compile_options,
backend, compiled)
return compiled
if FLAGS.jax_dump_ir_to:
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
return backend_compile(backend, serialized_computation, compile_options,
host_callbacks)
def _cache_read(computation: Union[str, bytes, ir.Module],
module_name: str,
compile_options: CompileOptions,
backend: Backend) -> Optional[XlaExecutable]:
"""Looks up `computation` in the persisent compilation cache."""
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
try:
return cc.get_executable(computation, compile_options, backend)
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
warnings.warn(
f"Error reading persistent compilation cache entry for "
f"'{module_name}': {type(ex).__name__}: {ex}")
return None
def _cache_write(computation: Union[str, bytes, ir.Module],
module_name: str,
compile_options: CompileOptions,
backend: Backend,
compiled: XlaExecutable):
"""Writes `computation` to the persistent compilation cache."""
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
try:
cc.put_executable(module_name, computation, compile_options, compiled,
backend)
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
warnings.warn(
f"Error writing persistent compilation cache entry for "
f"'{module_name}': {type(ex).__name__}: {ex}")
def get_buffer_counts(out_avals, ordered_effects, has_unordered_effects):
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]

View File

@ -19,7 +19,8 @@ import random
import sys
import tempfile
import unittest
from unittest import SkipTest
from unittest import mock, SkipTest
import warnings
from absl.testing import absltest
from jax.experimental import PartitionSpec as P
@ -35,9 +36,12 @@ from jax._src.lib import xla_client
import numpy as np
from jax.config import config
from jax._src.config import raise_persistent_cache_errors
config.parse_flags_with_absl()
FLAGS = config.FLAGS
@jtu.with_config(jax_raise_persistent_cache_errors=True)
class CompilationCacheTest(jtu.JaxTestCase):
def setUp(self):
@ -295,6 +299,38 @@ class CompilationCacheTest(jtu.JaxTestCase):
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 2)
def test_cache_write_warning(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
f = jit(lambda x: x*x)
with raise_persistent_cache_errors(False), \
mock.patch.object(cc._cache.__class__, 'put') as mock_put, \
warnings.catch_warnings(record=True) as w:
mock_put.side_effect = RuntimeError("test error")
self.assertEqual(f(2), 4)
self.assertLen(w, 1)
self.assertIn(
"Error writing persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error",
str(w[0].message))
def test_cache_read_warning(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
f = jit(lambda x: x*x)
with raise_persistent_cache_errors(False), \
mock.patch.object(cc._cache.__class__, 'get') as mock_get, \
warnings.catch_warnings(record=True) as w:
mock_get.side_effect = RuntimeError("test error")
self.assertEqual(f(2), 4)
self.assertLen(w, 1)
self.assertIn(
"Error reading persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error",
str(w[0].message))
def create_new_debug_options(self, debug_options_obj):
debug_options_obj.xla_cpu_enable_fast_math = False
debug_options_obj.xla_cpu_fast_math_honor_infs = False