2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2020 The JAX Authors.
|
2020-08-14 13:22:20 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import re
|
2023-08-03 10:20:29 -07:00
|
|
|
import sys
|
2020-08-14 13:22:20 -07:00
|
|
|
import traceback
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
2021-03-16 09:10:10 -07:00
|
|
|
from absl.testing import parameterized
|
2020-08-14 13:22:20 -07:00
|
|
|
|
2021-03-16 09:10:10 -07:00
|
|
|
import jax
|
2020-08-14 13:22:20 -07:00
|
|
|
import jax.numpy as jnp
|
2023-02-14 23:00:40 -08:00
|
|
|
from jax import grad, jit, vmap, lax
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src import config
|
2023-02-14 23:00:40 -08:00
|
|
|
from jax._src import core
|
2021-05-03 07:48:18 -07:00
|
|
|
from jax._src import source_info_util
|
2023-10-12 13:15:22 +01:00
|
|
|
from jax._src import test_util as jtu
|
2020-11-04 09:01:18 -08:00
|
|
|
from jax._src import traceback_util
|
2020-08-14 13:22:20 -07:00
|
|
|
|
|
|
|
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
|
|
|
|
|
|
def get_exception(etype, f):
|
|
|
|
try:
|
|
|
|
f()
|
|
|
|
except etype as e:
|
|
|
|
return e
|
|
|
|
assert False
|
|
|
|
|
2021-10-04 17:54:18 -07:00
|
|
|
def check_filtered_stack_trace(test, etype, f, frame_patterns=(),
|
2021-06-02 15:22:50 -04:00
|
|
|
filter_mode="remove_frames"):
|
2023-10-12 13:15:22 +01:00
|
|
|
with config.traceback_filtering(filter_mode):
|
2021-06-02 15:22:50 -04:00
|
|
|
test.assertRaises(etype, f)
|
|
|
|
e = get_exception(etype, f)
|
2021-05-01 12:41:10 -07:00
|
|
|
c = e.__cause__
|
2023-08-03 10:20:29 -07:00
|
|
|
if filter_mode == "quiet_remove_frames":
|
|
|
|
if sys.version_info >= (3, 11):
|
|
|
|
assert any("For simplicity" in x for x in e.__notes__)
|
|
|
|
else:
|
|
|
|
test.assertIsInstance(c, jax.errors.SimplifiedTraceback)
|
|
|
|
elif filter_mode == "remove_frames":
|
2021-06-02 15:22:50 -04:00
|
|
|
test.assertIsInstance(c, traceback_util.UnfilteredStackTrace)
|
|
|
|
else:
|
|
|
|
test.assertFalse(isinstance(c, traceback_util.UnfilteredStackTrace))
|
|
|
|
|
2020-08-14 13:22:20 -07:00
|
|
|
if frame_patterns:
|
2021-06-02 15:22:50 -04:00
|
|
|
frames = []
|
|
|
|
for frame, lineno in traceback.walk_tb(e.__traceback__):
|
|
|
|
if filter_mode == "tracebackhide":
|
|
|
|
if "__tracebackhide__" in frame.f_locals.keys():
|
|
|
|
continue
|
|
|
|
frames.append((frame, lineno))
|
|
|
|
|
|
|
|
c_tb = traceback.format_list(traceback.StackSummary.extract(frames))
|
2020-08-14 13:22:20 -07:00
|
|
|
for (fname_pat, line_pat), frame_fmt in zip(
|
|
|
|
reversed(frame_patterns), reversed(c_tb)):
|
2021-05-03 07:48:18 -07:00
|
|
|
file = re.escape(__file__)
|
|
|
|
fname_pat = re.escape(fname_pat)
|
2020-08-14 13:22:20 -07:00
|
|
|
line_pat = re.escape(line_pat)
|
|
|
|
full_pat = (
|
2020-11-10 00:23:54 +08:00
|
|
|
f' File "{file}", line ' r'[0-9]+'
|
2020-08-14 13:22:20 -07:00
|
|
|
f', in {fname_pat}' r'\n\s*' f'{line_pat}')
|
|
|
|
test.assertRegex(frame_fmt, full_pat)
|
|
|
|
|
|
|
|
|
2022-02-15 02:42:30 -08:00
|
|
|
@jtu.with_config(jax_traceback_filtering='auto') # JaxTestCase defaults to off.
|
2021-06-02 15:22:50 -04:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": f"_{f}", "filter_mode": f}
|
2023-08-03 10:20:29 -07:00
|
|
|
for f in ("tracebackhide", "remove_frames", "quiet_remove_frames"))
|
2020-08-14 13:22:20 -07:00
|
|
|
class FilteredTracebackTest(jtu.JaxTestCase):
|
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_nested_jit(self, filter_mode):
|
2020-08-14 13:22:20 -07:00
|
|
|
@jit
|
|
|
|
def innermost(x):
|
|
|
|
assert False
|
|
|
|
@jit
|
|
|
|
def inbetween(x):
|
|
|
|
return 1 + innermost(x)
|
|
|
|
@jit
|
|
|
|
def outermost(x):
|
|
|
|
return 2 + inbetween(x)
|
|
|
|
|
|
|
|
f = lambda: outermost(jnp.array([1, 2]))
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f, [
|
|
|
|
('<lambda>', 'f = lambda: outermost'),
|
|
|
|
('outermost', 'return 2 + inbetween(x)'),
|
|
|
|
('inbetween', 'return 1 + innermost(x)'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('innermost', 'assert False')],
|
|
|
|
filter_mode=filter_mode)
|
2020-08-14 13:22:20 -07:00
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_nested_jit_and_vmap(self, filter_mode):
|
2020-08-14 13:22:20 -07:00
|
|
|
@jit
|
|
|
|
def innermost(x):
|
|
|
|
assert False
|
|
|
|
@jit
|
|
|
|
def inbetween(x):
|
|
|
|
return 1 + vmap(innermost)(x)
|
|
|
|
@jit
|
|
|
|
def outermost(x):
|
|
|
|
return 2 + inbetween(x)
|
|
|
|
|
|
|
|
f = lambda: outermost(jnp.array([1, 2]))
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f, [
|
|
|
|
('<lambda>', 'f = lambda: outermost'),
|
|
|
|
('outermost', 'return 2 + inbetween(x)'),
|
|
|
|
('inbetween', 'return 1 + vmap(innermost)(x)'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('innermost', 'assert False')],
|
|
|
|
filter_mode=filter_mode)
|
2020-08-14 13:22:20 -07:00
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_nested_jit_and_grad(self, filter_mode):
|
2020-08-14 13:22:20 -07:00
|
|
|
@jit
|
|
|
|
def innermost(x):
|
|
|
|
assert False
|
|
|
|
@jit
|
|
|
|
def inbetween(x):
|
|
|
|
return 1 + grad(innermost)(x)
|
|
|
|
@jit
|
|
|
|
def outermost(x):
|
|
|
|
return 2 + inbetween(x)
|
|
|
|
|
|
|
|
f = lambda: outermost(jnp.array([1, 2]))
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, TypeError, f, [
|
|
|
|
('<lambda>', 'f = lambda: outermost'),
|
|
|
|
('outermost', 'return 2 + inbetween(x)'),
|
2021-05-01 12:41:10 -07:00
|
|
|
('inbetween', 'return 1 + grad(innermost)(x)'),
|
2021-06-02 15:22:50 -04:00
|
|
|
], filter_mode=filter_mode)
|
2020-08-14 13:22:20 -07:00
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_lax_cond(self, filter_mode):
|
2021-02-23 20:08:10 -08:00
|
|
|
def err(_):
|
|
|
|
assert False
|
|
|
|
return ()
|
|
|
|
|
|
|
|
def f():
|
|
|
|
return lax.cond(True, err, lambda _: (), ())
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f, [
|
|
|
|
('f', 'return lax.cond(True, err, lambda _: (), ())'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')],
|
|
|
|
filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_lax_switch(self, filter_mode):
|
2021-02-23 20:08:10 -08:00
|
|
|
def err(_):
|
|
|
|
assert False
|
|
|
|
return ()
|
|
|
|
|
|
|
|
def f():
|
|
|
|
branches = [lambda _: (), err, lambda _: ()]
|
|
|
|
return lax.switch(1, branches, ())
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f, [
|
|
|
|
('f', 'return lax.switch(1, branches, ())'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_lax_scan(self, filter_mode):
|
2021-02-23 20:08:10 -08:00
|
|
|
def err(*_):
|
|
|
|
assert False
|
|
|
|
return ()
|
|
|
|
|
|
|
|
def f():
|
|
|
|
return lax.scan(err, (), (), 3)
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f, [
|
|
|
|
('f', 'return lax.scan(err, (), (), 3)'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_lax_fori_loop(self, filter_mode):
|
2021-02-23 20:08:10 -08:00
|
|
|
def err(*_):
|
|
|
|
assert False
|
|
|
|
return ()
|
|
|
|
|
|
|
|
def f():
|
|
|
|
return lax.fori_loop(0, 3, err, ())
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f, [
|
|
|
|
('f', 'return lax.fori_loop(0, 3, err, ())'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_lax_while_loop(self, filter_mode):
|
2021-02-23 20:08:10 -08:00
|
|
|
def err(*_):
|
|
|
|
assert False
|
|
|
|
return ()
|
|
|
|
|
|
|
|
def f():
|
|
|
|
pred = lambda _: False
|
|
|
|
return lax.while_loop(pred, err, ())
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f, [
|
|
|
|
('f', 'return lax.while_loop(pred, err, ())'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_lax_map(self, filter_mode):
|
2021-02-23 20:08:10 -08:00
|
|
|
def err(_):
|
|
|
|
assert False
|
|
|
|
return ()
|
|
|
|
|
|
|
|
def f():
|
|
|
|
xs = jnp.ones(3)
|
|
|
|
return lax.map(err, xs)
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f, [
|
|
|
|
('f', 'return lax.map(err, xs)'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_lax_custom_root(self, filter_mode):
|
2021-02-23 20:08:10 -08:00
|
|
|
def err(*_):
|
|
|
|
assert False
|
|
|
|
return ()
|
|
|
|
|
|
|
|
def g(x): return (x - 1.) ** 2.
|
|
|
|
def solve(*_): return 1.
|
|
|
|
|
|
|
|
def f1():
|
|
|
|
return lax.custom_root(g, 0., err, solve)
|
|
|
|
def f2():
|
|
|
|
return lax.custom_root(g, 0., solve, err)
|
|
|
|
def f3():
|
|
|
|
return lax.custom_root(err, 0., solve, solve)
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f1, [
|
|
|
|
('f1', 'return lax.custom_root(g, 0., err, solve)'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
check_filtered_stack_trace(self, AssertionError, f2, [
|
|
|
|
('f2', 'return lax.custom_root(g, 0., solve, err)'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
check_filtered_stack_trace(self, AssertionError, f3, [
|
|
|
|
('f3', 'return lax.custom_root(err, 0., solve, solve)'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_lax_custom_linear_solve(self, filter_mode):
|
2021-02-23 20:08:10 -08:00
|
|
|
def err(*_):
|
|
|
|
assert False
|
|
|
|
return ()
|
|
|
|
|
|
|
|
matvec = lambda v: v
|
|
|
|
solve = lambda mv, b: 1.
|
|
|
|
b = 1.
|
|
|
|
|
|
|
|
def f1():
|
|
|
|
return lax.custom_linear_solve(err, b, solve)
|
|
|
|
def f2():
|
|
|
|
return lax.custom_linear_solve(matvec, b, err)
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f1, [
|
|
|
|
('f1', 'return lax.custom_linear_solve(err, b, solve)'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
check_filtered_stack_trace(self, AssertionError, f2, [
|
|
|
|
('f2', 'return lax.custom_linear_solve(matvec, b, err)'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_lax_associative_scan(self, filter_mode):
|
2021-02-23 20:08:10 -08:00
|
|
|
def err(*_):
|
|
|
|
assert False
|
|
|
|
return ()
|
|
|
|
|
|
|
|
def f():
|
|
|
|
xs = jnp.arange(4.)
|
|
|
|
return lax.associative_scan(err, xs)
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f, [
|
|
|
|
('f', 'return lax.associative_scan(err, xs)'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
2021-02-23 20:08:10 -08:00
|
|
|
|
2021-09-21 15:41:00 -07:00
|
|
|
def test_custom_jvp(self, filter_mode):
|
|
|
|
def err(*args):
|
|
|
|
assert False
|
|
|
|
return args
|
|
|
|
|
|
|
|
@jax.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return err(x)
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def f_jvp(x, tx):
|
|
|
|
x = err(x)
|
|
|
|
return x, tx
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, lambda: f(1.), [
|
|
|
|
('f', 'return err(x)'),
|
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
|
|
|
check_filtered_stack_trace(self, AssertionError, lambda: jax.jvp(f, [1.], [1.]), [
|
|
|
|
('f_jvp', 'x = err(x)'),
|
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
|
|
|
|
|
|
|
def test_custom_vjp(self, filter_mode):
|
|
|
|
def err(*args):
|
|
|
|
assert False
|
|
|
|
return args[0]
|
|
|
|
|
|
|
|
@jax.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
return err(x)
|
|
|
|
|
|
|
|
def fwd(x):
|
|
|
|
return x, ()
|
|
|
|
|
|
|
|
def fwd_err(x):
|
|
|
|
x = err(x)
|
|
|
|
return x, ()
|
|
|
|
|
|
|
|
def bwd(_, g):
|
|
|
|
return (g,)
|
|
|
|
|
|
|
|
def bwd_err(_, g):
|
|
|
|
g = err(g)
|
|
|
|
return (g,)
|
|
|
|
|
|
|
|
f.defvjp(fwd_err, bwd)
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, lambda: f(1.), [
|
|
|
|
('f', 'return err(x)'),
|
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, lambda: jax.grad(f)(1.), [
|
|
|
|
('fwd_err', 'x = err(x)'),
|
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
|
|
|
|
|
|
|
f.defvjp(fwd, bwd_err)
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, lambda: jax.grad(f)(1.), [
|
|
|
|
('bwd_err', 'g = err(g)'),
|
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
|
|
|
|
2025-02-06 14:55:57 +00:00
|
|
|
def test_jvp(self, filter_mode):
|
|
|
|
def err(_):
|
|
|
|
assert False
|
|
|
|
return ()
|
|
|
|
|
|
|
|
def f():
|
|
|
|
p = (1.,)
|
|
|
|
t = (0.,)
|
|
|
|
return jax.jvp(err, p, t)
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f, [
|
|
|
|
('f', 'return jax.jvp(err, p, t)'),
|
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
|
|
|
|
|
|
|
def test_vjp(self, filter_mode):
|
|
|
|
def err(_):
|
|
|
|
assert False
|
|
|
|
return ()
|
|
|
|
|
|
|
|
def f():
|
|
|
|
x = 1.
|
|
|
|
return jax.vjp(err, x)[0]
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, AssertionError, f, [
|
|
|
|
('f', 'return jax.vjp(err, x)[0]'),
|
|
|
|
('err', 'assert False')], filter_mode=filter_mode)
|
|
|
|
|
|
|
|
def test_debug_nans(self, filter_mode):
|
|
|
|
@jax.jit
|
|
|
|
def f(x):
|
|
|
|
return 0. / x
|
|
|
|
|
|
|
|
f(2.)
|
|
|
|
def g():
|
|
|
|
return f(0.)
|
|
|
|
|
|
|
|
with jax.debug_nans(True):
|
|
|
|
check_filtered_stack_trace(self, ZeroDivisionError, g, [
|
|
|
|
('g', 'return f(0.)'),
|
|
|
|
('f', 'return 0. / x')], filter_mode=filter_mode)
|
|
|
|
|
2021-06-02 15:22:50 -04:00
|
|
|
def test_cause_chain(self, filter_mode):
|
2020-08-14 13:22:20 -07:00
|
|
|
@jit
|
|
|
|
def inner(x):
|
|
|
|
raise ValueError('inner')
|
|
|
|
@jit
|
|
|
|
def outer(x):
|
|
|
|
try:
|
|
|
|
inner(x)
|
|
|
|
except ValueError as e:
|
|
|
|
raise TypeError('outer') from e
|
|
|
|
|
|
|
|
f = lambda: outer(1.)
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, TypeError, f, [
|
|
|
|
('<lambda>', 'f = lambda: outer'),
|
2021-06-02 15:22:50 -04:00
|
|
|
('outer', 'raise TypeError')], filter_mode=filter_mode)
|
2023-08-03 10:20:29 -07:00
|
|
|
e = get_exception(TypeError, f) # Uses the default JAX_TRACEBACK_FILTERING=auto
|
|
|
|
if sys.version_info >= (3, 11):
|
|
|
|
assert any("For simplicity" in x for x in e.__notes__)
|
|
|
|
self.assertIsInstance(e.__cause__, ValueError)
|
|
|
|
else:
|
|
|
|
self.assertIsInstance(e.__cause__, jax.errors.SimplifiedTraceback)
|
|
|
|
self.assertIsInstance(e.__cause__.__cause__, ValueError)
|
2020-08-14 13:22:20 -07:00
|
|
|
|
[jax] completely truncate trivial filtered tracebacks
[jaxlib] allow empty traceback overwrites
If an error is raised within JAX (under an API boundary frame), but prior to entering any user code, then all frames in between are JAX-internal. In this case, our filtered traceback ought to be trivial, i.e. empty of any frames at all.
Prior to this change, we did not handle this edge case consistently with the non-trivial case: any trivial filtered traceback was modified to comprise a single JAX-internal frame (namely, the inner-most one). With this change, the filtered traceback can be completely empty and result in omission of all JAX-internal frames.
Before:
```
Traceback (most recent call last):
File "tb.py", line 10, in <module>
jit(f)(A())
File "jax/_src/api.py", line 2850, in _check_arg
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
TypeError: Argument ... is not a valid JAX type.
```
After:
```
Traceback (most recent call last):
File "tb.py", line 10, in <module>
jit(f)(A())
TypeError: Argument ... is not a valid JAX type.
```
PiperOrigin-RevId: 422962976
2022-01-19 19:41:24 -08:00
|
|
|
def test_null_traceback(self, filter_mode):
|
|
|
|
class TestA: pass
|
|
|
|
def f(a): return a + 1
|
|
|
|
|
|
|
|
def err():
|
|
|
|
a = TestA()
|
|
|
|
return jit(f)(a)
|
|
|
|
|
|
|
|
check_filtered_stack_trace(self, TypeError, err, [
|
|
|
|
('err', 'return jit(f)(a)')], filter_mode=filter_mode)
|
|
|
|
|
2020-11-18 10:08:18 -05:00
|
|
|
|
2022-02-15 02:42:30 -08:00
|
|
|
@jtu.with_config(jax_traceback_filtering='auto') # JaxTestCase defaults to off.
|
2021-05-03 07:48:18 -07:00
|
|
|
class UserContextTracebackTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_grad_norm(self):
|
|
|
|
e = None
|
|
|
|
try:
|
|
|
|
with jax.debug_nans(True):
|
|
|
|
jax.grad(jnp.linalg.norm)(jnp.zeros((3, 3), jnp.float32))
|
|
|
|
except FloatingPointError as exc:
|
|
|
|
e = exc
|
|
|
|
self.assertIsNot(e, None)
|
|
|
|
self.assertIn("invalid value", str(e))
|
2023-08-03 10:20:29 -07:00
|
|
|
if sys.version_info >= (3, 11):
|
|
|
|
self.assertIsInstance(
|
|
|
|
e.__cause__,
|
|
|
|
source_info_util.JaxStackTraceBeforeTransformation)
|
|
|
|
else:
|
|
|
|
self.assertIsInstance(
|
|
|
|
e.__cause__.__cause__,
|
|
|
|
source_info_util.JaxStackTraceBeforeTransformation)
|
2021-05-03 07:48:18 -07:00
|
|
|
|
|
|
|
|
2021-03-16 09:10:10 -07:00
|
|
|
class CustomErrorsTest(jtu.JaxTestCase):
|
2024-09-26 08:38:46 -07:00
|
|
|
|
2022-10-10 11:34:43 -07:00
|
|
|
@jtu.sample_product(
|
2024-09-26 08:38:46 -07:00
|
|
|
errorclass=[
|
|
|
|
errorclass
|
|
|
|
for errorclass in dir(jax.errors)
|
|
|
|
if errorclass.endswith('Error')
|
|
|
|
and errorclass
|
|
|
|
not in [
|
|
|
|
'JaxIndexError',
|
|
|
|
'JAXTypeError',
|
|
|
|
'JaxRuntimeError',
|
|
|
|
]
|
|
|
|
],
|
2022-10-10 11:34:43 -07:00
|
|
|
)
|
2021-03-16 09:10:10 -07:00
|
|
|
def testErrorsURL(self, errorclass):
|
|
|
|
class FakeTracer(core.Tracer):
|
|
|
|
aval = None
|
|
|
|
ErrorClass = getattr(jax.errors, errorclass)
|
|
|
|
err = ErrorClass(FakeTracer(None))
|
|
|
|
|
|
|
|
self.assertIn(f'https://jax.readthedocs.io/en/latest/errors.html#jax.errors.{errorclass}', str(err))
|
|
|
|
|
|
|
|
|
2020-08-14 13:22:20 -07:00
|
|
|
if __name__ == '__main__':
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|