BUG: avoid warning when specifying dtype=complex in X32 mode

This commit is contained in:
Jake VanderPlas 2022-06-15 15:02:45 -07:00
parent b51ee3752e
commit 0a531ac76f
2 changed files with 8 additions and 1 deletions

View File

@ -4535,7 +4535,7 @@ def _abstractify(x):
def _check_user_dtype_supported(dtype, fun_name=None):
# Avoid using `dtype in [...]` because of numpy dtype equality overloading.
if isinstance(dtype, type) and dtype in {bool, int, float, complex}:
if isinstance(dtype, type) and dtype in {bool, int, float, builtins.complex}:
return
np_dtype = np.dtype(dtype)
if np_dtype.kind not in "biufc" and np_dtype.type != dtypes.bfloat16:

View File

@ -2461,6 +2461,13 @@ class APITest(jtu.JaxTestCase):
for x, y in zip(xs, ys):
self.assertAllClose(x, y)
def test_dtype_from_builtin_types(self):
for dtype in [bool, int, float, complex]:
with warnings.catch_warnings(record=True) as caught_warnings:
x = jnp.array(0, dtype=dtype)
self.assertEmpty(caught_warnings)
assert x.dtype == dtypes.canonicalize_dtype(dtype)
def test_dtype_warning(self):
# cf. issue #1230
if config.x64_enabled: