Improve type error when an object dtype is passed to an operator like +.

Fixes #856.
This commit is contained in:
Peter Hawkins 2021-05-10 11:52:12 -04:00
parent fa9ca33e60
commit e6fb6e0881
2 changed files with 14 additions and 4 deletions

View File

@ -232,7 +232,8 @@ _jax_types = [
np.dtype('float64'),
np.dtype('complex64'),
np.dtype('complex128'),
] + _weak_types # type: ignore[operator]
]
_jax_dtype_set = set(_jax_types) | {float0}
def _jax_type(dtype, weak_type):
"""Return the jax type for a dtype and weak type."""
@ -247,7 +248,8 @@ def _type_promotion_lattice():
Return the type promotion lattice in the form of a DAG.
This DAG maps each type to its immediately higher type on the lattice.
"""
b1, u1, u2, u4, u8, i1, i2, i4, i8, bf, f2, f4, f8, c4, c8, i_, f_, c_ = _jax_types
b1, u1, u2, u4, u8, i1, i2, i4, i8, bf, f2, f4, f8, c4, c8 = _jax_types
i_, f_, c_ = _weak_types
return {
b1: [i_],
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
@ -275,7 +277,7 @@ def _least_upper_bound(*nodes):
"""Compute the least upper bound of a set of nodes.
Args:
nodes: sequence of entries from _jax_types
nodes: sequence of entries from _jax_types + _weak_types
Returns:
the _jax_type representing the least upper bound of the input nodes
on the promotion lattice.
@ -337,7 +339,11 @@ def is_python_scalar(x):
def dtype(x):
if type(x) in python_scalar_dtypes:
return python_scalar_dtypes[type(x)]
return np.result_type(x)
dt = np.result_type(x)
if dt not in _jax_dtype_set:
raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
"type. Only arrays of numeric types are supported by JAX.")
return dt
def _lattice_result_type(*args):
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))

View File

@ -1812,6 +1812,10 @@ class APITest(jtu.JaxTestCase):
check_warning(lambda: jnp.arange(1.0).astype("int64"),
lambda: jnp.arange(1.0).astype(int))
def test_error_for_invalid_dtype(self):
with self.assertRaisesRegex(TypeError, ".*not a valid JAX array type.*"):
lax.add(jnp.array(7), np.array("hello"))
def test_vmap_preserves_docstr(self):
def superfun(a):
"""Does things with stuff."""