mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve type error when an object dtype is passed to an operator like +.
Fixes #856.
This commit is contained in:
parent
fa9ca33e60
commit
e6fb6e0881
@ -232,7 +232,8 @@ _jax_types = [
|
||||
np.dtype('float64'),
|
||||
np.dtype('complex64'),
|
||||
np.dtype('complex128'),
|
||||
] + _weak_types # type: ignore[operator]
|
||||
]
|
||||
_jax_dtype_set = set(_jax_types) | {float0}
|
||||
|
||||
def _jax_type(dtype, weak_type):
|
||||
"""Return the jax type for a dtype and weak type."""
|
||||
@ -247,7 +248,8 @@ def _type_promotion_lattice():
|
||||
Return the type promotion lattice in the form of a DAG.
|
||||
This DAG maps each type to its immediately higher type on the lattice.
|
||||
"""
|
||||
b1, u1, u2, u4, u8, i1, i2, i4, i8, bf, f2, f4, f8, c4, c8, i_, f_, c_ = _jax_types
|
||||
b1, u1, u2, u4, u8, i1, i2, i4, i8, bf, f2, f4, f8, c4, c8 = _jax_types
|
||||
i_, f_, c_ = _weak_types
|
||||
return {
|
||||
b1: [i_],
|
||||
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
|
||||
@ -275,7 +277,7 @@ def _least_upper_bound(*nodes):
|
||||
"""Compute the least upper bound of a set of nodes.
|
||||
|
||||
Args:
|
||||
nodes: sequence of entries from _jax_types
|
||||
nodes: sequence of entries from _jax_types + _weak_types
|
||||
Returns:
|
||||
the _jax_type representing the least upper bound of the input nodes
|
||||
on the promotion lattice.
|
||||
@ -337,7 +339,11 @@ def is_python_scalar(x):
|
||||
def dtype(x):
|
||||
if type(x) in python_scalar_dtypes:
|
||||
return python_scalar_dtypes[type(x)]
|
||||
return np.result_type(x)
|
||||
dt = np.result_type(x)
|
||||
if dt not in _jax_dtype_set:
|
||||
raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
|
||||
"type. Only arrays of numeric types are supported by JAX.")
|
||||
return dt
|
||||
|
||||
def _lattice_result_type(*args):
|
||||
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
|
||||
|
@ -1812,6 +1812,10 @@ class APITest(jtu.JaxTestCase):
|
||||
check_warning(lambda: jnp.arange(1.0).astype("int64"),
|
||||
lambda: jnp.arange(1.0).astype(int))
|
||||
|
||||
def test_error_for_invalid_dtype(self):
|
||||
with self.assertRaisesRegex(TypeError, ".*not a valid JAX array type.*"):
|
||||
lax.add(jnp.array(7), np.array("hello"))
|
||||
|
||||
def test_vmap_preserves_docstr(self):
|
||||
def superfun(a):
|
||||
"""Does things with stuff."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user