mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Make persistent compilation cache warn instead of raise an error on cache read/write failures
Fixes #12582. Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning. Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
This commit is contained in:
parent
9ff570e6c3
commit
15e5f38a16
@ -9,6 +9,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
-->
|
||||
|
||||
## jax 0.3.21
|
||||
* Changes
|
||||
* The persistent compilation cache will now warn instead of raising an
|
||||
exception on error ({jax-issue}`#12582`), so program execution can continue
|
||||
if something goes wrong with the cache. Set
|
||||
`JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` to revert this behavior.
|
||||
|
||||
## jaxlib 0.3.21
|
||||
|
||||
|
@ -699,6 +699,16 @@ enable_custom_vjp_by_custom_transpose = config.define_bool_state(
|
||||
help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
|
||||
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))
|
||||
|
||||
raise_persistent_cache_errors = config.define_bool_state(
|
||||
name='jax_raise_persistent_cache_errors',
|
||||
default=False,
|
||||
help=('If true, exceptions raised when reading or writing to the '
|
||||
'persistent compilation cache will be allowed through, halting '
|
||||
'program execution if not manually caught. If false, exceptions are '
|
||||
'caught and raised as warnings, allowing program execution to '
|
||||
'continue. Defaults to false so cache bugs or intermittent issues '
|
||||
'are non-fatal.'))
|
||||
|
||||
hlo_source_file_canonicalization_regex = config.define_string_state(
|
||||
name='jax_hlo_source_file_canonicalization_regex',
|
||||
default=None,
|
||||
|
@ -78,6 +78,7 @@ Device = xc.Device
|
||||
Buffer = xe.Buffer
|
||||
|
||||
XlaExecutable = xc.Executable
|
||||
CompileOptions = xc.CompileOptions
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
@ -1016,6 +1017,10 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
||||
|
||||
sym_name = computation.operation.attributes['sym_name']
|
||||
module_name = ir.StringAttr(sym_name).value
|
||||
|
||||
if FLAGS.jax_dump_ir_to:
|
||||
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
|
||||
|
||||
# Convert ir.Module to a string representation, unless the
|
||||
# back-end expliclity flags the ability to handle a module directly
|
||||
# (avoiding the overhead of back and forth conversions)
|
||||
@ -1036,23 +1041,57 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
||||
if "--xla_gpu_enable_xla_runtime_executable=true" in os.environ.get("XLA_FLAGS", ""):
|
||||
supported_platforms.append("gpu")
|
||||
if cc.is_initialized() and backend.platform in supported_platforms:
|
||||
cached_executable = cc.get_executable(serialized_computation,
|
||||
compile_options, backend)
|
||||
cached_executable = _cache_read(serialized_computation, module_name,
|
||||
compile_options, backend)
|
||||
if cached_executable is not None:
|
||||
logging.info('Persistent compilation cache hit for %s.', module_name)
|
||||
logging.info("Persistent compilation cache hit for '%s'", module_name)
|
||||
return cached_executable
|
||||
else:
|
||||
compiled = backend_compile(backend, serialized_computation,
|
||||
compile_options, host_callbacks)
|
||||
cc.put_executable(module_name, serialized_computation, compile_options,
|
||||
compiled, backend)
|
||||
_cache_write(serialized_computation, module_name, compile_options,
|
||||
backend, compiled)
|
||||
return compiled
|
||||
|
||||
if FLAGS.jax_dump_ir_to:
|
||||
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
|
||||
return backend_compile(backend, serialized_computation, compile_options,
|
||||
host_callbacks)
|
||||
|
||||
def _cache_read(computation: Union[str, bytes, ir.Module],
|
||||
module_name: str,
|
||||
compile_options: CompileOptions,
|
||||
backend: Backend) -> Optional[XlaExecutable]:
|
||||
"""Looks up `computation` in the persisent compilation cache."""
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
|
||||
try:
|
||||
return cc.get_executable(computation, compile_options, backend)
|
||||
except Exception as ex:
|
||||
if config.jax_raise_persistent_cache_errors:
|
||||
raise
|
||||
warnings.warn(
|
||||
f"Error reading persistent compilation cache entry for "
|
||||
f"'{module_name}': {type(ex).__name__}: {ex}")
|
||||
return None
|
||||
|
||||
def _cache_write(computation: Union[str, bytes, ir.Module],
|
||||
module_name: str,
|
||||
compile_options: CompileOptions,
|
||||
backend: Backend,
|
||||
compiled: XlaExecutable):
|
||||
"""Writes `computation` to the persistent compilation cache."""
|
||||
# Avoid import cycle between jax and jax.experimental
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
|
||||
try:
|
||||
cc.put_executable(module_name, computation, compile_options, compiled,
|
||||
backend)
|
||||
except Exception as ex:
|
||||
if config.jax_raise_persistent_cache_errors:
|
||||
raise
|
||||
warnings.warn(
|
||||
f"Error writing persistent compilation cache entry for "
|
||||
f"'{module_name}': {type(ex).__name__}: {ex}")
|
||||
|
||||
def get_buffer_counts(out_avals, ordered_effects, has_unordered_effects):
|
||||
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]
|
||||
|
@ -19,7 +19,8 @@ import random
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
from unittest import mock, SkipTest
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax.experimental import PartitionSpec as P
|
||||
@ -35,9 +36,12 @@ from jax._src.lib import xla_client
|
||||
import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
from jax._src.config import raise_persistent_cache_errors
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
@jtu.with_config(jax_raise_persistent_cache_errors=True)
|
||||
class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -295,6 +299,38 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
files_in_directory = len(os.listdir(tmpdir))
|
||||
self.assertEqual(files_in_directory, 2)
|
||||
|
||||
def test_cache_write_warning(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
f = jit(lambda x: x*x)
|
||||
|
||||
with raise_persistent_cache_errors(False), \
|
||||
mock.patch.object(cc._cache.__class__, 'put') as mock_put, \
|
||||
warnings.catch_warnings(record=True) as w:
|
||||
mock_put.side_effect = RuntimeError("test error")
|
||||
self.assertEqual(f(2), 4)
|
||||
self.assertLen(w, 1)
|
||||
self.assertIn(
|
||||
"Error writing persistent compilation cache entry "
|
||||
"for 'jit__lambda_': RuntimeError: test error",
|
||||
str(w[0].message))
|
||||
|
||||
def test_cache_read_warning(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
f = jit(lambda x: x*x)
|
||||
|
||||
with raise_persistent_cache_errors(False), \
|
||||
mock.patch.object(cc._cache.__class__, 'get') as mock_get, \
|
||||
warnings.catch_warnings(record=True) as w:
|
||||
mock_get.side_effect = RuntimeError("test error")
|
||||
self.assertEqual(f(2), 4)
|
||||
self.assertLen(w, 1)
|
||||
self.assertIn(
|
||||
"Error reading persistent compilation cache entry "
|
||||
"for 'jit__lambda_': RuntimeError: test error",
|
||||
str(w[0].message))
|
||||
|
||||
def create_new_debug_options(self, debug_options_obj):
|
||||
debug_options_obj.xla_cpu_enable_fast_math = False
|
||||
debug_options_obj.xla_cpu_fast_math_honor_infs = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user