mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Added better error messages. (#2058)
#2057 Added better error messages for when a user accidentally uses a python cast instead of a the `jax.numpy` casting.
This commit is contained in:
parent
e0ed5adc75
commit
82d6c6ce51
@ -24,17 +24,18 @@ from . import dtypes
|
||||
from . util import prod, partialmethod
|
||||
|
||||
|
||||
def concretization_err_msg(fun):
|
||||
def concretization_err_msg(fun, context=None):
|
||||
fname = getattr(fun, "__name__", fun)
|
||||
msg = ("Abstract value passed to `{}`, which requires a concrete value. "
|
||||
"The function to be transformed can't be traced at the required level "
|
||||
"of abstraction. If using `jit`, try using `static_argnums` or "
|
||||
"applying `jit` to smaller subfunctions instead.")
|
||||
return msg.format(fname)
|
||||
if context is None:
|
||||
context = ("The function to be transformed can't be traced at the required level "
|
||||
"of abstraction. If using `jit`, try using `static_argnums` or "
|
||||
"applying `jit` to smaller subfunctions instead.")
|
||||
msg = "Abstract value passed to `{}`, which requires a concrete value. {}"
|
||||
return msg.format(fname, context)
|
||||
|
||||
def concretization_function_error(fun):
|
||||
def concretization_function_error(fun, context=None):
|
||||
def error(self, *args):
|
||||
raise TypeError(concretization_err_msg(fun))
|
||||
raise TypeError(concretization_err_msg(fun, context))
|
||||
return error
|
||||
|
||||
|
||||
@ -64,9 +65,12 @@ class UnshapedArray(core.AbstractValue):
|
||||
", weak_type=True" if self.weak_type else "")
|
||||
|
||||
_bool = _nonzero = concretization_function_error(bool)
|
||||
_float = concretization_function_error(float)
|
||||
_int = concretization_function_error(int)
|
||||
_complex = concretization_function_error(complex)
|
||||
_float = concretization_function_error(
|
||||
float, "Try using `value.astype(float)` instead.")
|
||||
_int = concretization_function_error(
|
||||
int, "Try using `value.astype(int)` instead.")
|
||||
_complex = concretization_function_error(
|
||||
complex, "Try using `value.astype(complex)` instead.")
|
||||
_hex = concretization_function_error(hex)
|
||||
_oct = concretization_function_error(oct)
|
||||
|
||||
|
@ -221,6 +221,16 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertRaisesRegex(
|
||||
TypeError, "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).",
|
||||
lambda: grad(f)(onp.zeros(3), onp.zeros(4)))
|
||||
|
||||
def test_abstract_error_message(self):
|
||||
for castfun in [float, complex, int]:
|
||||
def f(x):
|
||||
return castfun(x)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Try using `value.astype\({}\)` instead".format(castfun.__name__),
|
||||
lambda: jit(f)(1.0))
|
||||
|
||||
def test_switch_value_jit(self):
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user