Add tests of dtypes.finfo properties

This commit is contained in:
Jake VanderPlas 2023-01-09 16:43:10 -08:00
parent 10847a9372
commit 31fd81f2d5
2 changed files with 53 additions and 0 deletions

View File

@ -186,6 +186,8 @@ class _Bfloat16MachArLike:
def __init__(self):
smallest_normal = float.fromhex("0x1p-126")
self.smallest_normal = bfloat16(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-133")
self.smallest_subnormal = bfloat16(smallest_subnormal)
class finfo(np.finfo):
@ -215,11 +217,15 @@ class finfo(np.finfo):
obj.nexp = 8
obj.nmant = 7
obj.iexp = obj.nexp
obj.maxexp = 128
obj.precision = 2
obj.resolution = bfloat16(resolution)
obj._machar = _Bfloat16MachArLike()
if not hasattr(obj, "tiny"):
obj.tiny = bfloat16(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)

View File

@ -294,6 +294,53 @@ class DtypesTest(jtu.JaxTestCase):
self.assertEqual(dtypes.float_, np.float32 if precision == '32' else np.float64)
self.assertEqual(dtypes.complex_, np.complex64 if precision == '32' else np.complex128)
@parameterized.named_parameters(
{"testcase_name": f"_{dtype}", "dtype": dtype}
for dtype in float_dtypes)
def testFInfo(self, dtype):
# Check that finfo attributes are self-consistent & reflect observed behavior.
dtype = np.dtype(dtype)
if dtype == np.float64 and not config.x64_enabled:
self.skipTest("x64 not enabled")
info = dtypes.finfo(dtype)
def make_val(val):
return jnp.array(val, dtype=dtype)
def assertRepresentable(val):
self.assertEqual(make_val(val).item(), val)
@jtu.ignore_warning(category=RuntimeWarning, message="overflow")
def assertInfinite(val):
self.assertEqual(make_val(val), make_val(jnp.inf))
def assertZero(val):
self.assertEqual(make_val(val), make_val(0))
self.assertEqual(jnp.array(0, dtype).dtype, dtype)
self.assertIs(info.dtype, dtype)
self.assertEqual(info.bits, jnp.array(0, dtype).itemsize * 8)
self.assertEqual(info.nmant + info.nexp + 1, info.bits)
assertRepresentable(info.tiny)
assertRepresentable(info.smallest_subnormal)
assertRepresentable(info.max)
assertRepresentable(2.0 ** (info.maxexp - 1))
if dtype != np.float64: # avoid Python float overflows
assertInfinite(info.max * 2)
assertInfinite(2. ** info.maxexp)
assertZero(info.smallest_subnormal * 0.5)
# Identities according to the documentation:
self.assertAllClose(info.resolution, make_val(10 ** -info.precision))
self.assertEqual(info.tiny, info.smallest_normal)
self.assertEqual(info.epsneg, make_val(2 ** info.negep))
self.assertEqual(info.eps, make_val(2 ** info.machep))
self.assertEqual(info.iexp, info.nexp)
class TestPromotionTables(jtu.JaxTestCase):