diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fda1291a..6f30310f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master). * New features: * The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`. + * A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters + tracebacks. + * A new traceback filtering mode using `__tracebackhide__` is now enabled by + default in sufficiently recent versions of IPython. * Breaking changes: diff --git a/jax/_src/config.py b/jax/_src/config.py index 244b4923f..3f445fd6c 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -250,7 +250,10 @@ class Config: See docstring for ``define_bool_state``. """ name = name.lower() - self.DEFINE_enum(name, os.getenv(name.upper(), default), + default = os.getenv(name.upper(), default) + if default is not None and default not in enum_values: + raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}") + self.DEFINE_enum(name, default, enum_values=enum_values, help=help, update_hook=update_global_hook) self._contextmanager_flags.add(name) @@ -517,3 +520,17 @@ default_matmul_precision = config.define_enum_state( update_global_jit_state(default_matmul_precision=val), update_thread_local_hook=lambda val: \ update_thread_local_jit_state(default_matmul_precision=val)) + +traceback_filtering = config.define_enum_state( + name = 'jax_traceback_filtering', + enum_values=["off", "tracebackhide", "remove_frames", "auto"], + default="auto", + help="Controls how JAX filters internal frames out of tracebacks.\n\n" + "Valid values are:\n" + " * \"off\": disables traceback filtering.\n" + " * \"auto\": use \"tracebackhide\" if running under a sufficiently " + "new IPython, or \"remove_frames\" otherwise.\n" + " * \"tracebackhide\": adds \"__tracebackhide__\" annotations to " + " hidden stack frames, which some traceback printers support.\n" + " * \"remove_frames\": removes hidden frames from tracebacks, and adds " + " the unfiltered traceback as a __cause__ of the exception.\n") diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index 374c11377..aeef289b7 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -16,7 +16,9 @@ import os import sys import traceback import types +import warnings +import jax from jax.lib import xla_extension from jax._src import util @@ -56,6 +58,11 @@ def include_frame(f): def ignore_known_hidden_frame(f): return 'importlib._bootstrap' in f.f_code.co_filename +def add_tracebackhide_to_hidden_frames(tb): + for f, lineno in traceback.walk_tb(tb): + if not include_frame(f): + f.f_locals["__tracebackhide__"] = True + def filter_traceback(tb): out = None # Scan the traceback and collect relevant frames. @@ -111,6 +118,38 @@ make_traceback = (types.TracebackType if sys.version_info >= (3, 7) else def filtered_tracebacks_supported(): return make_traceback is not None +def running_under_ipython(): + """Returns true if we appear to be in an IPython session.""" + try: + get_ipython() # type: ignore + return True + except NameError: + return False + +def python_supports_tracebackhide(): + """Returns true we can add __tracebackhide__ to frames.""" + # TODO(phawkins): remove this test after droppping Python 3.6 support. + return sys.version_info[:2] >= (3, 7) + +def ipython_supports_tracebackhide(): + """Returns true if the IPython version supports __tracebackhide__.""" + import IPython # type: ignore + return IPython.version_info[:2] >= (7, 17) + +def filtering_mode(): + mode = jax.config.jax_traceback_filtering + if mode is None or mode == "auto": + if (running_under_ipython() and ipython_supports_tracebackhide() and + python_supports_tracebackhide()): + mode = "tracebackhide" + else: + mode = "remove_frames" + if mode == "tracebackhide" and not python_supports_tracebackhide(): + warnings.warn("--jax_traceback_filtering=tracebackhide requires Python 3.7 " + "or newer.") + mode = "remove_frames" + return mode + def api_boundary(fun): '''Wraps ``fun`` to form a boundary for filtering exception tracebacks. @@ -139,39 +178,47 @@ def api_boundary(fun): @util.wraps(fun) def reraise_with_filtered_traceback(*args, **kwargs): + __tracebackhide__ = True try: return fun(*args, **kwargs) except Exception as e: - if not is_under_reraiser(e): - filtered_tb, unfiltered = None, None - try: - filtered_tb = filter_traceback(e.__traceback__) - if filtered_tb is None: - raise - msg = format_exception_only(e) - msg = f'{msg}\n\n{_jax_message_append}' - unfiltered = UnfilteredStackTrace(msg) - unfiltered.with_traceback(add_call_stack_frames(e.__traceback__)) - unfiltered.__context__ = e.__context__ - unfiltered.__cause__ = e.__cause__ - unfiltered.__suppress_context__ = e.__suppress_context__ - e.__context__ = None - e.__cause__ = unfiltered - # There seems to be no way to alter the currently raised exception's - # traceback, except via the C API. The currently raised exception - # is part of the interpreter's thread state: value `e` is a copy. - if hasattr(xla_extension, 'replace_thread_exc_traceback'): - xla_extension.replace_thread_exc_traceback(filtered_tb) - raise - else: - # TODO(phawkins): remove this case when jaxlib 0.1.66 is the - # minimum. - - # Fallback case for older jaxlibs; includes the current frame. - raise e.with_traceback(filtered_tb) - finally: - del filtered_tb - del unfiltered - else: + mode = filtering_mode() + if is_under_reraiser(e) or mode == "off": raise + if mode == "tracebackhide": + add_tracebackhide_to_hidden_frames(e.__traceback__) + raise + assert mode == "remove_frames", mode + + filtered_tb, unfiltered, mode = None, None, None + try: + filtered_tb = filter_traceback(e.__traceback__) + if filtered_tb is None: + raise + msg = format_exception_only(e) + msg = f'{msg}\n\n{_jax_message_append}' + unfiltered = UnfilteredStackTrace(msg) + unfiltered.with_traceback(add_call_stack_frames(e.__traceback__)) + unfiltered.__context__ = e.__context__ + unfiltered.__cause__ = e.__cause__ + unfiltered.__suppress_context__ = e.__suppress_context__ + e.__context__ = None + e.__cause__ = unfiltered + + # There seems to be no way to alter the currently raised exception's + # traceback, except via the C API. The currently raised exception + # is part of the interpreter's thread state: value `e` is a copy. + if hasattr(xla_extension, 'replace_thread_exc_traceback'): + xla_extension.replace_thread_exc_traceback(filtered_tb) + raise + else: + # TODO(phawkins): remove this case when jaxlib 0.1.66 is the + # minimum. + + # Fallback case for older jaxlibs; includes the current frame. + raise e.with_traceback(filtered_tb) + finally: + del filtered_tb + del unfiltered + del mode return reraise_with_filtered_traceback diff --git a/tests/errors_test.py b/tests/errors_test.py index 7651ba7e2..a62880a14 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -13,6 +13,7 @@ # limitations under the License. import re +import sys import traceback import unittest @@ -40,16 +41,31 @@ def get_exception(etype, f): return e assert False -def check_filtered_stack_trace(test, etype, f, frame_patterns=[]): - test.assertRaises(etype, f) - e = get_exception(etype, f) +def check_filtered_stack_trace(test, etype, f, frame_patterns=[], + filter_mode="remove_frames"): + with jax._src.config.traceback_filtering(filter_mode): + test.assertRaises(etype, f) + e = get_exception(etype, f) c = e.__cause__ - test.assertIsInstance(c, traceback_util.UnfilteredStackTrace) - c_tb = traceback.format_tb(e.__traceback__) - # TODO(phawkins): remove this condition after jaxlib 0.1.66 is the minimum. - if not hasattr(xla_extension, "replace_thread_exc_traceback"): - c_tb = [t for t in c_tb if "reraise_with_filtered_traceback" not in t] + if filter_mode == "remove_frames": + test.assertIsInstance(c, traceback_util.UnfilteredStackTrace) + else: + test.assertFalse(isinstance(c, traceback_util.UnfilteredStackTrace)) + if frame_patterns: + frames = [] + for frame, lineno in traceback.walk_tb(e.__traceback__): + if filter_mode == "tracebackhide": + if "__tracebackhide__" in frame.f_locals.keys(): + continue + elif filter_mode == "remove_frames": + # TODO(phawkins): remove this condition after jaxlib 0.1.66 is the minimum. + if (not hasattr(xla_extension, "replace_thread_exc_traceback") and + frame.f_code.co_name == "reraise_with_filtered_traceback"): + continue + frames.append((frame, lineno)) + + c_tb = traceback.format_list(traceback.StackSummary.extract(frames)) for (fname_pat, line_pat), frame_fmt in zip( reversed(frame_patterns), reversed(c_tb)): file = re.escape(__file__) @@ -60,12 +76,21 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=[]): f', in {fname_pat}' r'\n\s*' f'{line_pat}') test.assertRegex(frame_fmt, full_pat) +def skip_if_unsupported_filter_mode(filter_mode): + if (filter_mode == "remove_frames" and + not traceback_util.filtered_tracebacks_supported()): + raise unittest.SkipTest('Filtered tracebacks not supported') + elif filter_mode == "tracebackhide" and sys.version_info[:2] < (3, 7): + raise unittest.SkipTest('Tracebackhide requires Python 3.7 or newer') + +@parameterized.named_parameters( + {"testcase_name": f"_{f}", "filter_mode": f} + for f in ("tracebackhide", "remove_frames")) class FilteredTracebackTest(jtu.JaxTestCase): - def test_nested_jit(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_nested_jit(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) @jit def innermost(x): @@ -83,11 +108,11 @@ class FilteredTracebackTest(jtu.JaxTestCase): ('', 'f = lambda: outermost'), ('outermost', 'return 2 + inbetween(x)'), ('inbetween', 'return 1 + innermost(x)'), - ('innermost', 'assert False')]) + ('innermost', 'assert False')], + filter_mode=filter_mode) - def test_nested_jit_and_vmap(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_nested_jit_and_vmap(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) @jit def innermost(x): @@ -105,11 +130,11 @@ class FilteredTracebackTest(jtu.JaxTestCase): ('', 'f = lambda: outermost'), ('outermost', 'return 2 + inbetween(x)'), ('inbetween', 'return 1 + vmap(innermost)(x)'), - ('innermost', 'assert False')]) + ('innermost', 'assert False')], + filter_mode=filter_mode) - def test_nested_jit_and_grad(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_nested_jit_and_grad(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) @jit def innermost(x): @@ -127,11 +152,10 @@ class FilteredTracebackTest(jtu.JaxTestCase): ('', 'f = lambda: outermost'), ('outermost', 'return 2 + inbetween(x)'), ('inbetween', 'return 1 + grad(innermost)(x)'), - ]) + ], filter_mode=filter_mode) - def test_lax_cond(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_lax_cond(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) def err(_): assert False @@ -142,11 +166,11 @@ class FilteredTracebackTest(jtu.JaxTestCase): check_filtered_stack_trace(self, AssertionError, f, [ ('f', 'return lax.cond(True, err, lambda _: (), ())'), - ('err', 'assert False')]) + ('err', 'assert False')], + filter_mode=filter_mode) - def test_lax_switch(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_lax_switch(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) def err(_): assert False @@ -158,11 +182,10 @@ class FilteredTracebackTest(jtu.JaxTestCase): check_filtered_stack_trace(self, AssertionError, f, [ ('f', 'return lax.switch(1, branches, ())'), - ('err', 'assert False')]) + ('err', 'assert False')], filter_mode=filter_mode) - def test_lax_scan(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_lax_scan(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) def err(*_): assert False @@ -173,11 +196,10 @@ class FilteredTracebackTest(jtu.JaxTestCase): check_filtered_stack_trace(self, AssertionError, f, [ ('f', 'return lax.scan(err, (), (), 3)'), - ('err', 'assert False')]) + ('err', 'assert False')], filter_mode=filter_mode) - def test_lax_fori_loop(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_lax_fori_loop(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) def err(*_): assert False @@ -188,11 +210,10 @@ class FilteredTracebackTest(jtu.JaxTestCase): check_filtered_stack_trace(self, AssertionError, f, [ ('f', 'return lax.fori_loop(0, 3, err, ())'), - ('err', 'assert False')]) + ('err', 'assert False')], filter_mode=filter_mode) - def test_lax_while_loop(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_lax_while_loop(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) def err(*_): assert False @@ -204,11 +225,10 @@ class FilteredTracebackTest(jtu.JaxTestCase): check_filtered_stack_trace(self, AssertionError, f, [ ('f', 'return lax.while_loop(pred, err, ())'), - ('err', 'assert False')]) + ('err', 'assert False')], filter_mode=filter_mode) - def test_lax_map(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_lax_map(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) def err(_): assert False @@ -220,11 +240,10 @@ class FilteredTracebackTest(jtu.JaxTestCase): check_filtered_stack_trace(self, AssertionError, f, [ ('f', 'return lax.map(err, xs)'), - ('err', 'assert False')]) + ('err', 'assert False')], filter_mode=filter_mode) - def test_lax_custom_root(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_lax_custom_root(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) def err(*_): assert False @@ -242,17 +261,16 @@ class FilteredTracebackTest(jtu.JaxTestCase): check_filtered_stack_trace(self, AssertionError, f1, [ ('f1', 'return lax.custom_root(g, 0., err, solve)'), - ('err', 'assert False')]) + ('err', 'assert False')], filter_mode=filter_mode) check_filtered_stack_trace(self, AssertionError, f2, [ ('f2', 'return lax.custom_root(g, 0., solve, err)'), - ('err', 'assert False')]) + ('err', 'assert False')], filter_mode=filter_mode) check_filtered_stack_trace(self, AssertionError, f3, [ ('f3', 'return lax.custom_root(err, 0., solve, solve)'), - ('err', 'assert False')]) + ('err', 'assert False')], filter_mode=filter_mode) - def test_lax_custom_linear_solve(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_lax_custom_linear_solve(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) def err(*_): assert False @@ -269,14 +287,13 @@ class FilteredTracebackTest(jtu.JaxTestCase): check_filtered_stack_trace(self, AssertionError, f1, [ ('f1', 'return lax.custom_linear_solve(err, b, solve)'), - ('err', 'assert False')]) + ('err', 'assert False')], filter_mode=filter_mode) check_filtered_stack_trace(self, AssertionError, f2, [ ('f2', 'return lax.custom_linear_solve(matvec, b, err)'), - ('err', 'assert False')]) + ('err', 'assert False')], filter_mode=filter_mode) - def test_lax_associative_scan(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_lax_associative_scan(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) def err(*_): assert False @@ -288,11 +305,10 @@ class FilteredTracebackTest(jtu.JaxTestCase): check_filtered_stack_trace(self, AssertionError, f, [ ('f', 'return lax.associative_scan(err, xs)'), - ('err', 'assert False')]) + ('err', 'assert False')], filter_mode=filter_mode) - def test_cause_chain(self): - if not traceback_util.filtered_tracebacks_supported(): - raise unittest.SkipTest('Filtered tracebacks not supported') + def test_cause_chain(self, filter_mode): + skip_if_unsupported_filter_mode(filter_mode) @jit def inner(x): @@ -308,7 +324,7 @@ class FilteredTracebackTest(jtu.JaxTestCase): check_filtered_stack_trace(self, TypeError, f, [ ('', 'f = lambda: outer'), - ('outer', 'raise TypeError')]) + ('outer', 'raise TypeError')], filter_mode=filter_mode) e = get_exception(TypeError, f) self.assertIsInstance(e.__cause__, traceback_util.UnfilteredStackTrace) self.assertIsInstance(e.__cause__.__cause__, ValueError)