mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Evaluate the correctness of JAX complex functions using mpmath as a reference
This commit is contained in:
parent
2848cda34c
commit
fdb5015909
@ -4,6 +4,7 @@ cloudpickle
|
|||||||
colorama>=0.4.4
|
colorama>=0.4.4
|
||||||
flatbuffers
|
flatbuffers
|
||||||
hypothesis
|
hypothesis
|
||||||
|
mpmath>=1.3
|
||||||
numpy>=1.22
|
numpy>=1.22
|
||||||
pillow>=9.1.0
|
pillow>=9.1.0
|
||||||
portpicker
|
portpicker
|
||||||
|
@ -1467,11 +1467,11 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
|
|||||||
>>> print(complex_plane_sample(np.complex64, 0, 3))
|
>>> print(complex_plane_sample(np.complex64, 0, 3))
|
||||||
[[-inf -infj 0. -infj inf -infj]
|
[[-inf -infj 0. -infj inf -infj]
|
||||||
[-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j]
|
[-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j]
|
||||||
[-inf-2.0000052e+00j 0.-2.0000052e+00j inf-2.0000052e+00j]
|
[-inf-2.0000000e+00j 0.-2.0000000e+00j inf-2.0000000e+00j]
|
||||||
[-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j]
|
[-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j]
|
||||||
[-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j]
|
[-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j]
|
||||||
[-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j]
|
[-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j]
|
||||||
[-inf+2.0000052e+00j 0.+2.0000052e+00j inf+2.0000052e+00j]
|
[-inf+2.0000000e+00j 0.+2.0000000e+00j inf+2.0000000e+00j]
|
||||||
[-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j]
|
[-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j]
|
||||||
[-inf +infj 0. +infj inf +infj]]
|
[-inf +infj 0. +infj inf +infj]]
|
||||||
|
|
||||||
@ -1481,16 +1481,18 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
|
|||||||
finfo = np.finfo(dtype)
|
finfo = np.finfo(dtype)
|
||||||
|
|
||||||
def make_axis_points(size):
|
def make_axis_points(size):
|
||||||
logmin = np.log10(abs(finfo.min))
|
prec_dps_ratio = 3.3219280948873626
|
||||||
logtiny = np.log10(finfo.tiny)
|
logmin = logmax = finfo.maxexp / prec_dps_ratio
|
||||||
logmax = np.log10(finfo.max)
|
logtiny = finfo.minexp / prec_dps_ratio
|
||||||
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)
|
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
# Silence RuntimeWarning: overflow encountered in cast
|
# Silence RuntimeWarning: overflow encountered in cast
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
axis_points[1:size + 1] = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
|
half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
|
||||||
axis_points[-size - 1:-1] = np.logspace(logtiny, logmax, size, dtype=finfo.dtype)
|
half_line = -half_neg_line[::-1]
|
||||||
|
axis_points[-size - 1:-1] = half_line
|
||||||
|
axis_points[1:size + 1] = half_neg_line
|
||||||
|
|
||||||
if size > 1:
|
if size > 1:
|
||||||
axis_points[1] = finfo.min
|
axis_points[1] = finfo.min
|
||||||
@ -1512,3 +1514,302 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
|
|||||||
imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1)
|
imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1)
|
||||||
|
|
||||||
return real_part + imag_part
|
return real_part + imag_part
|
||||||
|
|
||||||
|
|
||||||
|
class vectorize_with_mpmath(np.vectorize):
|
||||||
|
"""Same as numpy.vectorize but using mpmath backend for function evaluation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
map_float_to_complex = dict(float16='complex32', float32='complex64', float64='complex128', float128='complex256', longdouble='clongdouble')
|
||||||
|
map_complex_to_float = {v: k for k, v in map_float_to_complex.items()}
|
||||||
|
|
||||||
|
float_prec = dict(
|
||||||
|
# float16=11,
|
||||||
|
float32=24,
|
||||||
|
float64=53,
|
||||||
|
# float128=113,
|
||||||
|
# longdouble=113
|
||||||
|
)
|
||||||
|
|
||||||
|
float_minexp = dict(
|
||||||
|
float16=-14,
|
||||||
|
float32=-126,
|
||||||
|
float64=-1022,
|
||||||
|
float128=-16382
|
||||||
|
)
|
||||||
|
|
||||||
|
float_maxexp = dict(
|
||||||
|
float16=16,
|
||||||
|
float32=128,
|
||||||
|
float64=1024,
|
||||||
|
float128=16384,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
mpmath = kwargs.pop('mpmath', None)
|
||||||
|
if mpmath is None:
|
||||||
|
raise ValueError('vectorize_with_mpmath: no mpmath argument specified')
|
||||||
|
self.extra_prec_multiplier = kwargs.pop('extra_prec_multiplier', 0)
|
||||||
|
self.extra_prec = kwargs.pop('extra_prec', 0)
|
||||||
|
self.mpmath = mpmath
|
||||||
|
self.contexts = dict()
|
||||||
|
self.contexts_inv = dict()
|
||||||
|
for fp_format, prec in self.float_prec.items():
|
||||||
|
ctx = self.mpmath.mp.clone()
|
||||||
|
ctx.prec = prec
|
||||||
|
self.contexts[fp_format] = ctx
|
||||||
|
self.contexts_inv[ctx] = fp_format
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_context(self, x):
|
||||||
|
if isinstance(x, (np.ndarray, np.floating, np.complexfloating)):
|
||||||
|
fp_format = str(x.dtype)
|
||||||
|
fp_format = self.map_complex_to_float.get(fp_format, fp_format)
|
||||||
|
return self.contexts[fp_format]
|
||||||
|
raise NotImplementedError(f'get mpmath context from {type(x).__name__} instance')
|
||||||
|
|
||||||
|
def nptomp(self, x):
|
||||||
|
"""Convert numpy array/scalar to an array/instance of mpmath number type.
|
||||||
|
"""
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return np.fromiter(map(self.nptomp, x.flatten()), dtype=object).reshape(x.shape)
|
||||||
|
elif isinstance(x, np.floating):
|
||||||
|
mpmath = self.mpmath
|
||||||
|
ctx = self.get_context(x)
|
||||||
|
prec, rounding = ctx._prec_rounding
|
||||||
|
if np.isposinf(x):
|
||||||
|
return ctx.make_mpf(mpmath.libmp.finf)
|
||||||
|
elif np.isneginf(x):
|
||||||
|
return ctx.make_mpf(mpmath.libmp.fninf)
|
||||||
|
elif np.isnan(x):
|
||||||
|
return ctx.make_mpf(mpmath.libmp.fnan)
|
||||||
|
elif np.isfinite(x):
|
||||||
|
mantissa, exponent = np.frexp(x)
|
||||||
|
man = int(np.ldexp(mantissa, prec))
|
||||||
|
exp = int(exponent - prec)
|
||||||
|
r = ctx.make_mpf(mpmath.libmp.from_man_exp(man, exp, prec, rounding))
|
||||||
|
assert ctx.isfinite(r), r._mpf_
|
||||||
|
return r
|
||||||
|
elif isinstance(x, np.complexfloating):
|
||||||
|
re, im = self.nptomp(x.real), self.nptomp(x.imag)
|
||||||
|
return re.context.make_mpc((re._mpf_, im._mpf_))
|
||||||
|
raise NotImplementedError(f'convert {type(x).__name__} instance to mpmath number type')
|
||||||
|
|
||||||
|
def mptonp(self, x):
|
||||||
|
"""Convert mpmath instance to numpy array/scalar type.
|
||||||
|
"""
|
||||||
|
if isinstance(x, np.ndarray) and x.dtype.kind == 'O':
|
||||||
|
x_flat = x.flatten()
|
||||||
|
item = x_flat[0]
|
||||||
|
ctx = item.context
|
||||||
|
fp_format = self.contexts_inv[ctx]
|
||||||
|
if isinstance(item, ctx.mpc):
|
||||||
|
dtype = getattr(np, self.map_float_to_complex[fp_format])
|
||||||
|
elif isinstance(item, ctx.mpf):
|
||||||
|
dtype = getattr(np, fp_format)
|
||||||
|
else:
|
||||||
|
dtype = None
|
||||||
|
if dtype is not None:
|
||||||
|
return np.fromiter(map(self.mptonp, x_flat), dtype=dtype).reshape(x.shape)
|
||||||
|
elif isinstance(x, self.mpmath.ctx_mp.mpnumeric):
|
||||||
|
ctx = x.context
|
||||||
|
if isinstance(x, ctx.mpc):
|
||||||
|
fp_format = self.contexts_inv[ctx]
|
||||||
|
dtype = getattr(np, self.map_float_to_complex[fp_format])
|
||||||
|
r = dtype().reshape(1).view(getattr(np, fp_format))
|
||||||
|
r[0] = self.mptonp(x.real)
|
||||||
|
r[1] = self.mptonp(x.imag)
|
||||||
|
return r.view(dtype)[0]
|
||||||
|
elif isinstance(x, ctx.mpf):
|
||||||
|
fp_format = self.contexts_inv[ctx]
|
||||||
|
dtype = getattr(np, fp_format)
|
||||||
|
if ctx.isfinite(x):
|
||||||
|
sign, man, exp, bc = self.mpmath.libmp.normalize(*x._mpf_, *ctx._prec_rounding)
|
||||||
|
assert bc >= 0, (sign, man, exp, bc, x._mpf_)
|
||||||
|
if exp + bc < self.float_minexp[fp_format]:
|
||||||
|
return -ctx.zero if sign else ctx.zero
|
||||||
|
if exp + bc > self.float_maxexp[fp_format]:
|
||||||
|
return ctx.ninf if sign else ctx.inf
|
||||||
|
man = dtype(-man if sign else man)
|
||||||
|
r = np.ldexp(man, exp)
|
||||||
|
assert np.isfinite(r), (x, r, x._mpf_, man)
|
||||||
|
return r
|
||||||
|
elif ctx.isnan(x):
|
||||||
|
return dtype(np.nan)
|
||||||
|
elif ctx.isinf(x):
|
||||||
|
return dtype(-np.inf if x._mpf_[0] else np.inf)
|
||||||
|
raise NotImplementedError(f'convert {type(x)} instance to numpy floating point type')
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
mp_args = []
|
||||||
|
context = None
|
||||||
|
for a in args:
|
||||||
|
if isinstance(a, (np.ndarray, np.floating, np.complexfloating)):
|
||||||
|
mp_args.append(self.nptomp(a))
|
||||||
|
if context is None:
|
||||||
|
context = self.get_context(a)
|
||||||
|
else:
|
||||||
|
assert context is self.get_context(a)
|
||||||
|
else:
|
||||||
|
mp_args.append(a)
|
||||||
|
|
||||||
|
extra_prec = int(context.prec * self.extra_prec_multiplier) + self.extra_prec
|
||||||
|
with context.extraprec(extra_prec):
|
||||||
|
result = super().__call__(*mp_args, **kwargs)
|
||||||
|
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
lst = []
|
||||||
|
for r in result:
|
||||||
|
if ((isinstance(r, np.ndarray) and r.dtype.kind == 'O')
|
||||||
|
or isinstance(r, self.mpmath.ctx_mp.mpnumeric)):
|
||||||
|
r = self.mptonp(r)
|
||||||
|
lst.append(r)
|
||||||
|
return tuple(lst)
|
||||||
|
|
||||||
|
if ((isinstance(result, np.ndarray) and result.dtype.kind == 'O')
|
||||||
|
or isinstance(result, self.mpmath.ctx_mp.mpnumeric)):
|
||||||
|
return self.mptonp(result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class numpy_with_mpmath:
|
||||||
|
"""Namespace of universal functions on numpy arrays that use mpmath
|
||||||
|
backend for evaluation and return numpy arrays as outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_provides = [
|
||||||
|
'abs', 'absolute', 'sqrt', 'exp', 'expm1', 'exp2',
|
||||||
|
'log', 'log1p', 'log10', 'log2',
|
||||||
|
'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
|
||||||
|
'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
|
||||||
|
'square', 'positive', 'negative', 'conjugate', 'sign', 'sinc',
|
||||||
|
'normalize',
|
||||||
|
]
|
||||||
|
|
||||||
|
_mp_names = dict(
|
||||||
|
abs='absmin', absolute='absmin',
|
||||||
|
log='ln',
|
||||||
|
arcsin='asin', arccos='acos', arctan='atan',
|
||||||
|
arcsinh='asinh', arccosh='acosh', arctanh='atanh',
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, mpmath, extra_prec_multiplier=0, extra_prec=0):
|
||||||
|
self.mpmath = mpmath
|
||||||
|
|
||||||
|
for name in self._provides:
|
||||||
|
mp_name = self._mp_names.get(name, name)
|
||||||
|
|
||||||
|
if hasattr(self, name):
|
||||||
|
op = getattr(self, name)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def op(x, mp_name=mp_name):
|
||||||
|
return getattr(x.context, mp_name)(x)
|
||||||
|
|
||||||
|
setattr(self, name, vectorize_with_mpmath(op, mpmath=mpmath, extra_prec_multiplier=extra_prec_multiplier, extra_prec=extra_prec))
|
||||||
|
|
||||||
|
# The following function methods operate on mpmath number instances.
|
||||||
|
# The corresponding function names must be listed in
|
||||||
|
# numpy_with_mpmath._provides list.
|
||||||
|
|
||||||
|
def square(self, x):
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
def positive(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def negative(self, x):
|
||||||
|
return -x
|
||||||
|
|
||||||
|
def sqrt(self, x):
|
||||||
|
ctx = x.context
|
||||||
|
# workaround mpmath bugs:
|
||||||
|
if isinstance(x, ctx.mpc):
|
||||||
|
if ctx.isinf(x.real) and ctx.isinf(x.imag):
|
||||||
|
if x.real > 0: return x
|
||||||
|
ninf = x.real
|
||||||
|
inf = -ninf
|
||||||
|
if x.imag > 0: return ctx.make_mpc((inf._mpf_, inf._mpf_))
|
||||||
|
return ctx.make_mpc((inf._mpf_, inf._mpf_))
|
||||||
|
elif ctx.isfinite(x.real) and ctx.isinf(x.imag):
|
||||||
|
if x.imag > 0:
|
||||||
|
inf = x.imag
|
||||||
|
return ctx.make_mpc((inf._mpf_, inf._mpf_))
|
||||||
|
else:
|
||||||
|
ninf = x.imag
|
||||||
|
inf = -ninf
|
||||||
|
return ctx.make_mpc((inf._mpf_, ninf._mpf_))
|
||||||
|
return ctx.sqrt(x)
|
||||||
|
|
||||||
|
def expm1(self, x):
|
||||||
|
return x.context.expm1(x)
|
||||||
|
|
||||||
|
def log2(self, x):
|
||||||
|
return x.context.ln(x) / x.context.ln2
|
||||||
|
|
||||||
|
def log10(self, x):
|
||||||
|
return x.context.ln(x) / x.context.ln10
|
||||||
|
|
||||||
|
def exp2(self, x):
|
||||||
|
return x.context.exp(x * x.context.ln2)
|
||||||
|
|
||||||
|
def normalize(self, exact, reference, value):
|
||||||
|
"""Normalize reference and value using precision defined by the
|
||||||
|
difference of exact and reference.
|
||||||
|
"""
|
||||||
|
def worker(ctx, s, e, r, v):
|
||||||
|
ss, sm, se, sbc = s._mpf_
|
||||||
|
es, em, ee, ebc = e._mpf_
|
||||||
|
rs, rm, re, rbc = r._mpf_
|
||||||
|
vs, vm, ve, vbc = v._mpf_
|
||||||
|
|
||||||
|
if not (ctx.isfinite(e) and ctx.isfinite(r) and ctx.isfinite(v)):
|
||||||
|
return r, v
|
||||||
|
|
||||||
|
me = min(se, ee, re, ve)
|
||||||
|
|
||||||
|
# transform mantissa parts to the same exponent base
|
||||||
|
sm_e = sm << (se - me)
|
||||||
|
em_e = em << (ee - me)
|
||||||
|
rm_e = rm << (re - me)
|
||||||
|
vm_e = vm << (ve - me)
|
||||||
|
|
||||||
|
# find matching higher and non-matching lower bits of e and r
|
||||||
|
sm_b = bin(sm_e)[2:] if sm_e else ''
|
||||||
|
em_b = bin(em_e)[2:] if em_e else ''
|
||||||
|
rm_b = bin(rm_e)[2:] if rm_e else ''
|
||||||
|
vm_b = bin(vm_e)[2:] if vm_e else ''
|
||||||
|
|
||||||
|
m = max(len(sm_b), len(em_b), len(rm_b), len(vm_b))
|
||||||
|
em_b = '0' * (m - len(em_b)) + em_b
|
||||||
|
rm_b = '0' * (m - len(rm_b)) + rm_b
|
||||||
|
|
||||||
|
c1 = 0
|
||||||
|
for b0, b1 in zip(em_b, rm_b):
|
||||||
|
if b0 != b1:
|
||||||
|
break
|
||||||
|
c1 += 1
|
||||||
|
c0 = m - c1
|
||||||
|
|
||||||
|
# truncate r and v mantissa
|
||||||
|
rm_m = rm_e >> c0
|
||||||
|
vm_m = vm_e >> c0
|
||||||
|
|
||||||
|
# normalized r and v
|
||||||
|
nr = ctx.make_mpf((rs, rm_m, -c1, len(bin(rm_m)) - 2)) if rm_m else (-ctx.zero if rs else ctx.zero)
|
||||||
|
nv = ctx.make_mpf((vs, vm_m, -c1, len(bin(vm_m)) - 2)) if vm_m else (-ctx.zero if vs else ctx.zero)
|
||||||
|
|
||||||
|
return nr, nv
|
||||||
|
|
||||||
|
ctx = exact.context
|
||||||
|
scale = abs(exact)
|
||||||
|
if isinstance(exact, ctx.mpc):
|
||||||
|
rr, rv = worker(ctx, scale, exact.real, reference.real, value.real)
|
||||||
|
ir, iv = worker(ctx, scale, exact.imag, reference.imag, value.imag)
|
||||||
|
return ctx.make_mpc((rr._mpf_, ir._mpf_)), ctx.make_mpc((rv._mpf_, iv._mpf_))
|
||||||
|
elif isinstance(exact, ctx.mpf):
|
||||||
|
return worker(ctx, scale, exact, reference, value)
|
||||||
|
else:
|
||||||
|
assert 0 # unreachable
|
||||||
|
@ -3365,87 +3365,133 @@ class CustomElementTypesTest(jtu.JaxTestCase):
|
|||||||
class FunctionAccuracyTest(jtu.JaxTestCase):
|
class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
dict(testcase_name=f"_{name}_{dtype.__name__}_{kind}", name=name, dtype=dtype, kind=kind)
|
dict(testcase_name=f"_{dtype.__name__}", dtype=dtype)
|
||||||
for name, dtype, kind in itertools.product(
|
for dtype in jtu.dtypes.supported([np.float32, np.float64, np.complex64, np.complex128]))
|
||||||
[ 'arccos', 'arccosh', 'arcsin', 'arcsinh',
|
def testMPMathUtils(self, dtype):
|
||||||
'arctan', 'arctanh', 'conjugate', 'cos',
|
try:
|
||||||
'cosh', 'exp', 'exp2', 'expm1', 'log',
|
import mpmath
|
||||||
'log10', 'log1p', 'sin', 'sinh', 'sqrt',
|
except ImportError as msg:
|
||||||
'square', 'tan', 'tanh', 'sinc', 'positive',
|
self.skipTest(f'could not import mpmath: {msg}')
|
||||||
'negative', 'absolute', 'sign'],
|
|
||||||
|
prec = {np.float32: 24, np.float64: 53, np.complex64: 24, np.complex128: 53}[dtype]
|
||||||
|
is_complex = dtype().dtype.kind == 'c'
|
||||||
|
|
||||||
|
def func(x):
|
||||||
|
assert isinstance(x, mpmath.ctx_mp.mpnumeric)
|
||||||
|
assert x.context.prec == prec
|
||||||
|
assert isinstance(x, x.context.mpc if is_complex else x.context.mpf)
|
||||||
|
return x
|
||||||
|
|
||||||
|
ufunc = jtu.vectorize_with_mpmath(func, mpmath=mpmath)
|
||||||
|
|
||||||
|
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
||||||
|
if is_complex:
|
||||||
|
arr = jtu.complex_plane_sample(dtype=dtype, size_re=11)
|
||||||
|
else:
|
||||||
|
cdtype = getattr(np, ufunc.map_float_to_complex[dtype.__name__])
|
||||||
|
arr = jtu.complex_plane_sample(dtype=cdtype, size_re=11, size_im=0)[1:2].real
|
||||||
|
|
||||||
|
arr2 = ufunc.mptonp(ufunc.nptomp(arr))
|
||||||
|
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
||||||
|
self.assertAllClose(arr, arr2, atol=0, rtol=0)
|
||||||
|
|
||||||
|
arr3 = ufunc(arr)
|
||||||
|
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
||||||
|
self.assertAllClose(arr, arr3, atol=0, rtol=0)
|
||||||
|
|
||||||
|
if is_complex:
|
||||||
|
# tests scale in normalize
|
||||||
|
v = dtype(1.1071487177940644+1.1102230246251565e-16j)
|
||||||
|
r = dtype(1.1071487177940644+0j)
|
||||||
|
mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1)
|
||||||
|
nr, nv = mnp.normalize(r, r, v)
|
||||||
|
self.assertAllClose(nr, nv)
|
||||||
|
|
||||||
|
_functions_on_complex_plane = [
|
||||||
|
'arccos', 'arccosh', 'arcsin', 'arcsinh',
|
||||||
|
'arctan', 'arctanh', 'conjugate', 'cos',
|
||||||
|
'cosh', 'exp', 'exp2', 'expm1', 'log',
|
||||||
|
'log10', 'log1p', 'sin', 'sinh', 'sqrt',
|
||||||
|
'square', 'tan', 'tanh', 'sinc', 'positive',
|
||||||
|
'negative', 'absolute', 'sign'
|
||||||
|
]
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype)
|
||||||
|
for name, dtype in itertools.product(
|
||||||
|
_functions_on_complex_plane,
|
||||||
jtu.dtypes.supported([np.complex64, np.complex128]),
|
jtu.dtypes.supported([np.complex64, np.complex128]),
|
||||||
['success', 'failure'],
|
|
||||||
))
|
))
|
||||||
@jtu.skip_on_devices("tpu")
|
@jtu.skip_on_devices("tpu")
|
||||||
def testOnComplexPlane(self, name, dtype, kind):
|
def testSuccessOnComplexPlane(self, name, dtype):
|
||||||
all_regions = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero']
|
self._testOnComplexPlaneWorker(name, dtype, 'success')
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype)
|
||||||
|
for name, dtype in itertools.product(
|
||||||
|
_functions_on_complex_plane,
|
||||||
|
jtu.dtypes.supported([np.complex64, np.complex128]),
|
||||||
|
))
|
||||||
|
@jtu.skip_on_devices("tpu")
|
||||||
|
def testFailureOnComplexPlane(self, name, dtype):
|
||||||
|
self._testOnComplexPlaneWorker(name, dtype, 'failure')
|
||||||
|
|
||||||
|
def _testOnComplexPlaneWorker(self, name, dtype, kind):
|
||||||
|
try:
|
||||||
|
import mpmath
|
||||||
|
except ImportError as msg:
|
||||||
|
self.skipTest(f'could not import mpmath: {msg}')
|
||||||
|
|
||||||
is_cpu = jtu.test_device_matches(["cpu"])
|
is_cpu = jtu.test_device_matches(["cpu"])
|
||||||
machine = platform.machine()
|
machine = platform.machine()
|
||||||
|
# TODO: remove is_arm_cpu as previously arm cpu related failures
|
||||||
|
# were due to numpy issues. Confirm?
|
||||||
is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm')
|
is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm')
|
||||||
is_cuda = jtu.test_device_matches(["cuda"])
|
is_cuda = jtu.test_device_matches(["cuda"])
|
||||||
|
|
||||||
# TODO(pearu): eliminate all items in the following lists:
|
size_re = 11
|
||||||
# TODO(pearu): when all items are eliminated, eliminate the kind == 'failure' tests
|
size_im = 11
|
||||||
regions_with_inaccuracies = dict(
|
atol = None
|
||||||
absolute = ['q1', 'q2', 'q3', 'q4'] if dtype == np.complex128 and is_cuda else [],
|
|
||||||
exp = (['pos', 'pinfj', 'pinf', 'ninfj', 'ninf']
|
|
||||||
+ (['q1', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
|
|
||||||
exp2 = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf', *(['q1', 'q4'] if is_cpu else [])],
|
|
||||||
log = ['q1', 'q2', 'q3', 'q4'],
|
|
||||||
log1p = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'ninfj', 'pinfj'],
|
|
||||||
log10 = ['q1', 'q2', 'q3', 'q4', 'zero', 'ninf', 'ninfj', 'pinf', 'pinfj'],
|
|
||||||
sinh = (['pos', 'neg', 'ninf', 'pinf']
|
|
||||||
+ (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
|
|
||||||
cosh = (['pos', 'neg', 'ninf', 'pinf']
|
|
||||||
+ (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
|
|
||||||
tan = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'],
|
|
||||||
square = (['pinf']
|
|
||||||
+ (['ninfj', 'pinfj'] if is_arm_cpu else [])
|
|
||||||
+ (['ninf'] if not is_arm_cpu else [])
|
|
||||||
+ (['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else [])
|
|
||||||
+ (['q1', 'q2', 'q3', 'q4'] if is_cpu and dtype == np.complex128 else [])),
|
|
||||||
sinc = ['q1', 'q2', 'q3', 'q4'],
|
|
||||||
arcsin = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
|
||||||
arccos = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
|
||||||
arctan = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
|
||||||
arcsinh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
|
||||||
arccosh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
|
||||||
arctanh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
|
||||||
sin = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [],
|
|
||||||
cos = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [],
|
|
||||||
expm1 = ['q1', 'q4', 'pinf'] if is_arm_cpu and dtype != np.complex128 else [],
|
|
||||||
)
|
|
||||||
|
|
||||||
if jtu.numpy_version() < (2, 0, 0):
|
mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1)
|
||||||
regions_with_inaccuracies['sign'] = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj']
|
mnp2 = jtu.numpy_with_mpmath(mpmath, extra_prec_multiplier=1)
|
||||||
|
|
||||||
|
ref_op = getattr(mnp, name)
|
||||||
|
ref2_op = getattr(mnp2, name)
|
||||||
jnp_op = getattr(jnp, name)
|
jnp_op = getattr(jnp, name)
|
||||||
|
|
||||||
if name == 'square':
|
|
||||||
# numpy square is incorrect on inputs with large absolute value
|
|
||||||
tiny = np.finfo(dtype).tiny
|
|
||||||
|
|
||||||
def square(x):
|
|
||||||
re = (x.real - x.imag) * (x.real + x.imag)
|
|
||||||
im = x.real * x.imag * 2
|
|
||||||
if is_cuda:
|
|
||||||
# apply FTZ
|
|
||||||
if np.isfinite(re) and abs(re) < tiny:
|
|
||||||
re *= 0
|
|
||||||
if np.isfinite(im) and abs(im) < tiny:
|
|
||||||
im *= 0
|
|
||||||
return np.array(complex(re, im), dtype=dtype)
|
|
||||||
|
|
||||||
np_op = np.vectorize(square)
|
|
||||||
else:
|
|
||||||
np_op = getattr(np, name)
|
|
||||||
|
|
||||||
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
||||||
args = (jtu.complex_plane_sample(dtype=dtype, size_re=11),)
|
args = (jtu.complex_plane_sample(dtype=dtype, size_re=size_re, size_im=size_im),)
|
||||||
result = np.asarray(jnp_op(*args))
|
result = np.asarray(jnp_op(*args))
|
||||||
expected = np_op(*args)
|
expected = ref_op(*args)
|
||||||
|
expected2 = ref2_op(*args)
|
||||||
|
|
||||||
s0, s1 = (result.shape[0] - 3) // 2, (result.shape[1] - 3) // 2
|
normalized_expected, normalized_result = mnp2.normalize(expected2, expected, result)
|
||||||
|
|
||||||
|
# When comparing the results with expected, we'll divide the
|
||||||
|
# complex plane grid into smaller regions and perform the
|
||||||
|
# closeness tests on each region separately. The reason for this
|
||||||
|
# is that the inaccuracy or incorrectness issues with a particular
|
||||||
|
# function exists typically in specific regions while in other
|
||||||
|
# regions the function is accurate. So, such a division of the
|
||||||
|
# complex plane helps to identify the problematic regions as well
|
||||||
|
# as to fix the inaccuracy or incorrectness issues.
|
||||||
|
#
|
||||||
|
# Regions in complex plane:
|
||||||
|
#
|
||||||
|
# ( pinfj )
|
||||||
|
# ( q2 ) (posj) ( q1 )
|
||||||
|
# (ninf) ( neg ) (zero) ( pos ) (pinf)
|
||||||
|
# ( q3 ) (negj) ( q4 )
|
||||||
|
# ( ninfj )
|
||||||
|
#
|
||||||
|
# In addition, the 1/3 middle parts of regions q1, q2, q3, q4,
|
||||||
|
# neg, pos are tested separately as these don't contain extremely
|
||||||
|
# small or extremelly large values and functions on these regions
|
||||||
|
# ought not to possess any incorrectness issues.
|
||||||
|
|
||||||
|
s0, s1 = size_re, size_im
|
||||||
|
s03, s13 = s0 // 3, s1 // 3
|
||||||
s_dict = dict(
|
s_dict = dict(
|
||||||
q1=(slice(s0 + 2, -1), slice(s1 + 2, -1)),
|
q1=(slice(s0 + 2, -1), slice(s1 + 2, -1)),
|
||||||
q2=(slice(s0 + 2, -1), slice(1, s1 + 1)),
|
q2=(slice(s0 + 2, -1), slice(1, s1 + 1)),
|
||||||
@ -3462,38 +3508,167 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
|||||||
zero=(slice(s0 + 1, s0 + 2), slice(s1 + 1, s1 + 2)),
|
zero=(slice(s0 + 1, s0 + 2), slice(s1 + 1, s1 + 2)),
|
||||||
)
|
)
|
||||||
|
|
||||||
for region in all_regions:
|
if s03 and s13:
|
||||||
if is_arm_cpu:
|
s_dict.update(
|
||||||
if (
|
mq1 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
|
||||||
(
|
mq2 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(2 + s13, 2 + 2 * s13)),
|
||||||
region in ['q1', 'q2', 'q3', 'q4']
|
mq3 = (slice(2 + s03, 2 + 2 * s03), slice(2 + s13, 2 + 2 * s13)),
|
||||||
and name in ['cos', 'cosh', 'sin', 'sinh', 'exp', 'expm1']
|
mq4 = (slice(2 + s03, 2 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
|
||||||
)
|
mneg=(s0 + 1, slice(2 + s13, 2 + 2 * s13)),
|
||||||
or (region in ['pinfj', 'ninfj'] and name in ['sin', 'cos'])
|
mpos=(s0 + 1, slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
|
||||||
or (region == 'pinf' and name in ['expm1'])
|
mnegj=(slice(2 + s03, 2 + 2 * s03), s1 + 1),
|
||||||
):
|
mposj=(slice(s0 + 3 + s03, s0 + 3 + 2 * s03), s1 + 1),
|
||||||
continue
|
)
|
||||||
s = s_dict[region]
|
|
||||||
inds = np.where(result[s] != expected[s])
|
|
||||||
if inds[0].size > 0:
|
|
||||||
mismatches = []
|
|
||||||
for ind in zip(*inds):
|
|
||||||
x, r, e = args[0][s][ind], str(result[s][ind]), str(expected[s][ind])
|
|
||||||
if r == e:
|
|
||||||
# skip equal nan-s
|
|
||||||
continue
|
|
||||||
mismatches.append(f'jax.numpy.{name}{x} -> {r}, expected {e}')
|
|
||||||
mismatches = "\n".join(mismatches)
|
|
||||||
else:
|
|
||||||
mismatches = ''
|
|
||||||
if kind == 'success' and region not in regions_with_inaccuracies.get(name, []):
|
|
||||||
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
|
||||||
self.assertAllClose(result[s], expected[s], err_msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{mismatches}")
|
|
||||||
if kind == 'failure' and region in regions_with_inaccuracies.get(name, []):
|
|
||||||
with self.assertRaises(AssertionError, msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"):
|
|
||||||
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
|
||||||
self.assertAllClose(result[s], expected[s]) # on success, update regions_with_inaccuracies
|
|
||||||
|
|
||||||
|
# Start with an assumption that all regions are problematic for a
|
||||||
|
# particular function:
|
||||||
|
regions_with_inaccuracies = list(s_dict)
|
||||||
|
|
||||||
|
# Next, we'll remove non-problematic regions from the
|
||||||
|
# regions_with_inaccuracies list by explicitly keeping problematic
|
||||||
|
# regions:
|
||||||
|
def regions_with_inaccuracies_keep(*to_keep):
|
||||||
|
for item in regions_with_inaccuracies[:]:
|
||||||
|
if item not in to_keep:
|
||||||
|
regions_with_inaccuracies.remove(item)
|
||||||
|
|
||||||
|
if name == 'absolute':
|
||||||
|
if is_cuda and dtype == np.complex128:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4')
|
||||||
|
else:
|
||||||
|
regions_with_inaccuracies.clear()
|
||||||
|
|
||||||
|
elif name == 'sign':
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4')
|
||||||
|
|
||||||
|
elif name == 'square':
|
||||||
|
if is_cuda:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj')
|
||||||
|
if is_cpu:
|
||||||
|
regions_with_inaccuracies_keep('ninf', 'pinf')
|
||||||
|
|
||||||
|
elif name == 'log':
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj')
|
||||||
|
|
||||||
|
elif name == 'log10':
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero')
|
||||||
|
|
||||||
|
elif name == 'log1p':
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
|
||||||
|
|
||||||
|
elif name == 'exp':
|
||||||
|
regions_with_inaccuracies_keep('pos', 'pinf', 'mpos')
|
||||||
|
|
||||||
|
elif name == 'exp2':
|
||||||
|
if dtype == np.complex64:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mpos', 'mnegj', 'mposj')
|
||||||
|
if dtype == np.complex128:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mpos')
|
||||||
|
|
||||||
|
elif name == 'expm1':
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
|
||||||
|
|
||||||
|
elif name == 'sinc':
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
|
||||||
|
|
||||||
|
elif name == 'tan':
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj')
|
||||||
|
|
||||||
|
elif name == 'sinh':
|
||||||
|
if is_cuda:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
|
||||||
|
if is_cpu:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
|
||||||
|
|
||||||
|
elif name == 'cosh':
|
||||||
|
regions_with_inaccuracies_keep('neg', 'pos', 'ninf', 'pinf', 'mneg', 'mpos')
|
||||||
|
|
||||||
|
elif name == 'tanh':
|
||||||
|
regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')
|
||||||
|
|
||||||
|
elif name == 'arccos':
|
||||||
|
if dtype == np.complex64:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj')
|
||||||
|
if dtype == np.complex128:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj')
|
||||||
|
|
||||||
|
elif name == 'arccosh':
|
||||||
|
if dtype == np.complex64:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj')
|
||||||
|
if dtype == np.complex128:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mnegj')
|
||||||
|
|
||||||
|
elif name == 'arcsin':
|
||||||
|
if dtype == np.complex64:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
|
||||||
|
if dtype == np.complex128:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
|
||||||
|
|
||||||
|
elif name == 'arcsinh':
|
||||||
|
if dtype == np.complex64:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj')
|
||||||
|
if dtype == np.complex128:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mnegj')
|
||||||
|
|
||||||
|
elif name == 'arctan':
|
||||||
|
if dtype == np.complex64:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj')
|
||||||
|
if dtype == np.complex128:
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
|
||||||
|
|
||||||
|
elif name == 'arctanh':
|
||||||
|
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
|
||||||
|
|
||||||
|
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt'}:
|
||||||
|
regions_with_inaccuracies.clear()
|
||||||
|
else:
|
||||||
|
assert 0 # unreachable
|
||||||
|
|
||||||
|
# Finally, perform the closeness tests per region:
|
||||||
|
unexpected_success_regions = []
|
||||||
|
for region_name, region_slice in s_dict.items():
|
||||||
|
region = args[0][region_slice]
|
||||||
|
inexact_indices = np.where(normalized_result[region_slice] != normalized_expected[region_slice])
|
||||||
|
|
||||||
|
if inexact_indices[0].size == 0:
|
||||||
|
inexact_samples = ''
|
||||||
|
else:
|
||||||
|
inexact_samples = []
|
||||||
|
for ind in zip(*inexact_indices):
|
||||||
|
x = region[ind]
|
||||||
|
y1, y2 = result[region_slice][ind], expected[region_slice][ind]
|
||||||
|
ny1, ny2 = normalized_result[region_slice][ind], normalized_expected[region_slice][ind]
|
||||||
|
if str(y1) == str(y2): # skip equal nan-s
|
||||||
|
continue
|
||||||
|
max_abs_diff = abs(ny1 - ny2).max() if np.isfinite(y1) and np.isfinite(y1) else np.inf
|
||||||
|
inexact_samples.append((max_abs_diff, f'jax.numpy.{name}({x}) -> {y1} [{ny1}], expected {y2} [{ny2}]'))
|
||||||
|
inexact_samples = "\n".join([msg for _, msg in sorted(inexact_samples)])
|
||||||
|
|
||||||
|
if kind == 'success' and region_name not in regions_with_inaccuracies:
|
||||||
|
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
||||||
|
self.assertAllClose(
|
||||||
|
normalized_result[region_slice], normalized_expected[region_slice], atol=atol,
|
||||||
|
err_msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{inexact_samples}")
|
||||||
|
|
||||||
|
if kind == 'failure' and region_name in regions_with_inaccuracies:
|
||||||
|
try:
|
||||||
|
with self.assertRaises(AssertionError, msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"):
|
||||||
|
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
||||||
|
self.assertAllClose(normalized_result[region_slice], normalized_expected[region_slice])
|
||||||
|
except AssertionError as msg:
|
||||||
|
if str(msg).startswith('AssertionError not raised'):
|
||||||
|
unexpected_success_regions.append(region_name)
|
||||||
|
else:
|
||||||
|
raise # something else is wrong..
|
||||||
|
|
||||||
|
if kind == 'success' and regions_with_inaccuracies:
|
||||||
|
reason = "xfail: problematic regions: " + ", ".join(regions_with_inaccuracies)
|
||||||
|
raise unittest.SkipTest(reason)
|
||||||
|
|
||||||
|
if kind == 'failure':
|
||||||
|
self.assertEqual(unexpected_success_regions, []) # regions_with_inaccuracies requires an update!
|
||||||
|
if not regions_with_inaccuracies:
|
||||||
|
raise unittest.SkipTest("no problematic regions")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user