Add more informative error when static argument is passed to non-static JIT parameter

This commit is contained in:
Jake VanderPlas 2024-09-24 05:22:18 -07:00
parent 8196c8bf36
commit a44e129ae7
3 changed files with 28 additions and 13 deletions

View File

@ -645,6 +645,18 @@ def _infer_params_impl(
"An overflow was encountered while parsing an argument to a jitted "
f"computation, whose {arg_path}."
) from e
except TypeError as e:
arg_description = (f"path {dbg.arg_names[i]}" if dbg
else f"flattened argument number {i}")
raise TypeError(
f"Error interpreting argument to {fun} as an abstract array."
f" The problematic value is of type {type(a)} and was passed to"
f" the function at {arg_description}.\n"
"This typically means that a jit-wrapped function was called with a non-array"
" argument, and this argument was not marked as static using the"
" static_argnums or static_argnames parameters of jax.jit."
) from e
in_type = in_avals = tuple(avals)
else:
in_type = in_avals

View File

@ -733,13 +733,12 @@ class JitTest(jtu.BufferDonationTestCase):
def f(x):
return x
with self.assertRaisesRegex(
TypeError, r".* 'foo' of type <.*'str'> is not a valid JAX type"):
err_str = ("Error interpreting argument to .* as an abstract array. The problematic "
"value is of type .* and was passed to the function at path x.")
with self.assertRaisesRegex(TypeError, err_str):
jit(f)("foo")
# Jax type objects aren't valid data arguments.
err_str = "JAX scalar type .*int32.* cannot be interpreted as a JAX array."
with self.assertRaisesRegex(TypeError, err_str):
jit(f)(jnp.int32)
@ -1576,13 +1575,14 @@ class APITest(jtu.JaxTestCase):
def f(x):
return x
self.assertRaisesRegex(
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
lambda: grad(f)("foo"))
with self.assertRaisesRegex(TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type"):
grad(f)("foo")
self.assertRaisesRegex(
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
lambda: jit(f)("foo"))
err_str = ("Error interpreting argument to .* as an abstract array. The problematic "
"value is of type .* and was passed to the function at path x.")
with self.assertRaisesRegex(TypeError, err_str):
jit(f)("foo")
def test_grad_tuple_output(self):
jtu.check_raises(lambda: grad(lambda x: (x,x))(1.0), TypeError,
@ -2959,8 +2959,10 @@ class APITest(jtu.JaxTestCase):
lambda: jnp.arange(1.0).astype(int))
def test_error_for_invalid_dtype(self):
err_str = ("Error interpreting argument to .* as an abstract array. The problematic "
r"value is of type .* and was passed to the function at path args\[1\].")
with jax.enable_checks(False):
with self.assertRaisesRegex(TypeError, ".*not a valid JAX array type.*"):
with self.assertRaisesRegex(TypeError, err_str):
lax.add(jnp.array(7), np.array("hello"))
with jax.enable_checks(True):
with self.assertRaises(AssertionError):

View File

@ -2844,9 +2844,10 @@ class LaxTest(jtu.JaxTestCase):
(np.int32(1), np.int16(2))))
def test_primitive_jaxtype_error(self):
err_str = ("Error interpreting argument to .* as an abstract array. The problematic "
r"value is of type .* and was passed to the function at path args\[1\].")
with jax.enable_checks(False):
with self.assertRaisesRegex(
TypeError, "Argument .* of type .* is not a valid JAX type"):
with self.assertRaisesRegex(TypeError, err_str):
lax.add(1, 'hi')
def test_reduction_with_repeated_axes_error(self):