From 5e276d0935c68137b7410acfbc115361057f2f5f Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 3 Aug 2023 10:20:29 -0700 Subject: [PATCH] Tracebacks no longer have JAX-internal frames prepended by default --- jax/_src/config.py | 16 +++++++----- jax/_src/traceback_util.py | 52 +++++++++++++++++++++++++++----------- jax/errors.py | 1 + tests/errors_test.py | 31 +++++++++++++++++------ 4 files changed, 71 insertions(+), 29 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index c7388c477..4c5f01b6a 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1025,17 +1025,21 @@ default_matmul_precision = config.define_enum_state( traceback_filtering = config.define_enum_state( name = 'jax_traceback_filtering', - enum_values=["off", "tracebackhide", "remove_frames", "auto"], + enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames", + "auto"], default="auto", help="Controls how JAX filters internal frames out of tracebacks.\n\n" "Valid values are:\n" " * \"off\": disables traceback filtering.\n" - " * \"auto\": use \"tracebackhide\" if running under a sufficiently " - "new IPython, or \"remove_frames\" otherwise.\n" - " * \"tracebackhide\": adds \"__tracebackhide__\" annotations to " + " * \"auto\": use \"tracebackhide\" if running under a sufficiently" + " new IPython, or \"remove_frames\" otherwise.\n" + " * \"tracebackhide\": adds \"__tracebackhide__\" annotations to" " hidden stack frames, which some traceback printers support.\n" - " * \"remove_frames\": removes hidden frames from tracebacks, and adds " - " the unfiltered traceback as a __cause__ of the exception.\n") + " * \"remove_frames\": removes hidden frames from tracebacks, and adds" + " the unfiltered traceback as a __cause__ of the exception.\n" + " * \"quiet_remove_frames\": removes hidden frames from tracebacks, and adds" + " a brief message (to the __cause__ of the exception) describing that this has" + " happened.\n") # This flag is for internal use. # TODO(tianjianlu): Removes once we always enable cusparse lowering. diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index c1ea67c58..6e837eccd 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -14,6 +14,7 @@ import functools import os +import sys import traceback import types from typing import Any, Callable, Optional, TypeVar, cast @@ -114,6 +115,16 @@ def format_exception_only(e: BaseException) -> str: class UnfilteredStackTrace(Exception): pass +_simplified_tb_msg = ("For simplicity, JAX has removed its internal frames from the " + "traceback of the following exception. Set " + "JAX_TRACEBACK_FILTERING=off to include these.") + +class SimplifiedTraceback(Exception): + def __str__(self): + return _simplified_tb_msg + +SimplifiedTraceback.__module__ = "jax.errors" + def _running_under_ipython() -> bool: """Returns true if we appear to be in an IPython session.""" try: @@ -133,7 +144,7 @@ def _filtering_mode() -> str: if (_running_under_ipython() and _ipython_supports_tracebackhide()): mode = "tracebackhide" else: - mode = "remove_frames" + mode = "quiet_remove_frames" return mode def api_boundary(fun: C) -> C: @@ -171,22 +182,12 @@ def api_boundary(fun: C) -> C: if mode == "tracebackhide": _add_tracebackhide_to_hidden_frames(e.__traceback__) raise - assert mode == "remove_frames", mode - filtered_tb, unfiltered, mode = None, None, None + filtered_tb, unfiltered = None, None try: - filtered_tb = filter_traceback(e.__traceback__) - msg = format_exception_only(e) - msg = f'{msg}\n\n{_jax_message_append}' - unfiltered = UnfilteredStackTrace(msg) - unfiltered.with_traceback(_add_call_stack_frames(e.__traceback__)) - unfiltered.__context__ = e.__context__ - unfiltered.__cause__ = e.__cause__ - unfiltered.__suppress_context__ = e.__suppress_context__ - e.__context__ = None - e.__cause__ = unfiltered - - e.__traceback__ = filtered_tb + tb = e.__traceback__ + filtered_tb = filter_traceback(tb) + e.with_traceback(filtered_tb) # In Python < 3.11, there seems to be no way to alter the currently # raised exception traceback, except via the C API. The interpreter # keeps a copy of the traceback (exc_traceback) that is separate to the @@ -195,7 +196,28 @@ def api_boundary(fun: C) -> C: # the XLA extension no longer defines a traceback-replacing method at # Python 3.11 and onward. if hasattr(xla_extension, "replace_thread_exc_traceback"): + # TODO(kidger): remove this line once Python 3.11 is the minimum supported + # version. xla_extension.replace_thread_exc_traceback(filtered_tb) + if sys.version_info >= (3, 11) and mode == "quiet_remove_frames": + e.add_note("--------------------\n" + _simplified_tb_msg) + else: + if mode == "quiet_remove_frames": + # TODO(kidger): remove `SimplifiedTraceback` once Python 3.11 is the minimum + # supported version. + jax_error = SimplifiedTraceback() + elif mode == "remove_frames": + msg = format_exception_only(e) + msg = f'{msg}\n\n{_jax_message_append}' + jax_error = UnfilteredStackTrace(msg) + jax_error.with_traceback(_add_call_stack_frames(tb)) + else: + raise ValueError(f"JAX_TRACEBACK_FILTERING={mode} is not a valid value.") + jax_error.__cause__ = e.__cause__ + jax_error.__context__ = e.__context__ + jax_error.__suppress_context__ = e.__suppress_context__ + e.__cause__ = jax_error + e.__context__ = None raise finally: del filtered_tb diff --git a/jax/errors.py b/jax/errors.py index 2801ec401..4b8a0cf75 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -25,3 +25,4 @@ from jax._src.errors import ( TracerIntegerConversionError as TracerIntegerConversionError, UnexpectedTracerError as UnexpectedTracerError, ) +from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback diff --git a/tests/errors_test.py b/tests/errors_test.py index cfb673f8a..c90a8fdfe 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -13,6 +13,7 @@ # limitations under the License. import re +import sys import traceback from absl.testing import absltest @@ -46,7 +47,12 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=(), test.assertRaises(etype, f) e = get_exception(etype, f) c = e.__cause__ - if filter_mode == "remove_frames": + if filter_mode == "quiet_remove_frames": + if sys.version_info >= (3, 11): + assert any("For simplicity" in x for x in e.__notes__) + else: + test.assertIsInstance(c, jax.errors.SimplifiedTraceback) + elif filter_mode == "remove_frames": test.assertIsInstance(c, traceback_util.UnfilteredStackTrace) else: test.assertFalse(isinstance(c, traceback_util.UnfilteredStackTrace)) @@ -74,7 +80,7 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=(), @jtu.with_config(jax_traceback_filtering='auto') # JaxTestCase defaults to off. @parameterized.named_parameters( {"testcase_name": f"_{f}", "filter_mode": f} - for f in ("tracebackhide", "remove_frames")) + for f in ("tracebackhide", "remove_frames", "quiet_remove_frames")) class FilteredTracebackTest(jtu.JaxTestCase): def test_nested_jit(self, filter_mode): @@ -347,9 +353,13 @@ class FilteredTracebackTest(jtu.JaxTestCase): check_filtered_stack_trace(self, TypeError, f, [ ('', 'f = lambda: outer'), ('outer', 'raise TypeError')], filter_mode=filter_mode) - e = get_exception(TypeError, f) - self.assertIsInstance(e.__cause__, traceback_util.UnfilteredStackTrace) - self.assertIsInstance(e.__cause__.__cause__, ValueError) + e = get_exception(TypeError, f) # Uses the default JAX_TRACEBACK_FILTERING=auto + if sys.version_info >= (3, 11): + assert any("For simplicity" in x for x in e.__notes__) + self.assertIsInstance(e.__cause__, ValueError) + else: + self.assertIsInstance(e.__cause__, jax.errors.SimplifiedTraceback) + self.assertIsInstance(e.__cause__.__cause__, ValueError) def test_null_traceback(self, filter_mode): class TestA: pass @@ -375,9 +385,14 @@ class UserContextTracebackTest(jtu.JaxTestCase): e = exc self.assertIsNot(e, None) self.assertIn("invalid value", str(e)) - self.assertIsInstance( - e.__cause__.__cause__, - source_info_util.JaxStackTraceBeforeTransformation) + if sys.version_info >= (3, 11): + self.assertIsInstance( + e.__cause__, + source_info_util.JaxStackTraceBeforeTransformation) + else: + self.assertIsInstance( + e.__cause__.__cause__, + source_info_util.JaxStackTraceBeforeTransformation) class CustomErrorsTest(jtu.JaxTestCase):