Fix test_util for new float8 type

This commit is contained in:
Jake VanderPlas 2023-06-08 00:30:45 -07:00
parent 0ec9f3c2df
commit 8d165193be

View File

@ -85,6 +85,11 @@ default_gradient_tolerance = {
np.dtype(np.complex128): 1e-5,
}
# TODO(jakevdp): make this unconditional when ml_dtypes>=0.2 is required
if _dtypes.float8_e4m3b11fnuz is not None:
_default_tolerance[np.dtype(_dtypes.float8_e4m3b11fnuz)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3b11fnuz)] = 1e-1
def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
@ -93,7 +98,9 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
np.testing.assert_array_equal(a, b, err_msg=err_msg)
return
custom_dtypes = [_dtypes.float8_e4m3fn, _dtypes.float8_e5m2, _dtypes.bfloat16]
custom_dtypes = [_dtypes.bfloat16]
# TODO(jakevdp): make this unconditional when ml_dtypes>=0.2 is required
if _dtypes.float8_e4m3b11fnuz is not None:
custom_dtypes.insert(0, _dtypes.float8_e4m3b11fnuz)
a = a.astype(np.float32) if a.dtype in custom_dtypes else a
b = b.astype(np.float32) if b.dtype in custom_dtypes else b
kw = {}