mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
BUG: avoid warning when specifying dtype=complex in X32 mode
This commit is contained in:
parent
b51ee3752e
commit
0a531ac76f
@ -4535,7 +4535,7 @@ def _abstractify(x):
|
||||
|
||||
def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
# Avoid using `dtype in [...]` because of numpy dtype equality overloading.
|
||||
if isinstance(dtype, type) and dtype in {bool, int, float, complex}:
|
||||
if isinstance(dtype, type) and dtype in {bool, int, float, builtins.complex}:
|
||||
return
|
||||
np_dtype = np.dtype(dtype)
|
||||
if np_dtype.kind not in "biufc" and np_dtype.type != dtypes.bfloat16:
|
||||
|
@ -2461,6 +2461,13 @@ class APITest(jtu.JaxTestCase):
|
||||
for x, y in zip(xs, ys):
|
||||
self.assertAllClose(x, y)
|
||||
|
||||
def test_dtype_from_builtin_types(self):
|
||||
for dtype in [bool, int, float, complex]:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
x = jnp.array(0, dtype=dtype)
|
||||
self.assertEmpty(caught_warnings)
|
||||
assert x.dtype == dtypes.canonicalize_dtype(dtype)
|
||||
|
||||
def test_dtype_warning(self):
|
||||
# cf. issue #1230
|
||||
if config.x64_enabled:
|
||||
|
Loading…
x
Reference in New Issue
Block a user