mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix test_util for new float8 type
This commit is contained in:
parent
0ec9f3c2df
commit
8d165193be
@ -85,6 +85,11 @@ default_gradient_tolerance = {
|
||||
np.dtype(np.complex128): 1e-5,
|
||||
}
|
||||
|
||||
# TODO(jakevdp): make this unconditional when ml_dtypes>=0.2 is required
|
||||
if _dtypes.float8_e4m3b11fnuz is not None:
|
||||
_default_tolerance[np.dtype(_dtypes.float8_e4m3b11fnuz)] = 1e-1
|
||||
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3b11fnuz)] = 1e-1
|
||||
|
||||
def is_python_scalar(val):
|
||||
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
|
||||
|
||||
@ -93,7 +98,9 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
|
||||
np.testing.assert_array_equal(a, b, err_msg=err_msg)
|
||||
return
|
||||
custom_dtypes = [_dtypes.float8_e4m3fn, _dtypes.float8_e5m2, _dtypes.bfloat16]
|
||||
custom_dtypes = [_dtypes.bfloat16]
|
||||
# TODO(jakevdp): make this unconditional when ml_dtypes>=0.2 is required
|
||||
if _dtypes.float8_e4m3b11fnuz is not None:
|
||||
custom_dtypes.insert(0, _dtypes.float8_e4m3b11fnuz)
|
||||
a = a.astype(np.float32) if a.dtype in custom_dtypes else a
|
||||
b = b.astype(np.float32) if b.dtype in custom_dtypes else b
|
||||
kw = {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user