mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Evaluate the correctness of JAX complex functions using mpmath as a reference
This commit is contained in:
parent
2848cda34c
commit
fdb5015909
@ -4,6 +4,7 @@ cloudpickle
|
||||
colorama>=0.4.4
|
||||
flatbuffers
|
||||
hypothesis
|
||||
mpmath>=1.3
|
||||
numpy>=1.22
|
||||
pillow>=9.1.0
|
||||
portpicker
|
||||
|
@ -1467,11 +1467,11 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
|
||||
>>> print(complex_plane_sample(np.complex64, 0, 3))
|
||||
[[-inf -infj 0. -infj inf -infj]
|
||||
[-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j]
|
||||
[-inf-2.0000052e+00j 0.-2.0000052e+00j inf-2.0000052e+00j]
|
||||
[-inf-2.0000000e+00j 0.-2.0000000e+00j inf-2.0000000e+00j]
|
||||
[-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j]
|
||||
[-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j]
|
||||
[-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j]
|
||||
[-inf+2.0000052e+00j 0.+2.0000052e+00j inf+2.0000052e+00j]
|
||||
[-inf+2.0000000e+00j 0.+2.0000000e+00j inf+2.0000000e+00j]
|
||||
[-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j]
|
||||
[-inf +infj 0. +infj inf +infj]]
|
||||
|
||||
@ -1481,16 +1481,18 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
|
||||
finfo = np.finfo(dtype)
|
||||
|
||||
def make_axis_points(size):
|
||||
logmin = np.log10(abs(finfo.min))
|
||||
logtiny = np.log10(finfo.tiny)
|
||||
logmax = np.log10(finfo.max)
|
||||
prec_dps_ratio = 3.3219280948873626
|
||||
logmin = logmax = finfo.maxexp / prec_dps_ratio
|
||||
logtiny = finfo.minexp / prec_dps_ratio
|
||||
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# Silence RuntimeWarning: overflow encountered in cast
|
||||
warnings.simplefilter("ignore")
|
||||
axis_points[1:size + 1] = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
|
||||
axis_points[-size - 1:-1] = np.logspace(logtiny, logmax, size, dtype=finfo.dtype)
|
||||
half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
|
||||
half_line = -half_neg_line[::-1]
|
||||
axis_points[-size - 1:-1] = half_line
|
||||
axis_points[1:size + 1] = half_neg_line
|
||||
|
||||
if size > 1:
|
||||
axis_points[1] = finfo.min
|
||||
@ -1512,3 +1514,302 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
|
||||
imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1)
|
||||
|
||||
return real_part + imag_part
|
||||
|
||||
|
||||
class vectorize_with_mpmath(np.vectorize):
|
||||
"""Same as numpy.vectorize but using mpmath backend for function evaluation.
|
||||
"""
|
||||
|
||||
map_float_to_complex = dict(float16='complex32', float32='complex64', float64='complex128', float128='complex256', longdouble='clongdouble')
|
||||
map_complex_to_float = {v: k for k, v in map_float_to_complex.items()}
|
||||
|
||||
float_prec = dict(
|
||||
# float16=11,
|
||||
float32=24,
|
||||
float64=53,
|
||||
# float128=113,
|
||||
# longdouble=113
|
||||
)
|
||||
|
||||
float_minexp = dict(
|
||||
float16=-14,
|
||||
float32=-126,
|
||||
float64=-1022,
|
||||
float128=-16382
|
||||
)
|
||||
|
||||
float_maxexp = dict(
|
||||
float16=16,
|
||||
float32=128,
|
||||
float64=1024,
|
||||
float128=16384,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
mpmath = kwargs.pop('mpmath', None)
|
||||
if mpmath is None:
|
||||
raise ValueError('vectorize_with_mpmath: no mpmath argument specified')
|
||||
self.extra_prec_multiplier = kwargs.pop('extra_prec_multiplier', 0)
|
||||
self.extra_prec = kwargs.pop('extra_prec', 0)
|
||||
self.mpmath = mpmath
|
||||
self.contexts = dict()
|
||||
self.contexts_inv = dict()
|
||||
for fp_format, prec in self.float_prec.items():
|
||||
ctx = self.mpmath.mp.clone()
|
||||
ctx.prec = prec
|
||||
self.contexts[fp_format] = ctx
|
||||
self.contexts_inv[ctx] = fp_format
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_context(self, x):
|
||||
if isinstance(x, (np.ndarray, np.floating, np.complexfloating)):
|
||||
fp_format = str(x.dtype)
|
||||
fp_format = self.map_complex_to_float.get(fp_format, fp_format)
|
||||
return self.contexts[fp_format]
|
||||
raise NotImplementedError(f'get mpmath context from {type(x).__name__} instance')
|
||||
|
||||
def nptomp(self, x):
|
||||
"""Convert numpy array/scalar to an array/instance of mpmath number type.
|
||||
"""
|
||||
if isinstance(x, np.ndarray):
|
||||
return np.fromiter(map(self.nptomp, x.flatten()), dtype=object).reshape(x.shape)
|
||||
elif isinstance(x, np.floating):
|
||||
mpmath = self.mpmath
|
||||
ctx = self.get_context(x)
|
||||
prec, rounding = ctx._prec_rounding
|
||||
if np.isposinf(x):
|
||||
return ctx.make_mpf(mpmath.libmp.finf)
|
||||
elif np.isneginf(x):
|
||||
return ctx.make_mpf(mpmath.libmp.fninf)
|
||||
elif np.isnan(x):
|
||||
return ctx.make_mpf(mpmath.libmp.fnan)
|
||||
elif np.isfinite(x):
|
||||
mantissa, exponent = np.frexp(x)
|
||||
man = int(np.ldexp(mantissa, prec))
|
||||
exp = int(exponent - prec)
|
||||
r = ctx.make_mpf(mpmath.libmp.from_man_exp(man, exp, prec, rounding))
|
||||
assert ctx.isfinite(r), r._mpf_
|
||||
return r
|
||||
elif isinstance(x, np.complexfloating):
|
||||
re, im = self.nptomp(x.real), self.nptomp(x.imag)
|
||||
return re.context.make_mpc((re._mpf_, im._mpf_))
|
||||
raise NotImplementedError(f'convert {type(x).__name__} instance to mpmath number type')
|
||||
|
||||
def mptonp(self, x):
|
||||
"""Convert mpmath instance to numpy array/scalar type.
|
||||
"""
|
||||
if isinstance(x, np.ndarray) and x.dtype.kind == 'O':
|
||||
x_flat = x.flatten()
|
||||
item = x_flat[0]
|
||||
ctx = item.context
|
||||
fp_format = self.contexts_inv[ctx]
|
||||
if isinstance(item, ctx.mpc):
|
||||
dtype = getattr(np, self.map_float_to_complex[fp_format])
|
||||
elif isinstance(item, ctx.mpf):
|
||||
dtype = getattr(np, fp_format)
|
||||
else:
|
||||
dtype = None
|
||||
if dtype is not None:
|
||||
return np.fromiter(map(self.mptonp, x_flat), dtype=dtype).reshape(x.shape)
|
||||
elif isinstance(x, self.mpmath.ctx_mp.mpnumeric):
|
||||
ctx = x.context
|
||||
if isinstance(x, ctx.mpc):
|
||||
fp_format = self.contexts_inv[ctx]
|
||||
dtype = getattr(np, self.map_float_to_complex[fp_format])
|
||||
r = dtype().reshape(1).view(getattr(np, fp_format))
|
||||
r[0] = self.mptonp(x.real)
|
||||
r[1] = self.mptonp(x.imag)
|
||||
return r.view(dtype)[0]
|
||||
elif isinstance(x, ctx.mpf):
|
||||
fp_format = self.contexts_inv[ctx]
|
||||
dtype = getattr(np, fp_format)
|
||||
if ctx.isfinite(x):
|
||||
sign, man, exp, bc = self.mpmath.libmp.normalize(*x._mpf_, *ctx._prec_rounding)
|
||||
assert bc >= 0, (sign, man, exp, bc, x._mpf_)
|
||||
if exp + bc < self.float_minexp[fp_format]:
|
||||
return -ctx.zero if sign else ctx.zero
|
||||
if exp + bc > self.float_maxexp[fp_format]:
|
||||
return ctx.ninf if sign else ctx.inf
|
||||
man = dtype(-man if sign else man)
|
||||
r = np.ldexp(man, exp)
|
||||
assert np.isfinite(r), (x, r, x._mpf_, man)
|
||||
return r
|
||||
elif ctx.isnan(x):
|
||||
return dtype(np.nan)
|
||||
elif ctx.isinf(x):
|
||||
return dtype(-np.inf if x._mpf_[0] else np.inf)
|
||||
raise NotImplementedError(f'convert {type(x)} instance to numpy floating point type')
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
mp_args = []
|
||||
context = None
|
||||
for a in args:
|
||||
if isinstance(a, (np.ndarray, np.floating, np.complexfloating)):
|
||||
mp_args.append(self.nptomp(a))
|
||||
if context is None:
|
||||
context = self.get_context(a)
|
||||
else:
|
||||
assert context is self.get_context(a)
|
||||
else:
|
||||
mp_args.append(a)
|
||||
|
||||
extra_prec = int(context.prec * self.extra_prec_multiplier) + self.extra_prec
|
||||
with context.extraprec(extra_prec):
|
||||
result = super().__call__(*mp_args, **kwargs)
|
||||
|
||||
if isinstance(result, tuple):
|
||||
lst = []
|
||||
for r in result:
|
||||
if ((isinstance(r, np.ndarray) and r.dtype.kind == 'O')
|
||||
or isinstance(r, self.mpmath.ctx_mp.mpnumeric)):
|
||||
r = self.mptonp(r)
|
||||
lst.append(r)
|
||||
return tuple(lst)
|
||||
|
||||
if ((isinstance(result, np.ndarray) and result.dtype.kind == 'O')
|
||||
or isinstance(result, self.mpmath.ctx_mp.mpnumeric)):
|
||||
return self.mptonp(result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class numpy_with_mpmath:
|
||||
"""Namespace of universal functions on numpy arrays that use mpmath
|
||||
backend for evaluation and return numpy arrays as outputs.
|
||||
"""
|
||||
|
||||
_provides = [
|
||||
'abs', 'absolute', 'sqrt', 'exp', 'expm1', 'exp2',
|
||||
'log', 'log1p', 'log10', 'log2',
|
||||
'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
|
||||
'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
|
||||
'square', 'positive', 'negative', 'conjugate', 'sign', 'sinc',
|
||||
'normalize',
|
||||
]
|
||||
|
||||
_mp_names = dict(
|
||||
abs='absmin', absolute='absmin',
|
||||
log='ln',
|
||||
arcsin='asin', arccos='acos', arctan='atan',
|
||||
arcsinh='asinh', arccosh='acosh', arctanh='atanh',
|
||||
)
|
||||
|
||||
def __init__(self, mpmath, extra_prec_multiplier=0, extra_prec=0):
|
||||
self.mpmath = mpmath
|
||||
|
||||
for name in self._provides:
|
||||
mp_name = self._mp_names.get(name, name)
|
||||
|
||||
if hasattr(self, name):
|
||||
op = getattr(self, name)
|
||||
else:
|
||||
|
||||
def op(x, mp_name=mp_name):
|
||||
return getattr(x.context, mp_name)(x)
|
||||
|
||||
setattr(self, name, vectorize_with_mpmath(op, mpmath=mpmath, extra_prec_multiplier=extra_prec_multiplier, extra_prec=extra_prec))
|
||||
|
||||
# The following function methods operate on mpmath number instances.
|
||||
# The corresponding function names must be listed in
|
||||
# numpy_with_mpmath._provides list.
|
||||
|
||||
def square(self, x):
|
||||
return x * x
|
||||
|
||||
def positive(self, x):
|
||||
return x
|
||||
|
||||
def negative(self, x):
|
||||
return -x
|
||||
|
||||
def sqrt(self, x):
|
||||
ctx = x.context
|
||||
# workaround mpmath bugs:
|
||||
if isinstance(x, ctx.mpc):
|
||||
if ctx.isinf(x.real) and ctx.isinf(x.imag):
|
||||
if x.real > 0: return x
|
||||
ninf = x.real
|
||||
inf = -ninf
|
||||
if x.imag > 0: return ctx.make_mpc((inf._mpf_, inf._mpf_))
|
||||
return ctx.make_mpc((inf._mpf_, inf._mpf_))
|
||||
elif ctx.isfinite(x.real) and ctx.isinf(x.imag):
|
||||
if x.imag > 0:
|
||||
inf = x.imag
|
||||
return ctx.make_mpc((inf._mpf_, inf._mpf_))
|
||||
else:
|
||||
ninf = x.imag
|
||||
inf = -ninf
|
||||
return ctx.make_mpc((inf._mpf_, ninf._mpf_))
|
||||
return ctx.sqrt(x)
|
||||
|
||||
def expm1(self, x):
|
||||
return x.context.expm1(x)
|
||||
|
||||
def log2(self, x):
|
||||
return x.context.ln(x) / x.context.ln2
|
||||
|
||||
def log10(self, x):
|
||||
return x.context.ln(x) / x.context.ln10
|
||||
|
||||
def exp2(self, x):
|
||||
return x.context.exp(x * x.context.ln2)
|
||||
|
||||
def normalize(self, exact, reference, value):
|
||||
"""Normalize reference and value using precision defined by the
|
||||
difference of exact and reference.
|
||||
"""
|
||||
def worker(ctx, s, e, r, v):
|
||||
ss, sm, se, sbc = s._mpf_
|
||||
es, em, ee, ebc = e._mpf_
|
||||
rs, rm, re, rbc = r._mpf_
|
||||
vs, vm, ve, vbc = v._mpf_
|
||||
|
||||
if not (ctx.isfinite(e) and ctx.isfinite(r) and ctx.isfinite(v)):
|
||||
return r, v
|
||||
|
||||
me = min(se, ee, re, ve)
|
||||
|
||||
# transform mantissa parts to the same exponent base
|
||||
sm_e = sm << (se - me)
|
||||
em_e = em << (ee - me)
|
||||
rm_e = rm << (re - me)
|
||||
vm_e = vm << (ve - me)
|
||||
|
||||
# find matching higher and non-matching lower bits of e and r
|
||||
sm_b = bin(sm_e)[2:] if sm_e else ''
|
||||
em_b = bin(em_e)[2:] if em_e else ''
|
||||
rm_b = bin(rm_e)[2:] if rm_e else ''
|
||||
vm_b = bin(vm_e)[2:] if vm_e else ''
|
||||
|
||||
m = max(len(sm_b), len(em_b), len(rm_b), len(vm_b))
|
||||
em_b = '0' * (m - len(em_b)) + em_b
|
||||
rm_b = '0' * (m - len(rm_b)) + rm_b
|
||||
|
||||
c1 = 0
|
||||
for b0, b1 in zip(em_b, rm_b):
|
||||
if b0 != b1:
|
||||
break
|
||||
c1 += 1
|
||||
c0 = m - c1
|
||||
|
||||
# truncate r and v mantissa
|
||||
rm_m = rm_e >> c0
|
||||
vm_m = vm_e >> c0
|
||||
|
||||
# normalized r and v
|
||||
nr = ctx.make_mpf((rs, rm_m, -c1, len(bin(rm_m)) - 2)) if rm_m else (-ctx.zero if rs else ctx.zero)
|
||||
nv = ctx.make_mpf((vs, vm_m, -c1, len(bin(vm_m)) - 2)) if vm_m else (-ctx.zero if vs else ctx.zero)
|
||||
|
||||
return nr, nv
|
||||
|
||||
ctx = exact.context
|
||||
scale = abs(exact)
|
||||
if isinstance(exact, ctx.mpc):
|
||||
rr, rv = worker(ctx, scale, exact.real, reference.real, value.real)
|
||||
ir, iv = worker(ctx, scale, exact.imag, reference.imag, value.imag)
|
||||
return ctx.make_mpc((rr._mpf_, ir._mpf_)), ctx.make_mpc((rv._mpf_, iv._mpf_))
|
||||
elif isinstance(exact, ctx.mpf):
|
||||
return worker(ctx, scale, exact, reference, value)
|
||||
else:
|
||||
assert 0 # unreachable
|
||||
|
@ -3365,87 +3365,133 @@ class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_{name}_{dtype.__name__}_{kind}", name=name, dtype=dtype, kind=kind)
|
||||
for name, dtype, kind in itertools.product(
|
||||
[ 'arccos', 'arccosh', 'arcsin', 'arcsinh',
|
||||
'arctan', 'arctanh', 'conjugate', 'cos',
|
||||
'cosh', 'exp', 'exp2', 'expm1', 'log',
|
||||
'log10', 'log1p', 'sin', 'sinh', 'sqrt',
|
||||
'square', 'tan', 'tanh', 'sinc', 'positive',
|
||||
'negative', 'absolute', 'sign'],
|
||||
dict(testcase_name=f"_{dtype.__name__}", dtype=dtype)
|
||||
for dtype in jtu.dtypes.supported([np.float32, np.float64, np.complex64, np.complex128]))
|
||||
def testMPMathUtils(self, dtype):
|
||||
try:
|
||||
import mpmath
|
||||
except ImportError as msg:
|
||||
self.skipTest(f'could not import mpmath: {msg}')
|
||||
|
||||
prec = {np.float32: 24, np.float64: 53, np.complex64: 24, np.complex128: 53}[dtype]
|
||||
is_complex = dtype().dtype.kind == 'c'
|
||||
|
||||
def func(x):
|
||||
assert isinstance(x, mpmath.ctx_mp.mpnumeric)
|
||||
assert x.context.prec == prec
|
||||
assert isinstance(x, x.context.mpc if is_complex else x.context.mpf)
|
||||
return x
|
||||
|
||||
ufunc = jtu.vectorize_with_mpmath(func, mpmath=mpmath)
|
||||
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
||||
if is_complex:
|
||||
arr = jtu.complex_plane_sample(dtype=dtype, size_re=11)
|
||||
else:
|
||||
cdtype = getattr(np, ufunc.map_float_to_complex[dtype.__name__])
|
||||
arr = jtu.complex_plane_sample(dtype=cdtype, size_re=11, size_im=0)[1:2].real
|
||||
|
||||
arr2 = ufunc.mptonp(ufunc.nptomp(arr))
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
||||
self.assertAllClose(arr, arr2, atol=0, rtol=0)
|
||||
|
||||
arr3 = ufunc(arr)
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
||||
self.assertAllClose(arr, arr3, atol=0, rtol=0)
|
||||
|
||||
if is_complex:
|
||||
# tests scale in normalize
|
||||
v = dtype(1.1071487177940644+1.1102230246251565e-16j)
|
||||
r = dtype(1.1071487177940644+0j)
|
||||
mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1)
|
||||
nr, nv = mnp.normalize(r, r, v)
|
||||
self.assertAllClose(nr, nv)
|
||||
|
||||
_functions_on_complex_plane = [
|
||||
'arccos', 'arccosh', 'arcsin', 'arcsinh',
|
||||
'arctan', 'arctanh', 'conjugate', 'cos',
|
||||
'cosh', 'exp', 'exp2', 'expm1', 'log',
|
||||
'log10', 'log1p', 'sin', 'sinh', 'sqrt',
|
||||
'square', 'tan', 'tanh', 'sinc', 'positive',
|
||||
'negative', 'absolute', 'sign'
|
||||
]
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype)
|
||||
for name, dtype in itertools.product(
|
||||
_functions_on_complex_plane,
|
||||
jtu.dtypes.supported([np.complex64, np.complex128]),
|
||||
['success', 'failure'],
|
||||
))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testOnComplexPlane(self, name, dtype, kind):
|
||||
all_regions = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero']
|
||||
def testSuccessOnComplexPlane(self, name, dtype):
|
||||
self._testOnComplexPlaneWorker(name, dtype, 'success')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype)
|
||||
for name, dtype in itertools.product(
|
||||
_functions_on_complex_plane,
|
||||
jtu.dtypes.supported([np.complex64, np.complex128]),
|
||||
))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testFailureOnComplexPlane(self, name, dtype):
|
||||
self._testOnComplexPlaneWorker(name, dtype, 'failure')
|
||||
|
||||
def _testOnComplexPlaneWorker(self, name, dtype, kind):
|
||||
try:
|
||||
import mpmath
|
||||
except ImportError as msg:
|
||||
self.skipTest(f'could not import mpmath: {msg}')
|
||||
|
||||
is_cpu = jtu.test_device_matches(["cpu"])
|
||||
machine = platform.machine()
|
||||
# TODO: remove is_arm_cpu as previously arm cpu related failures
|
||||
# were due to numpy issues. Confirm?
|
||||
is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm')
|
||||
is_cuda = jtu.test_device_matches(["cuda"])
|
||||
|
||||
# TODO(pearu): eliminate all items in the following lists:
|
||||
# TODO(pearu): when all items are eliminated, eliminate the kind == 'failure' tests
|
||||
regions_with_inaccuracies = dict(
|
||||
absolute = ['q1', 'q2', 'q3', 'q4'] if dtype == np.complex128 and is_cuda else [],
|
||||
exp = (['pos', 'pinfj', 'pinf', 'ninfj', 'ninf']
|
||||
+ (['q1', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
|
||||
exp2 = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf', *(['q1', 'q4'] if is_cpu else [])],
|
||||
log = ['q1', 'q2', 'q3', 'q4'],
|
||||
log1p = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'ninfj', 'pinfj'],
|
||||
log10 = ['q1', 'q2', 'q3', 'q4', 'zero', 'ninf', 'ninfj', 'pinf', 'pinfj'],
|
||||
sinh = (['pos', 'neg', 'ninf', 'pinf']
|
||||
+ (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
|
||||
cosh = (['pos', 'neg', 'ninf', 'pinf']
|
||||
+ (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])),
|
||||
tan = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'],
|
||||
square = (['pinf']
|
||||
+ (['ninfj', 'pinfj'] if is_arm_cpu else [])
|
||||
+ (['ninf'] if not is_arm_cpu else [])
|
||||
+ (['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else [])
|
||||
+ (['q1', 'q2', 'q3', 'q4'] if is_cpu and dtype == np.complex128 else [])),
|
||||
sinc = ['q1', 'q2', 'q3', 'q4'],
|
||||
arcsin = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
arccos = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
arctan = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
arcsinh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
arccosh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
arctanh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
sin = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [],
|
||||
cos = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [],
|
||||
expm1 = ['q1', 'q4', 'pinf'] if is_arm_cpu and dtype != np.complex128 else [],
|
||||
)
|
||||
size_re = 11
|
||||
size_im = 11
|
||||
atol = None
|
||||
|
||||
if jtu.numpy_version() < (2, 0, 0):
|
||||
regions_with_inaccuracies['sign'] = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj']
|
||||
mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1)
|
||||
mnp2 = jtu.numpy_with_mpmath(mpmath, extra_prec_multiplier=1)
|
||||
|
||||
ref_op = getattr(mnp, name)
|
||||
ref2_op = getattr(mnp2, name)
|
||||
jnp_op = getattr(jnp, name)
|
||||
|
||||
if name == 'square':
|
||||
# numpy square is incorrect on inputs with large absolute value
|
||||
tiny = np.finfo(dtype).tiny
|
||||
|
||||
def square(x):
|
||||
re = (x.real - x.imag) * (x.real + x.imag)
|
||||
im = x.real * x.imag * 2
|
||||
if is_cuda:
|
||||
# apply FTZ
|
||||
if np.isfinite(re) and abs(re) < tiny:
|
||||
re *= 0
|
||||
if np.isfinite(im) and abs(im) < tiny:
|
||||
im *= 0
|
||||
return np.array(complex(re, im), dtype=dtype)
|
||||
|
||||
np_op = np.vectorize(square)
|
||||
else:
|
||||
np_op = getattr(np, name)
|
||||
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
||||
args = (jtu.complex_plane_sample(dtype=dtype, size_re=11),)
|
||||
args = (jtu.complex_plane_sample(dtype=dtype, size_re=size_re, size_im=size_im),)
|
||||
result = np.asarray(jnp_op(*args))
|
||||
expected = np_op(*args)
|
||||
expected = ref_op(*args)
|
||||
expected2 = ref2_op(*args)
|
||||
|
||||
s0, s1 = (result.shape[0] - 3) // 2, (result.shape[1] - 3) // 2
|
||||
normalized_expected, normalized_result = mnp2.normalize(expected2, expected, result)
|
||||
|
||||
# When comparing the results with expected, we'll divide the
|
||||
# complex plane grid into smaller regions and perform the
|
||||
# closeness tests on each region separately. The reason for this
|
||||
# is that the inaccuracy or incorrectness issues with a particular
|
||||
# function exists typically in specific regions while in other
|
||||
# regions the function is accurate. So, such a division of the
|
||||
# complex plane helps to identify the problematic regions as well
|
||||
# as to fix the inaccuracy or incorrectness issues.
|
||||
#
|
||||
# Regions in complex plane:
|
||||
#
|
||||
# ( pinfj )
|
||||
# ( q2 ) (posj) ( q1 )
|
||||
# (ninf) ( neg ) (zero) ( pos ) (pinf)
|
||||
# ( q3 ) (negj) ( q4 )
|
||||
# ( ninfj )
|
||||
#
|
||||
# In addition, the 1/3 middle parts of regions q1, q2, q3, q4,
|
||||
# neg, pos are tested separately as these don't contain extremely
|
||||
# small or extremelly large values and functions on these regions
|
||||
# ought not to possess any incorrectness issues.
|
||||
|
||||
s0, s1 = size_re, size_im
|
||||
s03, s13 = s0 // 3, s1 // 3
|
||||
s_dict = dict(
|
||||
q1=(slice(s0 + 2, -1), slice(s1 + 2, -1)),
|
||||
q2=(slice(s0 + 2, -1), slice(1, s1 + 1)),
|
||||
@ -3462,38 +3508,167 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
zero=(slice(s0 + 1, s0 + 2), slice(s1 + 1, s1 + 2)),
|
||||
)
|
||||
|
||||
for region in all_regions:
|
||||
if is_arm_cpu:
|
||||
if (
|
||||
(
|
||||
region in ['q1', 'q2', 'q3', 'q4']
|
||||
and name in ['cos', 'cosh', 'sin', 'sinh', 'exp', 'expm1']
|
||||
)
|
||||
or (region in ['pinfj', 'ninfj'] and name in ['sin', 'cos'])
|
||||
or (region == 'pinf' and name in ['expm1'])
|
||||
):
|
||||
continue
|
||||
s = s_dict[region]
|
||||
inds = np.where(result[s] != expected[s])
|
||||
if inds[0].size > 0:
|
||||
mismatches = []
|
||||
for ind in zip(*inds):
|
||||
x, r, e = args[0][s][ind], str(result[s][ind]), str(expected[s][ind])
|
||||
if r == e:
|
||||
# skip equal nan-s
|
||||
continue
|
||||
mismatches.append(f'jax.numpy.{name}{x} -> {r}, expected {e}')
|
||||
mismatches = "\n".join(mismatches)
|
||||
else:
|
||||
mismatches = ''
|
||||
if kind == 'success' and region not in regions_with_inaccuracies.get(name, []):
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
||||
self.assertAllClose(result[s], expected[s], err_msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{mismatches}")
|
||||
if kind == 'failure' and region in regions_with_inaccuracies.get(name, []):
|
||||
with self.assertRaises(AssertionError, msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"):
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
||||
self.assertAllClose(result[s], expected[s]) # on success, update regions_with_inaccuracies
|
||||
if s03 and s13:
|
||||
s_dict.update(
|
||||
mq1 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
|
||||
mq2 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(2 + s13, 2 + 2 * s13)),
|
||||
mq3 = (slice(2 + s03, 2 + 2 * s03), slice(2 + s13, 2 + 2 * s13)),
|
||||
mq4 = (slice(2 + s03, 2 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
|
||||
mneg=(s0 + 1, slice(2 + s13, 2 + 2 * s13)),
|
||||
mpos=(s0 + 1, slice(s1 + 3 + s13, s1 + 3 + 2 * s13)),
|
||||
mnegj=(slice(2 + s03, 2 + 2 * s03), s1 + 1),
|
||||
mposj=(slice(s0 + 3 + s03, s0 + 3 + 2 * s03), s1 + 1),
|
||||
)
|
||||
|
||||
# Start with an assumption that all regions are problematic for a
|
||||
# particular function:
|
||||
regions_with_inaccuracies = list(s_dict)
|
||||
|
||||
# Next, we'll remove non-problematic regions from the
|
||||
# regions_with_inaccuracies list by explicitly keeping problematic
|
||||
# regions:
|
||||
def regions_with_inaccuracies_keep(*to_keep):
|
||||
for item in regions_with_inaccuracies[:]:
|
||||
if item not in to_keep:
|
||||
regions_with_inaccuracies.remove(item)
|
||||
|
||||
if name == 'absolute':
|
||||
if is_cuda and dtype == np.complex128:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4')
|
||||
else:
|
||||
regions_with_inaccuracies.clear()
|
||||
|
||||
elif name == 'sign':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4')
|
||||
|
||||
elif name == 'square':
|
||||
if is_cuda:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj')
|
||||
if is_cpu:
|
||||
regions_with_inaccuracies_keep('ninf', 'pinf')
|
||||
|
||||
elif name == 'log':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj')
|
||||
|
||||
elif name == 'log10':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero')
|
||||
|
||||
elif name == 'log1p':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
|
||||
|
||||
elif name == 'exp':
|
||||
regions_with_inaccuracies_keep('pos', 'pinf', 'mpos')
|
||||
|
||||
elif name == 'exp2':
|
||||
if dtype == np.complex64:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mpos', 'mnegj', 'mposj')
|
||||
if dtype == np.complex128:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mpos')
|
||||
|
||||
elif name == 'expm1':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
|
||||
|
||||
elif name == 'sinc':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
|
||||
|
||||
elif name == 'tan':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj')
|
||||
|
||||
elif name == 'sinh':
|
||||
if is_cuda:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
|
||||
if is_cpu:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
|
||||
|
||||
elif name == 'cosh':
|
||||
regions_with_inaccuracies_keep('neg', 'pos', 'ninf', 'pinf', 'mneg', 'mpos')
|
||||
|
||||
elif name == 'tanh':
|
||||
regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')
|
||||
|
||||
elif name == 'arccos':
|
||||
if dtype == np.complex64:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj')
|
||||
if dtype == np.complex128:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj')
|
||||
|
||||
elif name == 'arccosh':
|
||||
if dtype == np.complex64:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj')
|
||||
if dtype == np.complex128:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mnegj')
|
||||
|
||||
elif name == 'arcsin':
|
||||
if dtype == np.complex64:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
|
||||
if dtype == np.complex128:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
|
||||
|
||||
elif name == 'arcsinh':
|
||||
if dtype == np.complex64:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj')
|
||||
if dtype == np.complex128:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mnegj')
|
||||
|
||||
elif name == 'arctan':
|
||||
if dtype == np.complex64:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj')
|
||||
if dtype == np.complex128:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
|
||||
|
||||
elif name == 'arctanh':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
|
||||
|
||||
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt'}:
|
||||
regions_with_inaccuracies.clear()
|
||||
else:
|
||||
assert 0 # unreachable
|
||||
|
||||
# Finally, perform the closeness tests per region:
|
||||
unexpected_success_regions = []
|
||||
for region_name, region_slice in s_dict.items():
|
||||
region = args[0][region_slice]
|
||||
inexact_indices = np.where(normalized_result[region_slice] != normalized_expected[region_slice])
|
||||
|
||||
if inexact_indices[0].size == 0:
|
||||
inexact_samples = ''
|
||||
else:
|
||||
inexact_samples = []
|
||||
for ind in zip(*inexact_indices):
|
||||
x = region[ind]
|
||||
y1, y2 = result[region_slice][ind], expected[region_slice][ind]
|
||||
ny1, ny2 = normalized_result[region_slice][ind], normalized_expected[region_slice][ind]
|
||||
if str(y1) == str(y2): # skip equal nan-s
|
||||
continue
|
||||
max_abs_diff = abs(ny1 - ny2).max() if np.isfinite(y1) and np.isfinite(y1) else np.inf
|
||||
inexact_samples.append((max_abs_diff, f'jax.numpy.{name}({x}) -> {y1} [{ny1}], expected {y2} [{ny2}]'))
|
||||
inexact_samples = "\n".join([msg for _, msg in sorted(inexact_samples)])
|
||||
|
||||
if kind == 'success' and region_name not in regions_with_inaccuracies:
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
||||
self.assertAllClose(
|
||||
normalized_result[region_slice], normalized_expected[region_slice], atol=atol,
|
||||
err_msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{inexact_samples}")
|
||||
|
||||
if kind == 'failure' and region_name in regions_with_inaccuracies:
|
||||
try:
|
||||
with self.assertRaises(AssertionError, msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"):
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
||||
self.assertAllClose(normalized_result[region_slice], normalized_expected[region_slice])
|
||||
except AssertionError as msg:
|
||||
if str(msg).startswith('AssertionError not raised'):
|
||||
unexpected_success_regions.append(region_name)
|
||||
else:
|
||||
raise # something else is wrong..
|
||||
|
||||
if kind == 'success' and regions_with_inaccuracies:
|
||||
reason = "xfail: problematic regions: " + ", ".join(regions_with_inaccuracies)
|
||||
raise unittest.SkipTest(reason)
|
||||
|
||||
if kind == 'failure':
|
||||
self.assertEqual(unexpected_success_regions, []) # regions_with_inaccuracies requires an update!
|
||||
if not regions_with_inaccuracies:
|
||||
raise unittest.SkipTest("no problematic regions")
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user