mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add a --jax_traceback_filtering flag to control the traceback filtering mode.
Add a new traceback filtering mode that uses __tracebackhide__, and use it in IPython.
This commit is contained in:
parent
46cc654537
commit
2882286b50
@ -12,6 +12,10 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
|
||||
* New features:
|
||||
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
|
||||
* A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters
|
||||
tracebacks.
|
||||
* A new traceback filtering mode using `__tracebackhide__` is now enabled by
|
||||
default in sufficiently recent versions of IPython.
|
||||
|
||||
* Breaking changes:
|
||||
|
||||
|
@ -250,7 +250,10 @@ class Config:
|
||||
See docstring for ``define_bool_state``.
|
||||
"""
|
||||
name = name.lower()
|
||||
self.DEFINE_enum(name, os.getenv(name.upper(), default),
|
||||
default = os.getenv(name.upper(), default)
|
||||
if default is not None and default not in enum_values:
|
||||
raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}")
|
||||
self.DEFINE_enum(name, default,
|
||||
enum_values=enum_values, help=help,
|
||||
update_hook=update_global_hook)
|
||||
self._contextmanager_flags.add(name)
|
||||
@ -517,3 +520,17 @@ default_matmul_precision = config.define_enum_state(
|
||||
update_global_jit_state(default_matmul_precision=val),
|
||||
update_thread_local_hook=lambda val: \
|
||||
update_thread_local_jit_state(default_matmul_precision=val))
|
||||
|
||||
traceback_filtering = config.define_enum_state(
|
||||
name = 'jax_traceback_filtering',
|
||||
enum_values=["off", "tracebackhide", "remove_frames", "auto"],
|
||||
default="auto",
|
||||
help="Controls how JAX filters internal frames out of tracebacks.\n\n"
|
||||
"Valid values are:\n"
|
||||
" * \"off\": disables traceback filtering.\n"
|
||||
" * \"auto\": use \"tracebackhide\" if running under a sufficiently "
|
||||
"new IPython, or \"remove_frames\" otherwise.\n"
|
||||
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to "
|
||||
" hidden stack frames, which some traceback printers support.\n"
|
||||
" * \"remove_frames\": removes hidden frames from tracebacks, and adds "
|
||||
" the unfiltered traceback as a __cause__ of the exception.\n")
|
||||
|
@ -16,7 +16,9 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax.lib import xla_extension
|
||||
from jax._src import util
|
||||
|
||||
@ -56,6 +58,11 @@ def include_frame(f):
|
||||
def ignore_known_hidden_frame(f):
|
||||
return 'importlib._bootstrap' in f.f_code.co_filename
|
||||
|
||||
def add_tracebackhide_to_hidden_frames(tb):
|
||||
for f, lineno in traceback.walk_tb(tb):
|
||||
if not include_frame(f):
|
||||
f.f_locals["__tracebackhide__"] = True
|
||||
|
||||
def filter_traceback(tb):
|
||||
out = None
|
||||
# Scan the traceback and collect relevant frames.
|
||||
@ -111,6 +118,38 @@ make_traceback = (types.TracebackType if sys.version_info >= (3, 7) else
|
||||
def filtered_tracebacks_supported():
|
||||
return make_traceback is not None
|
||||
|
||||
def running_under_ipython():
|
||||
"""Returns true if we appear to be in an IPython session."""
|
||||
try:
|
||||
get_ipython() # type: ignore
|
||||
return True
|
||||
except NameError:
|
||||
return False
|
||||
|
||||
def python_supports_tracebackhide():
|
||||
"""Returns true we can add __tracebackhide__ to frames."""
|
||||
# TODO(phawkins): remove this test after droppping Python 3.6 support.
|
||||
return sys.version_info[:2] >= (3, 7)
|
||||
|
||||
def ipython_supports_tracebackhide():
|
||||
"""Returns true if the IPython version supports __tracebackhide__."""
|
||||
import IPython # type: ignore
|
||||
return IPython.version_info[:2] >= (7, 17)
|
||||
|
||||
def filtering_mode():
|
||||
mode = jax.config.jax_traceback_filtering
|
||||
if mode is None or mode == "auto":
|
||||
if (running_under_ipython() and ipython_supports_tracebackhide() and
|
||||
python_supports_tracebackhide()):
|
||||
mode = "tracebackhide"
|
||||
else:
|
||||
mode = "remove_frames"
|
||||
if mode == "tracebackhide" and not python_supports_tracebackhide():
|
||||
warnings.warn("--jax_traceback_filtering=tracebackhide requires Python 3.7 "
|
||||
"or newer.")
|
||||
mode = "remove_frames"
|
||||
return mode
|
||||
|
||||
def api_boundary(fun):
|
||||
'''Wraps ``fun`` to form a boundary for filtering exception tracebacks.
|
||||
|
||||
@ -139,39 +178,47 @@ def api_boundary(fun):
|
||||
|
||||
@util.wraps(fun)
|
||||
def reraise_with_filtered_traceback(*args, **kwargs):
|
||||
__tracebackhide__ = True
|
||||
try:
|
||||
return fun(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if not is_under_reraiser(e):
|
||||
filtered_tb, unfiltered = None, None
|
||||
try:
|
||||
filtered_tb = filter_traceback(e.__traceback__)
|
||||
if filtered_tb is None:
|
||||
raise
|
||||
msg = format_exception_only(e)
|
||||
msg = f'{msg}\n\n{_jax_message_append}'
|
||||
unfiltered = UnfilteredStackTrace(msg)
|
||||
unfiltered.with_traceback(add_call_stack_frames(e.__traceback__))
|
||||
unfiltered.__context__ = e.__context__
|
||||
unfiltered.__cause__ = e.__cause__
|
||||
unfiltered.__suppress_context__ = e.__suppress_context__
|
||||
e.__context__ = None
|
||||
e.__cause__ = unfiltered
|
||||
# There seems to be no way to alter the currently raised exception's
|
||||
# traceback, except via the C API. The currently raised exception
|
||||
# is part of the interpreter's thread state: value `e` is a copy.
|
||||
if hasattr(xla_extension, 'replace_thread_exc_traceback'):
|
||||
xla_extension.replace_thread_exc_traceback(filtered_tb)
|
||||
raise
|
||||
else:
|
||||
# TODO(phawkins): remove this case when jaxlib 0.1.66 is the
|
||||
# minimum.
|
||||
|
||||
# Fallback case for older jaxlibs; includes the current frame.
|
||||
raise e.with_traceback(filtered_tb)
|
||||
finally:
|
||||
del filtered_tb
|
||||
del unfiltered
|
||||
else:
|
||||
mode = filtering_mode()
|
||||
if is_under_reraiser(e) or mode == "off":
|
||||
raise
|
||||
if mode == "tracebackhide":
|
||||
add_tracebackhide_to_hidden_frames(e.__traceback__)
|
||||
raise
|
||||
assert mode == "remove_frames", mode
|
||||
|
||||
filtered_tb, unfiltered, mode = None, None, None
|
||||
try:
|
||||
filtered_tb = filter_traceback(e.__traceback__)
|
||||
if filtered_tb is None:
|
||||
raise
|
||||
msg = format_exception_only(e)
|
||||
msg = f'{msg}\n\n{_jax_message_append}'
|
||||
unfiltered = UnfilteredStackTrace(msg)
|
||||
unfiltered.with_traceback(add_call_stack_frames(e.__traceback__))
|
||||
unfiltered.__context__ = e.__context__
|
||||
unfiltered.__cause__ = e.__cause__
|
||||
unfiltered.__suppress_context__ = e.__suppress_context__
|
||||
e.__context__ = None
|
||||
e.__cause__ = unfiltered
|
||||
|
||||
# There seems to be no way to alter the currently raised exception's
|
||||
# traceback, except via the C API. The currently raised exception
|
||||
# is part of the interpreter's thread state: value `e` is a copy.
|
||||
if hasattr(xla_extension, 'replace_thread_exc_traceback'):
|
||||
xla_extension.replace_thread_exc_traceback(filtered_tb)
|
||||
raise
|
||||
else:
|
||||
# TODO(phawkins): remove this case when jaxlib 0.1.66 is the
|
||||
# minimum.
|
||||
|
||||
# Fallback case for older jaxlibs; includes the current frame.
|
||||
raise e.with_traceback(filtered_tb)
|
||||
finally:
|
||||
del filtered_tb
|
||||
del unfiltered
|
||||
del mode
|
||||
return reraise_with_filtered_traceback
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
@ -40,16 +41,31 @@ def get_exception(etype, f):
|
||||
return e
|
||||
assert False
|
||||
|
||||
def check_filtered_stack_trace(test, etype, f, frame_patterns=[]):
|
||||
test.assertRaises(etype, f)
|
||||
e = get_exception(etype, f)
|
||||
def check_filtered_stack_trace(test, etype, f, frame_patterns=[],
|
||||
filter_mode="remove_frames"):
|
||||
with jax._src.config.traceback_filtering(filter_mode):
|
||||
test.assertRaises(etype, f)
|
||||
e = get_exception(etype, f)
|
||||
c = e.__cause__
|
||||
test.assertIsInstance(c, traceback_util.UnfilteredStackTrace)
|
||||
c_tb = traceback.format_tb(e.__traceback__)
|
||||
# TODO(phawkins): remove this condition after jaxlib 0.1.66 is the minimum.
|
||||
if not hasattr(xla_extension, "replace_thread_exc_traceback"):
|
||||
c_tb = [t for t in c_tb if "reraise_with_filtered_traceback" not in t]
|
||||
if filter_mode == "remove_frames":
|
||||
test.assertIsInstance(c, traceback_util.UnfilteredStackTrace)
|
||||
else:
|
||||
test.assertFalse(isinstance(c, traceback_util.UnfilteredStackTrace))
|
||||
|
||||
if frame_patterns:
|
||||
frames = []
|
||||
for frame, lineno in traceback.walk_tb(e.__traceback__):
|
||||
if filter_mode == "tracebackhide":
|
||||
if "__tracebackhide__" in frame.f_locals.keys():
|
||||
continue
|
||||
elif filter_mode == "remove_frames":
|
||||
# TODO(phawkins): remove this condition after jaxlib 0.1.66 is the minimum.
|
||||
if (not hasattr(xla_extension, "replace_thread_exc_traceback") and
|
||||
frame.f_code.co_name == "reraise_with_filtered_traceback"):
|
||||
continue
|
||||
frames.append((frame, lineno))
|
||||
|
||||
c_tb = traceback.format_list(traceback.StackSummary.extract(frames))
|
||||
for (fname_pat, line_pat), frame_fmt in zip(
|
||||
reversed(frame_patterns), reversed(c_tb)):
|
||||
file = re.escape(__file__)
|
||||
@ -60,12 +76,21 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=[]):
|
||||
f', in {fname_pat}' r'\n\s*' f'{line_pat}')
|
||||
test.assertRegex(frame_fmt, full_pat)
|
||||
|
||||
def skip_if_unsupported_filter_mode(filter_mode):
|
||||
if (filter_mode == "remove_frames" and
|
||||
not traceback_util.filtered_tracebacks_supported()):
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
elif filter_mode == "tracebackhide" and sys.version_info[:2] < (3, 7):
|
||||
raise unittest.SkipTest('Tracebackhide requires Python 3.7 or newer')
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{f}", "filter_mode": f}
|
||||
for f in ("tracebackhide", "remove_frames"))
|
||||
class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
def test_nested_jit(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_nested_jit(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
@jit
|
||||
def innermost(x):
|
||||
@ -83,11 +108,11 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
('<lambda>', 'f = lambda: outermost'),
|
||||
('outermost', 'return 2 + inbetween(x)'),
|
||||
('inbetween', 'return 1 + innermost(x)'),
|
||||
('innermost', 'assert False')])
|
||||
('innermost', 'assert False')],
|
||||
filter_mode=filter_mode)
|
||||
|
||||
def test_nested_jit_and_vmap(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_nested_jit_and_vmap(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
@jit
|
||||
def innermost(x):
|
||||
@ -105,11 +130,11 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
('<lambda>', 'f = lambda: outermost'),
|
||||
('outermost', 'return 2 + inbetween(x)'),
|
||||
('inbetween', 'return 1 + vmap(innermost)(x)'),
|
||||
('innermost', 'assert False')])
|
||||
('innermost', 'assert False')],
|
||||
filter_mode=filter_mode)
|
||||
|
||||
def test_nested_jit_and_grad(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_nested_jit_and_grad(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
@jit
|
||||
def innermost(x):
|
||||
@ -127,11 +152,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
('<lambda>', 'f = lambda: outermost'),
|
||||
('outermost', 'return 2 + inbetween(x)'),
|
||||
('inbetween', 'return 1 + grad(innermost)(x)'),
|
||||
])
|
||||
], filter_mode=filter_mode)
|
||||
|
||||
def test_lax_cond(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_lax_cond(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
def err(_):
|
||||
assert False
|
||||
@ -142,11 +166,11 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.cond(True, err, lambda _: (), ())'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')],
|
||||
filter_mode=filter_mode)
|
||||
|
||||
def test_lax_switch(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_lax_switch(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
def err(_):
|
||||
assert False
|
||||
@ -158,11 +182,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.switch(1, branches, ())'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
|
||||
def test_lax_scan(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_lax_scan(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
@ -173,11 +196,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.scan(err, (), (), 3)'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
|
||||
def test_lax_fori_loop(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_lax_fori_loop(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
@ -188,11 +210,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.fori_loop(0, 3, err, ())'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
|
||||
def test_lax_while_loop(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_lax_while_loop(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
@ -204,11 +225,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.while_loop(pred, err, ())'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
|
||||
def test_lax_map(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_lax_map(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
def err(_):
|
||||
assert False
|
||||
@ -220,11 +240,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.map(err, xs)'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
|
||||
def test_lax_custom_root(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_lax_custom_root(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
@ -242,17 +261,16 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f1, [
|
||||
('f1', 'return lax.custom_root(g, 0., err, solve)'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
check_filtered_stack_trace(self, AssertionError, f2, [
|
||||
('f2', 'return lax.custom_root(g, 0., solve, err)'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
check_filtered_stack_trace(self, AssertionError, f3, [
|
||||
('f3', 'return lax.custom_root(err, 0., solve, solve)'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
|
||||
def test_lax_custom_linear_solve(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_lax_custom_linear_solve(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
@ -269,14 +287,13 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f1, [
|
||||
('f1', 'return lax.custom_linear_solve(err, b, solve)'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
check_filtered_stack_trace(self, AssertionError, f2, [
|
||||
('f2', 'return lax.custom_linear_solve(matvec, b, err)'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
|
||||
def test_lax_associative_scan(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_lax_associative_scan(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
@ -288,11 +305,10 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.associative_scan(err, xs)'),
|
||||
('err', 'assert False')])
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
|
||||
def test_cause_chain(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
def test_cause_chain(self, filter_mode):
|
||||
skip_if_unsupported_filter_mode(filter_mode)
|
||||
|
||||
@jit
|
||||
def inner(x):
|
||||
@ -308,7 +324,7 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
|
||||
check_filtered_stack_trace(self, TypeError, f, [
|
||||
('<lambda>', 'f = lambda: outer'),
|
||||
('outer', 'raise TypeError')])
|
||||
('outer', 'raise TypeError')], filter_mode=filter_mode)
|
||||
e = get_exception(TypeError, f)
|
||||
self.assertIsInstance(e.__cause__, traceback_util.UnfilteredStackTrace)
|
||||
self.assertIsInstance(e.__cause__.__cause__, ValueError)
|
||||
|
Loading…
x
Reference in New Issue
Block a user