Make persistent compilation cache warn instead of raise an error on cache read/write failures

Fixes #12582. Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.

Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
This commit is contained in:
Skye Wanderman-Milne 2022-09-27 20:59:08 +00:00
parent 9ff570e6c3
commit 15e5f38a16
4 changed files with 98 additions and 8 deletions

View File

@ -9,6 +9,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
--> -->
## jax 0.3.21 ## jax 0.3.21
* Changes
* The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}`#12582`), so program execution can continue
if something goes wrong with the cache. Set
`JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` to revert this behavior.
## jaxlib 0.3.21 ## jaxlib 0.3.21

View File

@ -699,6 +699,16 @@ enable_custom_vjp_by_custom_transpose = config.define_bool_state(
help=('Enables an internal upgrade that implements `jax.custom_vjp` by ' help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.')) 'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))
raise_persistent_cache_errors = config.define_bool_state(
name='jax_raise_persistent_cache_errors',
default=False,
help=('If true, exceptions raised when reading or writing to the '
'persistent compilation cache will be allowed through, halting '
'program execution if not manually caught. If false, exceptions are '
'caught and raised as warnings, allowing program execution to '
'continue. Defaults to false so cache bugs or intermittent issues '
'are non-fatal.'))
hlo_source_file_canonicalization_regex = config.define_string_state( hlo_source_file_canonicalization_regex = config.define_string_state(
name='jax_hlo_source_file_canonicalization_regex', name='jax_hlo_source_file_canonicalization_regex',
default=None, default=None,

View File

@ -78,6 +78,7 @@ Device = xc.Device
Buffer = xe.Buffer Buffer = xe.Buffer
XlaExecutable = xc.Executable XlaExecutable = xc.Executable
CompileOptions = xc.CompileOptions
map, unsafe_map = util.safe_map, map map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip zip, unsafe_zip = util.safe_zip, zip
@ -1016,6 +1017,10 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
sym_name = computation.operation.attributes['sym_name'] sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value module_name = ir.StringAttr(sym_name).value
if FLAGS.jax_dump_ir_to:
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
# Convert ir.Module to a string representation, unless the # Convert ir.Module to a string representation, unless the
# back-end expliclity flags the ability to handle a module directly # back-end expliclity flags the ability to handle a module directly
# (avoiding the overhead of back and forth conversions) # (avoiding the overhead of back and forth conversions)
@ -1036,23 +1041,57 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
if "--xla_gpu_enable_xla_runtime_executable=true" in os.environ.get("XLA_FLAGS", ""): if "--xla_gpu_enable_xla_runtime_executable=true" in os.environ.get("XLA_FLAGS", ""):
supported_platforms.append("gpu") supported_platforms.append("gpu")
if cc.is_initialized() and backend.platform in supported_platforms: if cc.is_initialized() and backend.platform in supported_platforms:
cached_executable = cc.get_executable(serialized_computation, cached_executable = _cache_read(serialized_computation, module_name,
compile_options, backend) compile_options, backend)
if cached_executable is not None: if cached_executable is not None:
logging.info('Persistent compilation cache hit for %s.', module_name) logging.info("Persistent compilation cache hit for '%s'", module_name)
return cached_executable return cached_executable
else: else:
compiled = backend_compile(backend, serialized_computation, compiled = backend_compile(backend, serialized_computation,
compile_options, host_callbacks) compile_options, host_callbacks)
cc.put_executable(module_name, serialized_computation, compile_options, _cache_write(serialized_computation, module_name, compile_options,
compiled, backend) backend, compiled)
return compiled return compiled
if FLAGS.jax_dump_ir_to:
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
return backend_compile(backend, serialized_computation, compile_options, return backend_compile(backend, serialized_computation, compile_options,
host_callbacks) host_callbacks)
def _cache_read(computation: Union[str, bytes, ir.Module],
module_name: str,
compile_options: CompileOptions,
backend: Backend) -> Optional[XlaExecutable]:
"""Looks up `computation` in the persisent compilation cache."""
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
try:
return cc.get_executable(computation, compile_options, backend)
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
warnings.warn(
f"Error reading persistent compilation cache entry for "
f"'{module_name}': {type(ex).__name__}: {ex}")
return None
def _cache_write(computation: Union[str, bytes, ir.Module],
module_name: str,
compile_options: CompileOptions,
backend: Backend,
compiled: XlaExecutable):
"""Writes `computation` to the persistent compilation cache."""
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
try:
cc.put_executable(module_name, computation, compile_options, compiled,
backend)
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
warnings.warn(
f"Error writing persistent compilation cache entry for "
f"'{module_name}': {type(ex).__name__}: {ex}")
def get_buffer_counts(out_avals, ordered_effects, has_unordered_effects): def get_buffer_counts(out_avals, ordered_effects, has_unordered_effects):
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals] buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]

View File

@ -19,7 +19,8 @@ import random
import sys import sys
import tempfile import tempfile
import unittest import unittest
from unittest import SkipTest from unittest import mock, SkipTest
import warnings
from absl.testing import absltest from absl.testing import absltest
from jax.experimental import PartitionSpec as P from jax.experimental import PartitionSpec as P
@ -35,9 +36,12 @@ from jax._src.lib import xla_client
import numpy as np import numpy as np
from jax.config import config from jax.config import config
from jax._src.config import raise_persistent_cache_errors
config.parse_flags_with_absl() config.parse_flags_with_absl()
FLAGS = config.FLAGS FLAGS = config.FLAGS
@jtu.with_config(jax_raise_persistent_cache_errors=True)
class CompilationCacheTest(jtu.JaxTestCase): class CompilationCacheTest(jtu.JaxTestCase):
def setUp(self): def setUp(self):
@ -295,6 +299,38 @@ class CompilationCacheTest(jtu.JaxTestCase):
files_in_directory = len(os.listdir(tmpdir)) files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 2) self.assertEqual(files_in_directory, 2)
def test_cache_write_warning(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
f = jit(lambda x: x*x)
with raise_persistent_cache_errors(False), \
mock.patch.object(cc._cache.__class__, 'put') as mock_put, \
warnings.catch_warnings(record=True) as w:
mock_put.side_effect = RuntimeError("test error")
self.assertEqual(f(2), 4)
self.assertLen(w, 1)
self.assertIn(
"Error writing persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error",
str(w[0].message))
def test_cache_read_warning(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
f = jit(lambda x: x*x)
with raise_persistent_cache_errors(False), \
mock.patch.object(cc._cache.__class__, 'get') as mock_get, \
warnings.catch_warnings(record=True) as w:
mock_get.side_effect = RuntimeError("test error")
self.assertEqual(f(2), 4)
self.assertLen(w, 1)
self.assertIn(
"Error reading persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error",
str(w[0].message))
def create_new_debug_options(self, debug_options_obj): def create_new_debug_options(self, debug_options_obj):
debug_options_obj.xla_cpu_enable_fast_math = False debug_options_obj.xla_cpu_enable_fast_math = False
debug_options_obj.xla_cpu_fast_math_honor_infs = False debug_options_obj.xla_cpu_fast_math_honor_infs = False