Added better error messages. (#2058)

#2057

Added better error messages for when a user accidentally uses a python cast instead of a the `jax.numpy` casting.
This commit is contained in:
Chase Roberts 2020-01-27 15:44:33 -08:00 committed by Skye Wanderman-Milne
parent e0ed5adc75
commit 82d6c6ce51
2 changed files with 25 additions and 11 deletions

View File

@ -24,17 +24,18 @@ from . import dtypes
from . util import prod, partialmethod
def concretization_err_msg(fun):
def concretization_err_msg(fun, context=None):
fname = getattr(fun, "__name__", fun)
msg = ("Abstract value passed to `{}`, which requires a concrete value. "
"The function to be transformed can't be traced at the required level "
"of abstraction. If using `jit`, try using `static_argnums` or "
"applying `jit` to smaller subfunctions instead.")
return msg.format(fname)
if context is None:
context = ("The function to be transformed can't be traced at the required level "
"of abstraction. If using `jit`, try using `static_argnums` or "
"applying `jit` to smaller subfunctions instead.")
msg = "Abstract value passed to `{}`, which requires a concrete value. {}"
return msg.format(fname, context)
def concretization_function_error(fun):
def concretization_function_error(fun, context=None):
def error(self, *args):
raise TypeError(concretization_err_msg(fun))
raise TypeError(concretization_err_msg(fun, context))
return error
@ -64,9 +65,12 @@ class UnshapedArray(core.AbstractValue):
", weak_type=True" if self.weak_type else "")
_bool = _nonzero = concretization_function_error(bool)
_float = concretization_function_error(float)
_int = concretization_function_error(int)
_complex = concretization_function_error(complex)
_float = concretization_function_error(
float, "Try using `value.astype(float)` instead.")
_int = concretization_function_error(
int, "Try using `value.astype(int)` instead.")
_complex = concretization_function_error(
complex, "Try using `value.astype(complex)` instead.")
_hex = concretization_function_error(hex)
_oct = concretization_function_error(oct)

View File

@ -221,6 +221,16 @@ class APITest(jtu.JaxTestCase):
self.assertRaisesRegex(
TypeError, "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).",
lambda: grad(f)(onp.zeros(3), onp.zeros(4)))
def test_abstract_error_message(self):
for castfun in [float, complex, int]:
def f(x):
return castfun(x)
self.assertRaisesRegex(
TypeError,
"Try using `value.astype\({}\)` instead".format(castfun.__name__),
lambda: jit(f)(1.0))
def test_switch_value_jit(self):
def f(x):