mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Second attempt at fixing warnings from jax.dtypes.issubdtype.
PiperOrigin-RevId: 532902714
This commit is contained in:
parent
2aa2282ea1
commit
389564551b
@ -221,6 +221,25 @@ def _issubclass(a: Any, b: Any) -> bool:
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
_type_classes = {
|
||||
np.generic,
|
||||
np.number,
|
||||
np.flexible,
|
||||
np.character,
|
||||
np.integer,
|
||||
np.signedinteger,
|
||||
np.unsignedinteger,
|
||||
np.inexact,
|
||||
np.floating,
|
||||
np.complexfloating,
|
||||
}
|
||||
|
||||
def _is_typeclass(a: Any) -> bool:
|
||||
try:
|
||||
return a in _type_classes
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def issubdtype(a: DTypeLike, b: DTypeLike) -> bool:
|
||||
"""Returns True if first argument is a typecode lower/equal in type hierarchy.
|
||||
|
||||
@ -229,13 +248,14 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool:
|
||||
"""
|
||||
if is_opaque_dtype(a):
|
||||
return a == b
|
||||
a = a if _issubclass(a, np.generic) else np.dtype(a)
|
||||
b = b if _issubclass(b, np.generic) else np.dtype(b)
|
||||
if a in _custom_float_dtypes:
|
||||
# Canonicalizes all concrete types to np.dtype instances
|
||||
a = a if _is_typeclass(a) else np.dtype(a)
|
||||
b = b if _is_typeclass(b) else np.dtype(b)
|
||||
if isinstance(a, np.dtype) and a in _custom_float_dtypes:
|
||||
# Avoid implicitly casting list elements below to a dtype.
|
||||
if isinstance(b, np.dtype):
|
||||
return a == b
|
||||
return b in [a, np.floating, np.inexact, np.number]
|
||||
return b in [np.floating, np.inexact, np.number]
|
||||
return np.issubdtype(a, b)
|
||||
|
||||
can_cast = np.can_cast
|
||||
|
@ -275,6 +275,7 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
self.assertTrue(dtypes.issubdtype(dt, np.inexact))
|
||||
self.assertTrue(dtypes.issubdtype(dt, np.number))
|
||||
self.assertFalse(dtypes.issubdtype(dt, np.float64))
|
||||
self.assertFalse(dtypes.issubdtype(np.generic, dt))
|
||||
|
||||
def testArrayCasts(self):
|
||||
for t in [jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user