mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Tracebacks no longer have JAX-internal frames prepended by default
This commit is contained in:
parent
a22c4773e1
commit
5e276d0935
@ -1025,17 +1025,21 @@ default_matmul_precision = config.define_enum_state(
|
||||
|
||||
traceback_filtering = config.define_enum_state(
|
||||
name = 'jax_traceback_filtering',
|
||||
enum_values=["off", "tracebackhide", "remove_frames", "auto"],
|
||||
enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames",
|
||||
"auto"],
|
||||
default="auto",
|
||||
help="Controls how JAX filters internal frames out of tracebacks.\n\n"
|
||||
"Valid values are:\n"
|
||||
" * \"off\": disables traceback filtering.\n"
|
||||
" * \"auto\": use \"tracebackhide\" if running under a sufficiently "
|
||||
"new IPython, or \"remove_frames\" otherwise.\n"
|
||||
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to "
|
||||
" * \"auto\": use \"tracebackhide\" if running under a sufficiently"
|
||||
" new IPython, or \"remove_frames\" otherwise.\n"
|
||||
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to"
|
||||
" hidden stack frames, which some traceback printers support.\n"
|
||||
" * \"remove_frames\": removes hidden frames from tracebacks, and adds "
|
||||
" the unfiltered traceback as a __cause__ of the exception.\n")
|
||||
" * \"remove_frames\": removes hidden frames from tracebacks, and adds"
|
||||
" the unfiltered traceback as a __cause__ of the exception.\n"
|
||||
" * \"quiet_remove_frames\": removes hidden frames from tracebacks, and adds"
|
||||
" a brief message (to the __cause__ of the exception) describing that this has"
|
||||
" happened.\n")
|
||||
|
||||
# This flag is for internal use.
|
||||
# TODO(tianjianlu): Removes once we always enable cusparse lowering.
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
from typing import Any, Callable, Optional, TypeVar, cast
|
||||
@ -114,6 +115,16 @@ def format_exception_only(e: BaseException) -> str:
|
||||
|
||||
class UnfilteredStackTrace(Exception): pass
|
||||
|
||||
_simplified_tb_msg = ("For simplicity, JAX has removed its internal frames from the "
|
||||
"traceback of the following exception. Set "
|
||||
"JAX_TRACEBACK_FILTERING=off to include these.")
|
||||
|
||||
class SimplifiedTraceback(Exception):
|
||||
def __str__(self):
|
||||
return _simplified_tb_msg
|
||||
|
||||
SimplifiedTraceback.__module__ = "jax.errors"
|
||||
|
||||
def _running_under_ipython() -> bool:
|
||||
"""Returns true if we appear to be in an IPython session."""
|
||||
try:
|
||||
@ -133,7 +144,7 @@ def _filtering_mode() -> str:
|
||||
if (_running_under_ipython() and _ipython_supports_tracebackhide()):
|
||||
mode = "tracebackhide"
|
||||
else:
|
||||
mode = "remove_frames"
|
||||
mode = "quiet_remove_frames"
|
||||
return mode
|
||||
|
||||
def api_boundary(fun: C) -> C:
|
||||
@ -171,22 +182,12 @@ def api_boundary(fun: C) -> C:
|
||||
if mode == "tracebackhide":
|
||||
_add_tracebackhide_to_hidden_frames(e.__traceback__)
|
||||
raise
|
||||
assert mode == "remove_frames", mode
|
||||
|
||||
filtered_tb, unfiltered, mode = None, None, None
|
||||
filtered_tb, unfiltered = None, None
|
||||
try:
|
||||
filtered_tb = filter_traceback(e.__traceback__)
|
||||
msg = format_exception_only(e)
|
||||
msg = f'{msg}\n\n{_jax_message_append}'
|
||||
unfiltered = UnfilteredStackTrace(msg)
|
||||
unfiltered.with_traceback(_add_call_stack_frames(e.__traceback__))
|
||||
unfiltered.__context__ = e.__context__
|
||||
unfiltered.__cause__ = e.__cause__
|
||||
unfiltered.__suppress_context__ = e.__suppress_context__
|
||||
e.__context__ = None
|
||||
e.__cause__ = unfiltered
|
||||
|
||||
e.__traceback__ = filtered_tb
|
||||
tb = e.__traceback__
|
||||
filtered_tb = filter_traceback(tb)
|
||||
e.with_traceback(filtered_tb)
|
||||
# In Python < 3.11, there seems to be no way to alter the currently
|
||||
# raised exception traceback, except via the C API. The interpreter
|
||||
# keeps a copy of the traceback (exc_traceback) that is separate to the
|
||||
@ -195,7 +196,28 @@ def api_boundary(fun: C) -> C:
|
||||
# the XLA extension no longer defines a traceback-replacing method at
|
||||
# Python 3.11 and onward.
|
||||
if hasattr(xla_extension, "replace_thread_exc_traceback"):
|
||||
# TODO(kidger): remove this line once Python 3.11 is the minimum supported
|
||||
# version.
|
||||
xla_extension.replace_thread_exc_traceback(filtered_tb)
|
||||
if sys.version_info >= (3, 11) and mode == "quiet_remove_frames":
|
||||
e.add_note("--------------------\n" + _simplified_tb_msg)
|
||||
else:
|
||||
if mode == "quiet_remove_frames":
|
||||
# TODO(kidger): remove `SimplifiedTraceback` once Python 3.11 is the minimum
|
||||
# supported version.
|
||||
jax_error = SimplifiedTraceback()
|
||||
elif mode == "remove_frames":
|
||||
msg = format_exception_only(e)
|
||||
msg = f'{msg}\n\n{_jax_message_append}'
|
||||
jax_error = UnfilteredStackTrace(msg)
|
||||
jax_error.with_traceback(_add_call_stack_frames(tb))
|
||||
else:
|
||||
raise ValueError(f"JAX_TRACEBACK_FILTERING={mode} is not a valid value.")
|
||||
jax_error.__cause__ = e.__cause__
|
||||
jax_error.__context__ = e.__context__
|
||||
jax_error.__suppress_context__ = e.__suppress_context__
|
||||
e.__cause__ = jax_error
|
||||
e.__context__ = None
|
||||
raise
|
||||
finally:
|
||||
del filtered_tb
|
||||
|
@ -25,3 +25,4 @@ from jax._src.errors import (
|
||||
TracerIntegerConversionError as TracerIntegerConversionError,
|
||||
UnexpectedTracerError as UnexpectedTracerError,
|
||||
)
|
||||
from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -46,7 +47,12 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=(),
|
||||
test.assertRaises(etype, f)
|
||||
e = get_exception(etype, f)
|
||||
c = e.__cause__
|
||||
if filter_mode == "remove_frames":
|
||||
if filter_mode == "quiet_remove_frames":
|
||||
if sys.version_info >= (3, 11):
|
||||
assert any("For simplicity" in x for x in e.__notes__)
|
||||
else:
|
||||
test.assertIsInstance(c, jax.errors.SimplifiedTraceback)
|
||||
elif filter_mode == "remove_frames":
|
||||
test.assertIsInstance(c, traceback_util.UnfilteredStackTrace)
|
||||
else:
|
||||
test.assertFalse(isinstance(c, traceback_util.UnfilteredStackTrace))
|
||||
@ -74,7 +80,7 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=(),
|
||||
@jtu.with_config(jax_traceback_filtering='auto') # JaxTestCase defaults to off.
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{f}", "filter_mode": f}
|
||||
for f in ("tracebackhide", "remove_frames"))
|
||||
for f in ("tracebackhide", "remove_frames", "quiet_remove_frames"))
|
||||
class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
def test_nested_jit(self, filter_mode):
|
||||
@ -347,8 +353,12 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
check_filtered_stack_trace(self, TypeError, f, [
|
||||
('<lambda>', 'f = lambda: outer'),
|
||||
('outer', 'raise TypeError')], filter_mode=filter_mode)
|
||||
e = get_exception(TypeError, f)
|
||||
self.assertIsInstance(e.__cause__, traceback_util.UnfilteredStackTrace)
|
||||
e = get_exception(TypeError, f) # Uses the default JAX_TRACEBACK_FILTERING=auto
|
||||
if sys.version_info >= (3, 11):
|
||||
assert any("For simplicity" in x for x in e.__notes__)
|
||||
self.assertIsInstance(e.__cause__, ValueError)
|
||||
else:
|
||||
self.assertIsInstance(e.__cause__, jax.errors.SimplifiedTraceback)
|
||||
self.assertIsInstance(e.__cause__.__cause__, ValueError)
|
||||
|
||||
def test_null_traceback(self, filter_mode):
|
||||
@ -375,6 +385,11 @@ class UserContextTracebackTest(jtu.JaxTestCase):
|
||||
e = exc
|
||||
self.assertIsNot(e, None)
|
||||
self.assertIn("invalid value", str(e))
|
||||
if sys.version_info >= (3, 11):
|
||||
self.assertIsInstance(
|
||||
e.__cause__,
|
||||
source_info_util.JaxStackTraceBeforeTransformation)
|
||||
else:
|
||||
self.assertIsInstance(
|
||||
e.__cause__.__cause__,
|
||||
source_info_util.JaxStackTraceBeforeTransformation)
|
||||
|
Loading…
x
Reference in New Issue
Block a user