mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Check for unsupported dtypes and issue a helpful error. (#2885)
This commit is contained in:
parent
52c69e88c5
commit
0557248fbd
@ -5173,8 +5173,13 @@ def _abstractify(x):
|
||||
return raise_to_shaped(core.get_aval(x))
|
||||
|
||||
|
||||
|
||||
def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
if dtype is not None and onp.dtype(dtype) != dtypes.canonicalize_dtype(dtype):
|
||||
onp_dtype = onp.dtype(dtype)
|
||||
if onp_dtype.kind not in "biufc" and onp_dtype.type != dtypes.bfloat16:
|
||||
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
|
||||
raise TypeError(msg)
|
||||
if dtype is not None and onp_dtype != dtypes.canonicalize_dtype(dtype):
|
||||
msg = ("Explicitly requested dtype {} {} is not available, "
|
||||
"and will be truncated to dtype {}. To enable more dtypes, set the "
|
||||
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
|
||||
|
@ -1728,6 +1728,11 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
def testArrayUnsupportedDtypeError(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"JAX only supports number and bool dtypes.*"):
|
||||
jnp.array(3, [('a','<i4'),('b','<i4')])
|
||||
|
||||
def testIssue121(self):
|
||||
assert not onp.isscalar(jnp.array(3))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user