rocm_jax/jax/_src/traceback_util.py
Peter Hawkins e8c340623c [JAX] Switch the order of the filtered and unfiltered stack traces in exceptions.
After this change, the filtered stack trace is attached to the main exception, and the unfiltered stack trace becomes a __cause__ exception.

PiperOrigin-RevId: 371509766
2021-05-01 12:41:37 -07:00

175 lines
6.7 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import traceback
import types
from jax.lib import xla_extension
from jax._src import util
_exclude_paths = [__file__, util.__file__]
def register_exclusion(path):
_exclude_paths.append(path)
_jax_message_append = (
'The stack trace below excludes JAX-internal frames.\n'
'The preceding is the original exception that occurred, unmodified.\n'
'\n--------------------')
def path_starts_with(path, path_prefix):
path = os.path.abspath(path)
path_prefix = os.path.abspath(path_prefix)
if not os.path.exists(path_prefix):
return False
try:
common = os.path.commonpath([path, path_prefix])
return os.path.samefile(common, path_prefix)
except ValueError:
# path and path_prefix are both absolute, the only case will raise a
# ValueError is different drives.
# https://docs.python.org/3/library/os.path.html#os.path.commonpath
return False
def include_frame(f):
return not any(path_starts_with(f.f_code.co_filename, path)
for path in _exclude_paths)
# When scanning stack traces, we might encounter frames from cpython that are
# removed from printed stack traces, such as frames from parts of importlib. We
# ignore these frames heuristically based on source and name match.
def ignore_known_hidden_frame(f):
return 'importlib._bootstrap' in f.f_code.co_filename
def filter_traceback_and_stack(tb):
out = None
# Scan the traceback and collect relevant frames.
for f, lineno in reversed(list(traceback.walk_tb(tb))):
if include_frame(f) or out is None:
out = make_traceback(out, f, f.f_lasti, lineno) # pytype: disable=wrong-arg-count
return out
def add_call_stack_frames(tb):
# Continue up the call stack.
#
# We would like to avoid stepping too far up, e.g. past the exec/eval point of
# a REPL such as IPython. To that end, we stop past the first contiguous bunch
# of module-level frames, if we reach any such frames at all. This is a
# heuristic that might stop in advance of the REPL boundary. For example, if
# the call stack includes module-level frames from the current module A, and
# the current module A was imported from within a function F elsewhere, then
# the stack trace we produce will be truncated at F's frame.
out = tb
reached_module_level = False
for f, lineno in traceback.walk_stack(tb.tb_frame):
if ignore_known_hidden_frame(f):
continue
if reached_module_level and f.f_code.co_name != '<module>':
break
if include_frame(f):
out = make_traceback(out, f, f.f_lasti, lineno) # pytype: disable=wrong-arg-count
if f.f_code.co_name == '<module>':
reached_module_level = True
return out
def is_reraiser_frame(f):
return (f.filename == __file__ and
f.name == 'reraise_with_filtered_traceback')
def is_under_reraiser(e):
tb = traceback.extract_stack(e.__traceback__.tb_frame)
return any(is_reraiser_frame(f) for f in tb[:-1])
def format_exception_only(e):
return ''.join(traceback.format_exception_only(type(e), e)).strip()
class UnfilteredStackTrace(Exception): pass
make_traceback = (types.TracebackType if sys.version_info >= (3, 7) else
getattr(xla_extension, "make_python_traceback", None))
replace_thread_exc_traceback = getattr(
xla_extension, "replace_thread_exc_traceback", None)
def filtered_tracebacks_supported():
return make_traceback is not None
def api_boundary(fun):
'''Wraps ``fun`` to form a boundary for filtering exception tracebacks.
When an exception occurs below ``fun``, this appends to it a custom
``__cause__`` that carries a filtered traceback. The traceback imitates the
stack trace of the original exception, but with JAX-internal frames removed.
This boundary annotation works in composition with itself. The topmost frame
corresponding to an ``api_boundary`` is the one below which stack traces are
filtered. In other words, if ``api_boundary(f)`` calls ``api_boundary(g)``,
directly or indirectly, the filtered stack trace provided is the same as if
``api_boundary(f)`` were to simply call ``g`` instead.
This annotation is primarily useful in wrapping functions output by JAX's
transformations. For example, consider ``g = jax.jit(f)``. When ``g`` is
called, JAX's JIT compilation machinery is invoked, which in turn calls ``f``
in order to trace and translate it. If the function ``f`` raises an exception,
the stack unwinds through JAX's JIT internals up to the original call site of
``g``. Because the function returned by ``jax.jit`` is annotated as an
``api_boundary``, such an exception is accompanied by an additional traceback
that excludes the frames specific to JAX's implementation.
'''
if not filtered_tracebacks_supported():
return fun
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
try:
return fun(*args, **kwargs)
except Exception as e:
if not is_under_reraiser(e):
filtered_tb, unfiltered = None, None
try:
filtered_tb = filter_traceback_and_stack(e.__traceback__)
if filtered_tb is None:
raise
msg = format_exception_only(e)
msg = f'{msg}\n\n{_jax_message_append}'
unfiltered = UnfilteredStackTrace(msg)
unfiltered.with_traceback(add_call_stack_frames(e.__traceback__))
unfiltered.__context__ = e.__context__
unfiltered.__cause__ = e.__cause__
unfiltered.__suppress_context__ = e.__suppress_context__
e.__context__ = None
e.__cause__ = unfiltered
# There seems to be no way to alter the currently raised exception's
# traceback, except via the C API. The currently raised exception
# is part of the interpreter's thread state: value `e` is a copy.
if replace_thread_exc_traceback is not None:
replace_thread_exc_traceback(filtered_tb)
raise
else:
# TODO(phawkins): remove this case when jaxlib 0.1.66 is the
# minimum.
# Fallback case for older jaxlibs; includes the current frame.
raise e.with_traceback(filtered_tb)
finally:
del filtered_tb
del unfiltered
else:
raise
return reraise_with_filtered_traceback