Second attempt at fixing warnings from jax.dtypes.issubdtype.

PiperOrigin-RevId: 532902714
This commit is contained in:
Peter Hawkins 2023-05-17 14:09:42 -07:00 committed by jax authors
parent 2aa2282ea1
commit 389564551b
2 changed files with 25 additions and 4 deletions

View File

@ -221,6 +221,25 @@ def _issubclass(a: Any, b: Any) -> bool:
except TypeError:
return False
_type_classes = {
np.generic,
np.number,
np.flexible,
np.character,
np.integer,
np.signedinteger,
np.unsignedinteger,
np.inexact,
np.floating,
np.complexfloating,
}
def _is_typeclass(a: Any) -> bool:
try:
return a in _type_classes
except TypeError:
return False
def issubdtype(a: DTypeLike, b: DTypeLike) -> bool:
"""Returns True if first argument is a typecode lower/equal in type hierarchy.
@ -229,13 +248,14 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool:
"""
if is_opaque_dtype(a):
return a == b
a = a if _issubclass(a, np.generic) else np.dtype(a)
b = b if _issubclass(b, np.generic) else np.dtype(b)
if a in _custom_float_dtypes:
# Canonicalizes all concrete types to np.dtype instances
a = a if _is_typeclass(a) else np.dtype(a)
b = b if _is_typeclass(b) else np.dtype(b)
if isinstance(a, np.dtype) and a in _custom_float_dtypes:
# Avoid implicitly casting list elements below to a dtype.
if isinstance(b, np.dtype):
return a == b
return b in [a, np.floating, np.inexact, np.number]
return b in [np.floating, np.inexact, np.number]
return np.issubdtype(a, b)
can_cast = np.can_cast

View File

@ -275,6 +275,7 @@ class DtypesTest(jtu.JaxTestCase):
self.assertTrue(dtypes.issubdtype(dt, np.inexact))
self.assertTrue(dtypes.issubdtype(dt, np.number))
self.assertFalse(dtypes.issubdtype(dt, np.float64))
self.assertFalse(dtypes.issubdtype(np.generic, dt))
def testArrayCasts(self):
for t in [jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64]: