Check for unsupported dtypes and issue a helpful error. (#2885)

This commit is contained in:
Peter Hawkins 2020-04-29 14:14:49 -04:00 committed by GitHub
parent 52c69e88c5
commit 0557248fbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 1 deletions

View File

@ -5173,8 +5173,13 @@ def _abstractify(x):
return raise_to_shaped(core.get_aval(x))
def _check_user_dtype_supported(dtype, fun_name=None):
if dtype is not None and onp.dtype(dtype) != dtypes.canonicalize_dtype(dtype):
onp_dtype = onp.dtype(dtype)
if onp_dtype.kind not in "biufc" and onp_dtype.type != dtypes.bfloat16:
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
raise TypeError(msg)
if dtype is not None and onp_dtype != dtypes.canonicalize_dtype(dtype):
msg = ("Explicitly requested dtype {} {} is not available, "
"and will be truncated to dtype {}. To enable more dtypes, set the "
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "

View File

@ -1728,6 +1728,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
def testArrayUnsupportedDtypeError(self):
with self.assertRaisesRegex(TypeError,
"JAX only supports number and bool dtypes.*"):
jnp.array(3, [('a','<i4'),('b','<i4')])
def testIssue121(self):
assert not onp.isscalar(jnp.array(3))