Add a --jax_traceback_filtering flag to control the traceback filtering mode.

Add a new traceback filtering mode that uses __tracebackhide__, and use it in IPython.
This commit is contained in:
Peter Hawkins 2021-06-02 15:22:50 -04:00
parent 46cc654537
commit 2882286b50
4 changed files with 179 additions and 95 deletions

View File

@ -12,6 +12,10 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
* New features:
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
* A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters
tracebacks.
* A new traceback filtering mode using `__tracebackhide__` is now enabled by
default in sufficiently recent versions of IPython.
* Breaking changes:

View File

@ -250,7 +250,10 @@ class Config:
See docstring for ``define_bool_state``.
"""
name = name.lower()
self.DEFINE_enum(name, os.getenv(name.upper(), default),
default = os.getenv(name.upper(), default)
if default is not None and default not in enum_values:
raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}")
self.DEFINE_enum(name, default,
enum_values=enum_values, help=help,
update_hook=update_global_hook)
self._contextmanager_flags.add(name)
@ -517,3 +520,17 @@ default_matmul_precision = config.define_enum_state(
update_global_jit_state(default_matmul_precision=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(default_matmul_precision=val))
traceback_filtering = config.define_enum_state(
name = 'jax_traceback_filtering',
enum_values=["off", "tracebackhide", "remove_frames", "auto"],
default="auto",
help="Controls how JAX filters internal frames out of tracebacks.\n\n"
"Valid values are:\n"
" * \"off\": disables traceback filtering.\n"
" * \"auto\": use \"tracebackhide\" if running under a sufficiently "
"new IPython, or \"remove_frames\" otherwise.\n"
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to "
" hidden stack frames, which some traceback printers support.\n"
" * \"remove_frames\": removes hidden frames from tracebacks, and adds "
" the unfiltered traceback as a __cause__ of the exception.\n")

View File

@ -16,7 +16,9 @@ import os
import sys
import traceback
import types
import warnings
import jax
from jax.lib import xla_extension
from jax._src import util
@ -56,6 +58,11 @@ def include_frame(f):
def ignore_known_hidden_frame(f):
return 'importlib._bootstrap' in f.f_code.co_filename
def add_tracebackhide_to_hidden_frames(tb):
for f, lineno in traceback.walk_tb(tb):
if not include_frame(f):
f.f_locals["__tracebackhide__"] = True
def filter_traceback(tb):
out = None
# Scan the traceback and collect relevant frames.
@ -111,6 +118,38 @@ make_traceback = (types.TracebackType if sys.version_info >= (3, 7) else
def filtered_tracebacks_supported():
return make_traceback is not None
def running_under_ipython():
"""Returns true if we appear to be in an IPython session."""
try:
get_ipython() # type: ignore
return True
except NameError:
return False
def python_supports_tracebackhide():
"""Returns true we can add __tracebackhide__ to frames."""
# TODO(phawkins): remove this test after droppping Python 3.6 support.
return sys.version_info[:2] >= (3, 7)
def ipython_supports_tracebackhide():
"""Returns true if the IPython version supports __tracebackhide__."""
import IPython # type: ignore
return IPython.version_info[:2] >= (7, 17)
def filtering_mode():
mode = jax.config.jax_traceback_filtering
if mode is None or mode == "auto":
if (running_under_ipython() and ipython_supports_tracebackhide() and
python_supports_tracebackhide()):
mode = "tracebackhide"
else:
mode = "remove_frames"
if mode == "tracebackhide" and not python_supports_tracebackhide():
warnings.warn("--jax_traceback_filtering=tracebackhide requires Python 3.7 "
"or newer.")
mode = "remove_frames"
return mode
def api_boundary(fun):
'''Wraps ``fun`` to form a boundary for filtering exception tracebacks.
@ -139,39 +178,47 @@ def api_boundary(fun):
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
return fun(*args, **kwargs)
except Exception as e:
if not is_under_reraiser(e):
filtered_tb, unfiltered = None, None
try:
filtered_tb = filter_traceback(e.__traceback__)
if filtered_tb is None:
raise
msg = format_exception_only(e)
msg = f'{msg}\n\n{_jax_message_append}'
unfiltered = UnfilteredStackTrace(msg)
unfiltered.with_traceback(add_call_stack_frames(e.__traceback__))
unfiltered.__context__ = e.__context__
unfiltered.__cause__ = e.__cause__
unfiltered.__suppress_context__ = e.__suppress_context__
e.__context__ = None
e.__cause__ = unfiltered
# There seems to be no way to alter the currently raised exception's
# traceback, except via the C API. The currently raised exception
# is part of the interpreter's thread state: value `e` is a copy.
if hasattr(xla_extension, 'replace_thread_exc_traceback'):
xla_extension.replace_thread_exc_traceback(filtered_tb)
raise
else:
# TODO(phawkins): remove this case when jaxlib 0.1.66 is the
# minimum.
# Fallback case for older jaxlibs; includes the current frame.
raise e.with_traceback(filtered_tb)
finally:
del filtered_tb
del unfiltered
else:
mode = filtering_mode()
if is_under_reraiser(e) or mode == "off":
raise
if mode == "tracebackhide":
add_tracebackhide_to_hidden_frames(e.__traceback__)
raise
assert mode == "remove_frames", mode
filtered_tb, unfiltered, mode = None, None, None
try:
filtered_tb = filter_traceback(e.__traceback__)
if filtered_tb is None:
raise
msg = format_exception_only(e)
msg = f'{msg}\n\n{_jax_message_append}'
unfiltered = UnfilteredStackTrace(msg)
unfiltered.with_traceback(add_call_stack_frames(e.__traceback__))
unfiltered.__context__ = e.__context__
unfiltered.__cause__ = e.__cause__
unfiltered.__suppress_context__ = e.__suppress_context__
e.__context__ = None
e.__cause__ = unfiltered
# There seems to be no way to alter the currently raised exception's
# traceback, except via the C API. The currently raised exception
# is part of the interpreter's thread state: value `e` is a copy.
if hasattr(xla_extension, 'replace_thread_exc_traceback'):
xla_extension.replace_thread_exc_traceback(filtered_tb)
raise
else:
# TODO(phawkins): remove this case when jaxlib 0.1.66 is the
# minimum.
# Fallback case for older jaxlibs; includes the current frame.
raise e.with_traceback(filtered_tb)
finally:
del filtered_tb
del unfiltered
del mode
return reraise_with_filtered_traceback

View File

@ -13,6 +13,7 @@
# limitations under the License.
import re
import sys
import traceback
import unittest
@ -40,16 +41,31 @@ def get_exception(etype, f):
return e
assert False
def check_filtered_stack_trace(test, etype, f, frame_patterns=[]):
test.assertRaises(etype, f)
e = get_exception(etype, f)
def check_filtered_stack_trace(test, etype, f, frame_patterns=[],
filter_mode="remove_frames"):
with jax._src.config.traceback_filtering(filter_mode):
test.assertRaises(etype, f)
e = get_exception(etype, f)
c = e.__cause__
test.assertIsInstance(c, traceback_util.UnfilteredStackTrace)
c_tb = traceback.format_tb(e.__traceback__)
# TODO(phawkins): remove this condition after jaxlib 0.1.66 is the minimum.
if not hasattr(xla_extension, "replace_thread_exc_traceback"):
c_tb = [t for t in c_tb if "reraise_with_filtered_traceback" not in t]
if filter_mode == "remove_frames":
test.assertIsInstance(c, traceback_util.UnfilteredStackTrace)
else:
test.assertFalse(isinstance(c, traceback_util.UnfilteredStackTrace))
if frame_patterns:
frames = []
for frame, lineno in traceback.walk_tb(e.__traceback__):
if filter_mode == "tracebackhide":
if "__tracebackhide__" in frame.f_locals.keys():
continue
elif filter_mode == "remove_frames":
# TODO(phawkins): remove this condition after jaxlib 0.1.66 is the minimum.
if (not hasattr(xla_extension, "replace_thread_exc_traceback") and
frame.f_code.co_name == "reraise_with_filtered_traceback"):
continue
frames.append((frame, lineno))
c_tb = traceback.format_list(traceback.StackSummary.extract(frames))
for (fname_pat, line_pat), frame_fmt in zip(
reversed(frame_patterns), reversed(c_tb)):
file = re.escape(__file__)
@ -60,12 +76,21 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=[]):
f', in {fname_pat}' r'\n\s*' f'{line_pat}')
test.assertRegex(frame_fmt, full_pat)
def skip_if_unsupported_filter_mode(filter_mode):
if (filter_mode == "remove_frames" and
not traceback_util.filtered_tracebacks_supported()):
raise unittest.SkipTest('Filtered tracebacks not supported')
elif filter_mode == "tracebackhide" and sys.version_info[:2] < (3, 7):
raise unittest.SkipTest('Tracebackhide requires Python 3.7 or newer')
@parameterized.named_parameters(
{"testcase_name": f"_{f}", "filter_mode": f}
for f in ("tracebackhide", "remove_frames"))
class FilteredTracebackTest(jtu.JaxTestCase):
def test_nested_jit(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_nested_jit(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
@jit
def innermost(x):
@ -83,11 +108,11 @@ class FilteredTracebackTest(jtu.JaxTestCase):
('<lambda>', 'f = lambda: outermost'),
('outermost', 'return 2 + inbetween(x)'),
('inbetween', 'return 1 + innermost(x)'),
('innermost', 'assert False')])
('innermost', 'assert False')],
filter_mode=filter_mode)
def test_nested_jit_and_vmap(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_nested_jit_and_vmap(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
@jit
def innermost(x):
@ -105,11 +130,11 @@ class FilteredTracebackTest(jtu.JaxTestCase):
('<lambda>', 'f = lambda: outermost'),
('outermost', 'return 2 + inbetween(x)'),
('inbetween', 'return 1 + vmap(innermost)(x)'),
('innermost', 'assert False')])
('innermost', 'assert False')],
filter_mode=filter_mode)
def test_nested_jit_and_grad(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_nested_jit_and_grad(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
@jit
def innermost(x):
@ -127,11 +152,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
('<lambda>', 'f = lambda: outermost'),
('outermost', 'return 2 + inbetween(x)'),
('inbetween', 'return 1 + grad(innermost)(x)'),
])
], filter_mode=filter_mode)
def test_lax_cond(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_lax_cond(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
def err(_):
assert False
@ -142,11 +166,11 @@ class FilteredTracebackTest(jtu.JaxTestCase):
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.cond(True, err, lambda _: (), ())'),
('err', 'assert False')])
('err', 'assert False')],
filter_mode=filter_mode)
def test_lax_switch(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_lax_switch(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
def err(_):
assert False
@ -158,11 +182,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.switch(1, branches, ())'),
('err', 'assert False')])
('err', 'assert False')], filter_mode=filter_mode)
def test_lax_scan(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_lax_scan(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
def err(*_):
assert False
@ -173,11 +196,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.scan(err, (), (), 3)'),
('err', 'assert False')])
('err', 'assert False')], filter_mode=filter_mode)
def test_lax_fori_loop(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_lax_fori_loop(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
def err(*_):
assert False
@ -188,11 +210,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.fori_loop(0, 3, err, ())'),
('err', 'assert False')])
('err', 'assert False')], filter_mode=filter_mode)
def test_lax_while_loop(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_lax_while_loop(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
def err(*_):
assert False
@ -204,11 +225,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.while_loop(pred, err, ())'),
('err', 'assert False')])
('err', 'assert False')], filter_mode=filter_mode)
def test_lax_map(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_lax_map(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
def err(_):
assert False
@ -220,11 +240,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.map(err, xs)'),
('err', 'assert False')])
('err', 'assert False')], filter_mode=filter_mode)
def test_lax_custom_root(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_lax_custom_root(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
def err(*_):
assert False
@ -242,17 +261,16 @@ class FilteredTracebackTest(jtu.JaxTestCase):
check_filtered_stack_trace(self, AssertionError, f1, [
('f1', 'return lax.custom_root(g, 0., err, solve)'),
('err', 'assert False')])
('err', 'assert False')], filter_mode=filter_mode)
check_filtered_stack_trace(self, AssertionError, f2, [
('f2', 'return lax.custom_root(g, 0., solve, err)'),
('err', 'assert False')])
('err', 'assert False')], filter_mode=filter_mode)
check_filtered_stack_trace(self, AssertionError, f3, [
('f3', 'return lax.custom_root(err, 0., solve, solve)'),
('err', 'assert False')])
('err', 'assert False')], filter_mode=filter_mode)
def test_lax_custom_linear_solve(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_lax_custom_linear_solve(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
def err(*_):
assert False
@ -269,14 +287,13 @@ class FilteredTracebackTest(jtu.JaxTestCase):
check_filtered_stack_trace(self, AssertionError, f1, [
('f1', 'return lax.custom_linear_solve(err, b, solve)'),
('err', 'assert False')])
('err', 'assert False')], filter_mode=filter_mode)
check_filtered_stack_trace(self, AssertionError, f2, [
('f2', 'return lax.custom_linear_solve(matvec, b, err)'),
('err', 'assert False')])
('err', 'assert False')], filter_mode=filter_mode)
def test_lax_associative_scan(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_lax_associative_scan(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
def err(*_):
assert False
@ -288,11 +305,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.associative_scan(err, xs)'),
('err', 'assert False')])
('err', 'assert False')], filter_mode=filter_mode)
def test_cause_chain(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def test_cause_chain(self, filter_mode):
skip_if_unsupported_filter_mode(filter_mode)
@jit
def inner(x):
@ -308,7 +324,7 @@ class FilteredTracebackTest(jtu.JaxTestCase):
check_filtered_stack_trace(self, TypeError, f, [
('<lambda>', 'f = lambda: outer'),
('outer', 'raise TypeError')])
('outer', 'raise TypeError')], filter_mode=filter_mode)
e = get_exception(TypeError, f)
self.assertIsInstance(e.__cause__, traceback_util.UnfilteredStackTrace)
self.assertIsInstance(e.__cause__.__cause__, ValueError)