Evaluate the correctness of JAX complex functions using mpmath as a reference

This commit is contained in:
Pearu Peterson 2024-03-21 23:35:29 +02:00
parent 2848cda34c
commit fdb5015909
3 changed files with 581 additions and 104 deletions

View File

@ -4,6 +4,7 @@ cloudpickle
colorama>=0.4.4
flatbuffers
hypothesis
mpmath>=1.3
numpy>=1.22
pillow>=9.1.0
portpicker

View File

@ -1467,11 +1467,11 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
>>> print(complex_plane_sample(np.complex64, 0, 3))
[[-inf -infj 0. -infj inf -infj]
[-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j]
[-inf-2.0000052e+00j 0.-2.0000052e+00j inf-2.0000052e+00j]
[-inf-2.0000000e+00j 0.-2.0000000e+00j inf-2.0000000e+00j]
[-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j]
[-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j]
[-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j]
[-inf+2.0000052e+00j 0.+2.0000052e+00j inf+2.0000052e+00j]
[-inf+2.0000000e+00j 0.+2.0000000e+00j inf+2.0000000e+00j]
[-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j]
[-inf +infj 0. +infj inf +infj]]
@ -1481,16 +1481,18 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
finfo = np.finfo(dtype)
def make_axis_points(size):
logmin = np.log10(abs(finfo.min))
logtiny = np.log10(finfo.tiny)
logmax = np.log10(finfo.max)
prec_dps_ratio = 3.3219280948873626
logmin = logmax = finfo.maxexp / prec_dps_ratio
logtiny = finfo.minexp / prec_dps_ratio
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)
with warnings.catch_warnings():
# Silence RuntimeWarning: overflow encountered in cast
warnings.simplefilter("ignore")
axis_points[1:size + 1] = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
axis_points[-size - 1:-1] = np.logspace(logtiny, logmax, size, dtype=finfo.dtype)
half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
half_line = -half_neg_line[::-1]
axis_points[-size - 1:-1] = half_line
axis_points[1:size + 1] = half_neg_line
if size > 1:
axis_points[1] = finfo.min
@ -1512,3 +1514,302 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1)
return real_part + imag_part
class vectorize_with_mpmath(np.vectorize):
"""Same as numpy.vectorize but using mpmath backend for function evaluation.
"""
map_float_to_complex = dict(float16='complex32', float32='complex64', float64='complex128', float128='complex256', longdouble='clongdouble')
map_complex_to_float = {v: k for k, v in map_float_to_complex.items()}
float_prec = dict(
# float16=11,
float32=24,
float64=53,
# float128=113,
# longdouble=113
)
float_minexp = dict(
float16=-14,
float32=-126,
float64=-1022,
float128=-16382
)
float_maxexp = dict(
float16=16,
float32=128,
float64=1024,
float128=16384,
)
def __init__(self, *args, **kwargs):
mpmath = kwargs.pop('mpmath', None)
if mpmath is None:
raise ValueError('vectorize_with_mpmath: no mpmath argument specified')
self.extra_prec_multiplier = kwargs.pop('extra_prec_multiplier', 0)
self.extra_prec = kwargs.pop('extra_prec', 0)
self.mpmath = mpmath
self.contexts = dict()
self.contexts_inv = dict()
for fp_format, prec in self.float_prec.items():
ctx = self.mpmath.mp.clone()
ctx.prec = prec
self.contexts[fp_format] = ctx
self.contexts_inv[ctx] = fp_format
super().__init__(*args, **kwargs)
def get_context(self, x):
if isinstance(x, (np.ndarray, np.floating, np.complexfloating)):
fp_format = str(x.dtype)
fp_format = self.map_complex_to_float.get(fp_format, fp_format)
return self.contexts[fp_format]
raise NotImplementedError(f'get mpmath context from {type(x).__name__} instance')
def nptomp(self, x):
"""Convert numpy array/scalar to an array/instance of mpmath number type.
"""
if isinstance(x, np.ndarray):
return np.fromiter(map(self.nptomp, x.flatten()), dtype=object).reshape(x.shape)
elif isinstance(x, np.floating):
mpmath = self.mpmath
ctx = self.get_context(x)
prec, rounding = ctx._prec_rounding
if np.isposinf(x):
return ctx.make_mpf(mpmath.libmp.finf)
elif np.isneginf(x):
return ctx.make_mpf(mpmath.libmp.fninf)
elif np.isnan(x):
return ctx.make_mpf(mpmath.libmp.fnan)
elif np.isfinite(x):
mantissa, exponent = np.frexp(x)
man = int(np.ldexp(mantissa, prec))
exp = int(exponent - prec)
r = ctx.make_mpf(mpmath.libmp.from_man_exp(man, exp, prec, rounding))
assert ctx.isfinite(r), r._mpf_
return r
elif isinstance(x, np.complexfloating):
re, im = self.nptomp(x.real), self.nptomp(x.imag)
return re.context.make_mpc((re._mpf_, im._mpf_))
raise NotImplementedError(f'convert {type(x).__name__} instance to mpmath number type')
def mptonp(self, x):
"""Convert mpmath instance to numpy array/scalar type.
"""
if isinstance(x, np.ndarray) and x.dtype.kind == 'O':
x_flat = x.flatten()
item = x_flat[0]
ctx = item.context
fp_format = self.contexts_inv[ctx]
if isinstance(item, ctx.mpc):
dtype = getattr(np, self.map_float_to_complex[fp_format])
elif isinstance(item, ctx.mpf):
dtype = getattr(np, fp_format)
else:
dtype = None
if dtype is not None:
return np.fromiter(map(self.mptonp, x_flat), dtype=dtype).reshape(x.shape)
elif isinstance(x, self.mpmath.ctx_mp.mpnumeric):
ctx = x.context
if isinstance(x, ctx.mpc):
fp_format = self.contexts_inv[ctx]
dtype = getattr(np, self.map_float_to_complex[fp_format])
r = dtype().reshape(1).view(getattr(np, fp_format))
r[0] = self.mptonp(x.real)
r[1] = self.mptonp(x.imag)
return r.view(dtype)[0]
elif isinstance(x, ctx.mpf):
fp_format = self.contexts_inv[ctx]
dtype = getattr(np, fp_format)
if ctx.isfinite(x):
sign, man, exp, bc = self.mpmath.libmp.normalize(*x._mpf_, *ctx._prec_rounding)
assert bc >= 0, (sign, man, exp, bc, x._mpf_)
if exp + bc < self.float_minexp[fp_format]:
return -ctx.zero if sign else ctx.zero
if exp + bc > self.float_maxexp[fp_format]:
return ctx.ninf if sign else ctx.inf
man = dtype(-man if sign else man)
r = np.ldexp(man, exp)
assert np.isfinite(r), (x, r, x._mpf_, man)
return r
elif ctx.isnan(x):
return dtype(np.nan)
elif ctx.isinf(x):
return dtype(-np.inf if x._mpf_[0] else np.inf)
raise NotImplementedError(f'convert {type(x)} instance to numpy floating point type')
def __call__(self, *args, **kwargs):
mp_args = []
context = None
for a in args:
if isinstance(a, (np.ndarray, np.floating, np.complexfloating)):
mp_args.append(self.nptomp(a))
if context is None:
context = self.get_context(a)
else:
assert context is self.get_context(a)
else:
mp_args.append(a)
extra_prec = int(context.prec * self.extra_prec_multiplier) + self.extra_prec
with context.extraprec(extra_prec):
result = super().__call__(*mp_args, **kwargs)
if isinstance(result, tuple):
lst = []
for r in result:
if ((isinstance(r, np.ndarray) and r.dtype.kind == 'O')
or isinstance(r, self.mpmath.ctx_mp.mpnumeric)):
r = self.mptonp(r)
lst.append(r)
return tuple(lst)
if ((isinstance(result, np.ndarray) and result.dtype.kind == 'O')
or isinstance(result, self.mpmath.ctx_mp.mpnumeric)):
return self.mptonp(result)
return result
class numpy_with_mpmath:
"""Namespace of universal functions on numpy arrays that use mpmath
backend for evaluation and return numpy arrays as outputs.
"""
_provides = [
'abs', 'absolute', 'sqrt', 'exp', 'expm1', 'exp2',
'log', 'log1p', 'log10', 'log2',
'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
'square', 'positive', 'negative', 'conjugate', 'sign', 'sinc',
'normalize',
]
_mp_names = dict(
abs='absmin', absolute='absmin',
log='ln',
arcsin='asin', arccos='acos', arctan='atan',
arcsinh='asinh', arccosh='acosh', arctanh='atanh',
)
def __init__(self, mpmath, extra_prec_multiplier=0, extra_prec=0):
self.mpmath = mpmath
for name in self._provides:
mp_name = self._mp_names.get(name, name)
if hasattr(self, name):
op = getattr(self, name)
else:
def op(x, mp_name=mp_name):
return getattr(x.context, mp_name)(x)
setattr(self, name, vectorize_with_mpmath(op, mpmath=mpmath, extra_prec_multiplier=extra_prec_multiplier, extra_prec=extra_prec))
# The following function methods operate on mpmath number instances.
# The corresponding function names must be listed in
# numpy_with_mpmath._provides list.
def square(self, x):
return x * x
def positive(self, x):
return x
def negative(self, x):
return -x
def sqrt(self, x):
ctx = x.context
# workaround mpmath bugs:
if isinstance(x, ctx.mpc):
if ctx.isinf(x.real) and ctx.isinf(x.imag):
if x.real > 0: return x
ninf = x.real
inf = -ninf
if x.imag > 0: return ctx.make_mpc((inf._mpf_, inf._mpf_))
return ctx.make_mpc((inf._mpf_, inf._mpf_))
elif ctx.isfinite(x.real) and ctx.isinf(x.imag):
if x.imag > 0:
inf = x.imag
return ctx.make_mpc((inf._mpf_, inf._mpf_))
else:
ninf = x.imag
inf = -ninf
return ctx.make_mpc((inf._mpf_, ninf._mpf_))
return ctx.sqrt(x)
def expm1(self, x):
return x.context.expm1(x)
def log2(self, x):
return x.context.ln(x) / x.context.ln2
def log10(self, x):
return x.context.ln(x) / x.context.ln10
def exp2(self, x):
return x.context.exp(x * x.context.ln2)
def normalize(self, exact, reference, value):
"""Normalize reference and value using precision defined by the
difference of exact and reference.
"""
def worker(ctx, s, e, r, v):
ss, sm, se, sbc = s._mpf_
es, em, ee, ebc = e._mpf_
rs, rm, re, rbc = r._mpf_
vs, vm, ve, vbc = v._mpf_
if not (ctx.isfinite(e) and ctx.isfinite(r) and ctx.isfinite(v)):
return r, v
me = min(se, ee, re, ve)
# transform mantissa parts to the same exponent base
sm_e = sm << (se - me)
em_e = em << (ee - me)
rm_e = rm << (re - me)
vm_e = vm << (ve - me)
# find matching higher and non-matching lower bits of e and r
sm_b = bin(sm_e)[2:] if sm_e else ''
em_b = bin(em_e)[2:] if em_e else ''
rm_b = bin(rm_e)[2:] if rm_e else ''
vm_b = bin(vm_e)[2:] if vm_e else ''
m = max(len(sm_b), len(em_b), len(rm_b), len(vm_b))
em_b = '0' * (m - len(em_b)) + em_b
rm_b = '0' * (m - len(rm_b)) + rm_b
c1 = 0
for b0, b1 in zip(em_b, rm_b):
if b0 != b1:
break
c1 += 1
c0 = m - c1
# truncate r and v mantissa
rm_m = rm_e >> c0
vm_m = vm_e >> c0
# normalized r and v
nr = ctx.make_mpf((rs, rm_m, -c1, len(bin(rm_m)) - 2)) if rm_m else (-ctx.zero if rs else ctx.zero)
nv = ctx.make_mpf((vs, vm_m, -c1, len(bin(vm_m)) - 2)) if vm_m else (-ctx.zero if vs else ctx.zero)
return nr, nv
ctx = exact.context
scale = abs(exact)
if isinstance(exact, ctx.mpc):
rr, rv = worker(ctx, scale, exact.real, reference.real, value.real)
ir, iv = worker(ctx, scale, exact.imag, reference.imag, value.imag)
return ctx.make_mpc((rr._mpf_, ir._mpf_)), ctx.make_mpc((rv._mpf_, iv._mpf_))
elif isinstance(exact, ctx.mpf):
return worker(ctx, scale, exact, reference, value)
else:
assert 0 # unreachable

View File

@ -3365,87 +3365,133 @@ class CustomElementTypesTest(jtu.JaxTestCase):
class FunctionAccuracyTest(jtu.JaxTestCase):
@parameterized.named_parameters(
dict(testcase_name=f"_{name}_{dtype.__name__}_{kind}", name=name, dtype=dtype, kind=kind)
for name, dtype, kind in itertools.product(
[ 'arccos', 'arccosh', 'arcsin', 'arcsinh',
'arctan', 'arctanh', 'conjugate', 'cos',
'cosh', 'exp', 'exp2', 'expm1', 'log',
'log10', 'log1p', 'sin', 'sinh', 'sqrt',
'square', 'tan', 'tanh', 'sinc', 'positive',
'negative', 'absolute', 'sign'],
dict(testcase_name=f"_{dtype.__name__}", dtype=dtype)
for dtype in jtu.dtypes.supported([np.float32, np.float64, np.complex64, np.complex128]))
def testMPMathUtils(self, dtype):
try:
import mpmath
except ImportError as msg:
self.skipTest(f'could not import mpmath: {msg}')
prec = {np.float32: 24, np.float64: 53, np.complex64: 24, np.complex128: 53}[dtype]
is_complex = dtype().dtype.kind == 'c'
def func(x):
assert isinstance(x, mpmath.ctx_mp.mpnumeric)
assert x.context.prec == prec
assert isinstance(x, x.context.mpc if is_complex else x.context.mpf)
return x
ufunc = jtu.vectorize_with_mpmath(func, mpmath=mpmath)
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
if is_complex:
arr = jtu.complex_plane_sample(dtype=dtype, size_re=11)
else:
cdtype = getattr(np, ufunc.map_float_to_complex[dtype.__name__])
arr = jtu.complex_plane_sample(dtype=cdtype, size_re=11, size_im=0)[1:2].real
arr2 = ufunc.mptonp(ufunc.nptomp(arr))
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
self.assertAllClose(arr, arr2, atol=0, rtol=0)
arr3 = ufunc(arr)
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
self.assertAllClose(arr, arr3, atol=0, rtol=0)
if is_complex:
# tests scale in normalize
v = dtype(1.1071487177940644+1.1102230246251565e-16j)
r = dtype(1.1071487177940644+0j)
mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1)
nr, nv = mnp.normalize(r, r, v)
self.assertAllClose(nr, nv)
_functions_on_complex_plane = [
'arccos', 'arccosh', 'arcsin', 'arcsinh',
'arctan', 'arctanh', 'conjugate', 'cos',
'cosh', 'exp', 'exp2', 'expm1', 'log',
'log10', 'log1p', 'sin', 'sinh', 'sqrt',
'square', 'tan', 'tanh', 'sinc', 'positive',
'negative', 'absolute', 'sign'
]
@parameterized.named_parameters(
dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype)
for name, dtype in itertools.product(
_functions_on_complex_plane,
jtu.dtypes.supported([np.complex64, np.complex128]),
['success', 'failure'],
))
@jtu.skip_on_devices("tpu")
def testOnComplexPlane(self, name, dtype, kind):
all_regions = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero']
def testSuccessOnComplexPlane(self, name, dtype):
self._testOnComplexPlaneWorker(name, dtype, 'success')
@parameterized.named_parameters(
dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype)
for name, dtype in itertools.product(
_functions_on_complex_plane,
jtu.dtypes.supported([np.complex64, np.complex128]),
))
@jtu.skip_on_devices("tpu")
def testFailureOnComplexPlane(self, name, dtype):
self._testOnComplexPlaneWorker(name, dtype, 'failure')
def _testOnComplexPlaneWorker(self, name, dtype, kind):
try:
import mpmath
except ImportError as msg:
self.skipTest(f'could not import mpmath: {msg}')
is_cpu = jtu.test_device_matches(["cpu"])
machine = platform.machine()
# TODO: remove is_arm_cpu as previously arm cpu related failures
# were due to numpy issues. Confirm?
is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm')
is_cuda = jtu.test_device_matches(["cuda"])
# TODO(pearu): eliminate all items in the following lists:
# TODO(pearu): when all items are eliminated, eliminate the kind == 'failure' tests
regions_with_inaccuracies = dict(
absolute = ['q1', 'q2', 'q3', 'q4'] if dtype == np.complex128 and is_cuda else [],
exp = (['pos', 'pinfj', 'pinf', 'ninfj', 'ninf']
+ (['q1', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
exp2 = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf', *(['q1', 'q4'] if is_cpu else [])],
log = ['q1', 'q2', 'q3', 'q4'],
log1p = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'ninfj', 'pinfj'],
log10 = ['q1', 'q2', 'q3', 'q4', 'zero', 'ninf', 'ninfj', 'pinf', 'pinfj'],
sinh = (['pos', 'neg', 'ninf', 'pinf']
+ (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
cosh = (['pos', 'neg', 'ninf', 'pinf']
+ (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
tan = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'],
square = (['pinf']
+ (['ninfj', 'pinfj'] if is_arm_cpu else [])
+ (['ninf'] if not is_arm_cpu else [])
+ (['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else [])
+ (['q1', 'q2', 'q3', 'q4'] if is_cpu and dtype == np.complex128 else [])),
sinc = ['q1', 'q2', 'q3', 'q4'],
arcsin = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arccos = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arctan = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arcsinh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arccosh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arctanh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
sin = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [],
cos = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [],
expm1 = ['q1', 'q4', 'pinf'] if is_arm_cpu and dtype != np.complex128 else [],
)
size_re = 11
size_im = 11
atol = None
if jtu.numpy_version() < (2, 0, 0):
regions_with_inaccuracies['sign'] = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj']
mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1)
mnp2 = jtu.numpy_with_mpmath(mpmath, extra_prec_multiplier=1)
ref_op = getattr(mnp, name)
ref2_op = getattr(mnp2, name)
jnp_op = getattr(jnp, name)
if name == 'square':
# numpy square is incorrect on inputs with large absolute value
tiny = np.finfo(dtype).tiny
def square(x):
re = (x.real - x.imag) * (x.real + x.imag)
im = x.real * x.imag * 2
if is_cuda:
# apply FTZ
if np.isfinite(re) and abs(re) < tiny:
re *= 0
if np.isfinite(im) and abs(im) < tiny:
im *= 0
return np.array(complex(re, im), dtype=dtype)
np_op = np.vectorize(square)
else:
np_op = getattr(np, name)
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
args = (jtu.complex_plane_sample(dtype=dtype, size_re=11),)
args = (jtu.complex_plane_sample(dtype=dtype, size_re=size_re, size_im=size_im),)
result = np.asarray(jnp_op(*args))
expected = np_op(*args)
expected = ref_op(*args)
expected2 = ref2_op(*args)
s0, s1 = (result.shape[0] - 3) // 2, (result.shape[1] - 3) // 2
normalized_expected, normalized_result = mnp2.normalize(expected2, expected, result)
# When comparing the results with expected, we'll divide the
# complex plane grid into smaller regions and perform the
# closeness tests on each region separately. The reason for this
# is that the inaccuracy or incorrectness issues with a particular
# function exists typically in specific regions while in other
# regions the function is accurate. So, such a division of the
# complex plane helps to identify the problematic regions as well
# as to fix the inaccuracy or incorrectness issues.
#
# Regions in complex plane:
#
# ( pinfj )
# ( q2 ) (posj) ( q1 )
# (ninf) ( neg ) (zero) ( pos ) (pinf)
# ( q3 ) (negj) ( q4 )
# ( ninfj )
#
# In addition, the 1/3 middle parts of regions q1, q2, q3, q4,
# neg, pos are tested separately as these don't contain extremely
# small or extremelly large values and functions on these regions
# ought not to possess any incorrectness issues.
s0, s1 = size_re, size_im
s03, s13 = s0 // 3, s1 // 3
s_dict = dict(
q1=(slice(s0 + 2, -1), slice(s1 + 2, -1)),
q2=(slice(s0 + 2, -1), slice(1, s1 + 1)),
@ -3462,38 +3508,167 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
zero=(slice(s0 + 1, s0 + 2), slice(s1 + 1, s1 + 2)),
)
for region in all_regions:
if is_arm_cpu:
if (
(
region in ['q1', 'q2', 'q3', 'q4']
and name in ['cos', 'cosh', 'sin', 'sinh', 'exp', 'expm1']
)
or (region in ['pinfj', 'ninfj'] and name in ['sin', 'cos'])
or (region == 'pinf' and name in ['expm1'])
):
continue
s = s_dict[region]
inds = np.where(result[s] != expected[s])
if inds[0].size > 0:
mismatches = []
for ind in zip(*inds):
x, r, e = args[0][s][ind], str(result[s][ind]), str(expected[s][ind])
if r == e:
# skip equal nan-s
continue
mismatches.append(f'jax.numpy.{name}{x} -> {r}, expected {e}')
mismatches = "\n".join(mismatches)
else:
mismatches = ''
if kind == 'success' and region not in regions_with_inaccuracies.get(name, []):
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
self.assertAllClose(result[s], expected[s], err_msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{mismatches}")
if kind == 'failure' and region in regions_with_inaccuracies.get(name, []):
with self.assertRaises(AssertionError, msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"):
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
self.assertAllClose(result[s], expected[s]) # on success, update regions_with_inaccuracies
if s03 and s13:
s_dict.update(
mq1 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
mq2 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(2 + s13, 2 + 2 * s13)),
mq3 = (slice(2 + s03, 2 + 2 * s03), slice(2 + s13, 2 + 2 * s13)),
mq4 = (slice(2 + s03, 2 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
mneg=(s0 + 1, slice(2 + s13, 2 + 2 * s13)),
mpos=(s0 + 1, slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
mnegj=(slice(2 + s03, 2 + 2 * s03), s1 + 1),
mposj=(slice(s0 + 3 + s03, s0 + 3 + 2 * s03), s1 + 1),
)
# Start with an assumption that all regions are problematic for a
# particular function:
regions_with_inaccuracies = list(s_dict)
# Next, we'll remove non-problematic regions from the
# regions_with_inaccuracies list by explicitly keeping problematic
# regions:
def regions_with_inaccuracies_keep(*to_keep):
for item in regions_with_inaccuracies[:]:
if item not in to_keep:
regions_with_inaccuracies.remove(item)
if name == 'absolute':
if is_cuda and dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4')
else:
regions_with_inaccuracies.clear()
elif name == 'sign':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4')
elif name == 'square':
if is_cuda:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj')
if is_cpu:
regions_with_inaccuracies_keep('ninf', 'pinf')
elif name == 'log':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj')
elif name == 'log10':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero')
elif name == 'log1p':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
elif name == 'exp':
regions_with_inaccuracies_keep('pos', 'pinf', 'mpos')
elif name == 'exp2':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mpos', 'mnegj', 'mposj')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mpos')
elif name == 'expm1':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
elif name == 'sinc':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
elif name == 'tan':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj')
elif name == 'sinh':
if is_cuda:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
if is_cpu:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
elif name == 'cosh':
regions_with_inaccuracies_keep('neg', 'pos', 'ninf', 'pinf', 'mneg', 'mpos')
elif name == 'tanh':
regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')
elif name == 'arccos':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj')
elif name == 'arccosh':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mnegj')
elif name == 'arcsin':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
elif name == 'arcsinh':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mnegj')
elif name == 'arctan':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
elif name == 'arctanh':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt'}:
regions_with_inaccuracies.clear()
else:
assert 0 # unreachable
# Finally, perform the closeness tests per region:
unexpected_success_regions = []
for region_name, region_slice in s_dict.items():
region = args[0][region_slice]
inexact_indices = np.where(normalized_result[region_slice] != normalized_expected[region_slice])
if inexact_indices[0].size == 0:
inexact_samples = ''
else:
inexact_samples = []
for ind in zip(*inexact_indices):
x = region[ind]
y1, y2 = result[region_slice][ind], expected[region_slice][ind]
ny1, ny2 = normalized_result[region_slice][ind], normalized_expected[region_slice][ind]
if str(y1) == str(y2): # skip equal nan-s
continue
max_abs_diff = abs(ny1 - ny2).max() if np.isfinite(y1) and np.isfinite(y1) else np.inf
inexact_samples.append((max_abs_diff, f'jax.numpy.{name}({x}) -> {y1} [{ny1}], expected {y2} [{ny2}]'))
inexact_samples = "\n".join([msg for _, msg in sorted(inexact_samples)])
if kind == 'success' and region_name not in regions_with_inaccuracies:
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
self.assertAllClose(
normalized_result[region_slice], normalized_expected[region_slice], atol=atol,
err_msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{inexact_samples}")
if kind == 'failure' and region_name in regions_with_inaccuracies:
try:
with self.assertRaises(AssertionError, msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"):
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
self.assertAllClose(normalized_result[region_slice], normalized_expected[region_slice])
except AssertionError as msg:
if str(msg).startswith('AssertionError not raised'):
unexpected_success_regions.append(region_name)
else:
raise # something else is wrong..
if kind == 'success' and regions_with_inaccuracies:
reason = "xfail: problematic regions: " + ", ".join(regions_with_inaccuracies)
raise unittest.SkipTest(reason)
if kind == 'failure':
self.assertEqual(unexpected_success_regions, []) # regions_with_inaccuracies requires an update!
if not regions_with_inaccuracies:
raise unittest.SkipTest("no problematic regions")
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())