mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add tests of dtypes.finfo properties
This commit is contained in:
parent
10847a9372
commit
31fd81f2d5
@ -186,6 +186,8 @@ class _Bfloat16MachArLike:
|
||||
def __init__(self):
|
||||
smallest_normal = float.fromhex("0x1p-126")
|
||||
self.smallest_normal = bfloat16(smallest_normal)
|
||||
smallest_subnormal = float.fromhex("0x1p-133")
|
||||
self.smallest_subnormal = bfloat16(smallest_subnormal)
|
||||
|
||||
|
||||
class finfo(np.finfo):
|
||||
@ -215,11 +217,15 @@ class finfo(np.finfo):
|
||||
obj.nexp = 8
|
||||
obj.nmant = 7
|
||||
obj.iexp = obj.nexp
|
||||
obj.maxexp = 128
|
||||
obj.precision = 2
|
||||
obj.resolution = bfloat16(resolution)
|
||||
obj._machar = _Bfloat16MachArLike()
|
||||
if not hasattr(obj, "tiny"):
|
||||
obj.tiny = bfloat16(tiny)
|
||||
if not hasattr(obj, "smallest_normal"):
|
||||
obj.smallest_normal = obj._machar.smallest_normal
|
||||
obj.smallest_subnormal = obj._machar.smallest_subnormal
|
||||
|
||||
obj._str_tiny = float_to_str(tiny)
|
||||
obj._str_smallest_normal = float_to_str(tiny)
|
||||
|
@ -294,6 +294,53 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(dtypes.float_, np.float32 if precision == '32' else np.float64)
|
||||
self.assertEqual(dtypes.complex_, np.complex64 if precision == '32' else np.complex128)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{dtype}", "dtype": dtype}
|
||||
for dtype in float_dtypes)
|
||||
def testFInfo(self, dtype):
|
||||
# Check that finfo attributes are self-consistent & reflect observed behavior.
|
||||
dtype = np.dtype(dtype)
|
||||
|
||||
if dtype == np.float64 and not config.x64_enabled:
|
||||
self.skipTest("x64 not enabled")
|
||||
info = dtypes.finfo(dtype)
|
||||
|
||||
def make_val(val):
|
||||
return jnp.array(val, dtype=dtype)
|
||||
|
||||
def assertRepresentable(val):
|
||||
self.assertEqual(make_val(val).item(), val)
|
||||
|
||||
@jtu.ignore_warning(category=RuntimeWarning, message="overflow")
|
||||
def assertInfinite(val):
|
||||
self.assertEqual(make_val(val), make_val(jnp.inf))
|
||||
|
||||
def assertZero(val):
|
||||
self.assertEqual(make_val(val), make_val(0))
|
||||
|
||||
self.assertEqual(jnp.array(0, dtype).dtype, dtype)
|
||||
self.assertIs(info.dtype, dtype)
|
||||
|
||||
self.assertEqual(info.bits, jnp.array(0, dtype).itemsize * 8)
|
||||
self.assertEqual(info.nmant + info.nexp + 1, info.bits)
|
||||
|
||||
assertRepresentable(info.tiny)
|
||||
assertRepresentable(info.smallest_subnormal)
|
||||
assertRepresentable(info.max)
|
||||
assertRepresentable(2.0 ** (info.maxexp - 1))
|
||||
|
||||
if dtype != np.float64: # avoid Python float overflows
|
||||
assertInfinite(info.max * 2)
|
||||
assertInfinite(2. ** info.maxexp)
|
||||
assertZero(info.smallest_subnormal * 0.5)
|
||||
|
||||
# Identities according to the documentation:
|
||||
self.assertAllClose(info.resolution, make_val(10 ** -info.precision))
|
||||
self.assertEqual(info.tiny, info.smallest_normal)
|
||||
self.assertEqual(info.epsneg, make_val(2 ** info.negep))
|
||||
self.assertEqual(info.eps, make_val(2 ** info.machep))
|
||||
self.assertEqual(info.iexp, info.nexp)
|
||||
|
||||
|
||||
class TestPromotionTables(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user