mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add more informative error when static argument is passed to non-static JIT parameter
This commit is contained in:
parent
8196c8bf36
commit
a44e129ae7
@ -645,6 +645,18 @@ def _infer_params_impl(
|
||||
"An overflow was encountered while parsing an argument to a jitted "
|
||||
f"computation, whose {arg_path}."
|
||||
) from e
|
||||
except TypeError as e:
|
||||
arg_description = (f"path {dbg.arg_names[i]}" if dbg
|
||||
else f"flattened argument number {i}")
|
||||
raise TypeError(
|
||||
f"Error interpreting argument to {fun} as an abstract array."
|
||||
f" The problematic value is of type {type(a)} and was passed to"
|
||||
f" the function at {arg_description}.\n"
|
||||
"This typically means that a jit-wrapped function was called with a non-array"
|
||||
" argument, and this argument was not marked as static using the"
|
||||
" static_argnums or static_argnames parameters of jax.jit."
|
||||
) from e
|
||||
|
||||
in_type = in_avals = tuple(avals)
|
||||
else:
|
||||
in_type = in_avals
|
||||
|
@ -733,13 +733,12 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, r".* 'foo' of type <.*'str'> is not a valid JAX type"):
|
||||
err_str = ("Error interpreting argument to .* as an abstract array. The problematic "
|
||||
"value is of type .* and was passed to the function at path x.")
|
||||
with self.assertRaisesRegex(TypeError, err_str):
|
||||
jit(f)("foo")
|
||||
|
||||
# Jax type objects aren't valid data arguments.
|
||||
err_str = "JAX scalar type .*int32.* cannot be interpreted as a JAX array."
|
||||
|
||||
with self.assertRaisesRegex(TypeError, err_str):
|
||||
jit(f)(jnp.int32)
|
||||
|
||||
@ -1576,13 +1575,14 @@ class APITest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
self.assertRaisesRegex(
|
||||
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
||||
lambda: grad(f)("foo"))
|
||||
with self.assertRaisesRegex(TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type"):
|
||||
grad(f)("foo")
|
||||
|
||||
self.assertRaisesRegex(
|
||||
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
||||
lambda: jit(f)("foo"))
|
||||
|
||||
err_str = ("Error interpreting argument to .* as an abstract array. The problematic "
|
||||
"value is of type .* and was passed to the function at path x.")
|
||||
with self.assertRaisesRegex(TypeError, err_str):
|
||||
jit(f)("foo")
|
||||
|
||||
def test_grad_tuple_output(self):
|
||||
jtu.check_raises(lambda: grad(lambda x: (x,x))(1.0), TypeError,
|
||||
@ -2959,8 +2959,10 @@ class APITest(jtu.JaxTestCase):
|
||||
lambda: jnp.arange(1.0).astype(int))
|
||||
|
||||
def test_error_for_invalid_dtype(self):
|
||||
err_str = ("Error interpreting argument to .* as an abstract array. The problematic "
|
||||
r"value is of type .* and was passed to the function at path args\[1\].")
|
||||
with jax.enable_checks(False):
|
||||
with self.assertRaisesRegex(TypeError, ".*not a valid JAX array type.*"):
|
||||
with self.assertRaisesRegex(TypeError, err_str):
|
||||
lax.add(jnp.array(7), np.array("hello"))
|
||||
with jax.enable_checks(True):
|
||||
with self.assertRaises(AssertionError):
|
||||
|
@ -2844,9 +2844,10 @@ class LaxTest(jtu.JaxTestCase):
|
||||
(np.int32(1), np.int16(2))))
|
||||
|
||||
def test_primitive_jaxtype_error(self):
|
||||
err_str = ("Error interpreting argument to .* as an abstract array. The problematic "
|
||||
r"value is of type .* and was passed to the function at path args\[1\].")
|
||||
with jax.enable_checks(False):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Argument .* of type .* is not a valid JAX type"):
|
||||
with self.assertRaisesRegex(TypeError, err_str):
|
||||
lax.add(1, 'hi')
|
||||
|
||||
def test_reduction_with_repeated_axes_error(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user