diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 0744b2ac3..800bc735d 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -4,6 +4,7 @@ cloudpickle colorama>=0.4.4 flatbuffers hypothesis +mpmath>=1.3 numpy>=1.22 pillow>=9.1.0 portpicker diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 22addcc59..63d9597f0 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1467,11 +1467,11 @@ def complex_plane_sample(dtype, size_re=10, size_im=None): >>> print(complex_plane_sample(np.complex64, 0, 3)) [[-inf -infj 0. -infj inf -infj] [-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j] - [-inf-2.0000052e+00j 0.-2.0000052e+00j inf-2.0000052e+00j] + [-inf-2.0000000e+00j 0.-2.0000000e+00j inf-2.0000000e+00j] [-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j] [-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j] [-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j] - [-inf+2.0000052e+00j 0.+2.0000052e+00j inf+2.0000052e+00j] + [-inf+2.0000000e+00j 0.+2.0000000e+00j inf+2.0000000e+00j] [-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j] [-inf +infj 0. +infj inf +infj]] @@ -1481,16 +1481,18 @@ def complex_plane_sample(dtype, size_re=10, size_im=None): finfo = np.finfo(dtype) def make_axis_points(size): - logmin = np.log10(abs(finfo.min)) - logtiny = np.log10(finfo.tiny) - logmax = np.log10(finfo.max) + prec_dps_ratio = 3.3219280948873626 + logmin = logmax = finfo.maxexp / prec_dps_ratio + logtiny = finfo.minexp / prec_dps_ratio axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype) with warnings.catch_warnings(): # Silence RuntimeWarning: overflow encountered in cast warnings.simplefilter("ignore") - axis_points[1:size + 1] = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype) - axis_points[-size - 1:-1] = np.logspace(logtiny, logmax, size, dtype=finfo.dtype) + half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype) + half_line = -half_neg_line[::-1] + axis_points[-size - 1:-1] = half_line + axis_points[1:size + 1] = half_neg_line if size > 1: axis_points[1] = finfo.min @@ -1512,3 +1514,302 @@ def complex_plane_sample(dtype, size_re=10, size_im=None): imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1) return real_part + imag_part + + +class vectorize_with_mpmath(np.vectorize): + """Same as numpy.vectorize but using mpmath backend for function evaluation. + """ + + map_float_to_complex = dict(float16='complex32', float32='complex64', float64='complex128', float128='complex256', longdouble='clongdouble') + map_complex_to_float = {v: k for k, v in map_float_to_complex.items()} + + float_prec = dict( + # float16=11, + float32=24, + float64=53, + # float128=113, + # longdouble=113 + ) + + float_minexp = dict( + float16=-14, + float32=-126, + float64=-1022, + float128=-16382 + ) + + float_maxexp = dict( + float16=16, + float32=128, + float64=1024, + float128=16384, + ) + + def __init__(self, *args, **kwargs): + mpmath = kwargs.pop('mpmath', None) + if mpmath is None: + raise ValueError('vectorize_with_mpmath: no mpmath argument specified') + self.extra_prec_multiplier = kwargs.pop('extra_prec_multiplier', 0) + self.extra_prec = kwargs.pop('extra_prec', 0) + self.mpmath = mpmath + self.contexts = dict() + self.contexts_inv = dict() + for fp_format, prec in self.float_prec.items(): + ctx = self.mpmath.mp.clone() + ctx.prec = prec + self.contexts[fp_format] = ctx + self.contexts_inv[ctx] = fp_format + + super().__init__(*args, **kwargs) + + def get_context(self, x): + if isinstance(x, (np.ndarray, np.floating, np.complexfloating)): + fp_format = str(x.dtype) + fp_format = self.map_complex_to_float.get(fp_format, fp_format) + return self.contexts[fp_format] + raise NotImplementedError(f'get mpmath context from {type(x).__name__} instance') + + def nptomp(self, x): + """Convert numpy array/scalar to an array/instance of mpmath number type. + """ + if isinstance(x, np.ndarray): + return np.fromiter(map(self.nptomp, x.flatten()), dtype=object).reshape(x.shape) + elif isinstance(x, np.floating): + mpmath = self.mpmath + ctx = self.get_context(x) + prec, rounding = ctx._prec_rounding + if np.isposinf(x): + return ctx.make_mpf(mpmath.libmp.finf) + elif np.isneginf(x): + return ctx.make_mpf(mpmath.libmp.fninf) + elif np.isnan(x): + return ctx.make_mpf(mpmath.libmp.fnan) + elif np.isfinite(x): + mantissa, exponent = np.frexp(x) + man = int(np.ldexp(mantissa, prec)) + exp = int(exponent - prec) + r = ctx.make_mpf(mpmath.libmp.from_man_exp(man, exp, prec, rounding)) + assert ctx.isfinite(r), r._mpf_ + return r + elif isinstance(x, np.complexfloating): + re, im = self.nptomp(x.real), self.nptomp(x.imag) + return re.context.make_mpc((re._mpf_, im._mpf_)) + raise NotImplementedError(f'convert {type(x).__name__} instance to mpmath number type') + + def mptonp(self, x): + """Convert mpmath instance to numpy array/scalar type. + """ + if isinstance(x, np.ndarray) and x.dtype.kind == 'O': + x_flat = x.flatten() + item = x_flat[0] + ctx = item.context + fp_format = self.contexts_inv[ctx] + if isinstance(item, ctx.mpc): + dtype = getattr(np, self.map_float_to_complex[fp_format]) + elif isinstance(item, ctx.mpf): + dtype = getattr(np, fp_format) + else: + dtype = None + if dtype is not None: + return np.fromiter(map(self.mptonp, x_flat), dtype=dtype).reshape(x.shape) + elif isinstance(x, self.mpmath.ctx_mp.mpnumeric): + ctx = x.context + if isinstance(x, ctx.mpc): + fp_format = self.contexts_inv[ctx] + dtype = getattr(np, self.map_float_to_complex[fp_format]) + r = dtype().reshape(1).view(getattr(np, fp_format)) + r[0] = self.mptonp(x.real) + r[1] = self.mptonp(x.imag) + return r.view(dtype)[0] + elif isinstance(x, ctx.mpf): + fp_format = self.contexts_inv[ctx] + dtype = getattr(np, fp_format) + if ctx.isfinite(x): + sign, man, exp, bc = self.mpmath.libmp.normalize(*x._mpf_, *ctx._prec_rounding) + assert bc >= 0, (sign, man, exp, bc, x._mpf_) + if exp + bc < self.float_minexp[fp_format]: + return -ctx.zero if sign else ctx.zero + if exp + bc > self.float_maxexp[fp_format]: + return ctx.ninf if sign else ctx.inf + man = dtype(-man if sign else man) + r = np.ldexp(man, exp) + assert np.isfinite(r), (x, r, x._mpf_, man) + return r + elif ctx.isnan(x): + return dtype(np.nan) + elif ctx.isinf(x): + return dtype(-np.inf if x._mpf_[0] else np.inf) + raise NotImplementedError(f'convert {type(x)} instance to numpy floating point type') + + def __call__(self, *args, **kwargs): + mp_args = [] + context = None + for a in args: + if isinstance(a, (np.ndarray, np.floating, np.complexfloating)): + mp_args.append(self.nptomp(a)) + if context is None: + context = self.get_context(a) + else: + assert context is self.get_context(a) + else: + mp_args.append(a) + + extra_prec = int(context.prec * self.extra_prec_multiplier) + self.extra_prec + with context.extraprec(extra_prec): + result = super().__call__(*mp_args, **kwargs) + + if isinstance(result, tuple): + lst = [] + for r in result: + if ((isinstance(r, np.ndarray) and r.dtype.kind == 'O') + or isinstance(r, self.mpmath.ctx_mp.mpnumeric)): + r = self.mptonp(r) + lst.append(r) + return tuple(lst) + + if ((isinstance(result, np.ndarray) and result.dtype.kind == 'O') + or isinstance(result, self.mpmath.ctx_mp.mpnumeric)): + return self.mptonp(result) + + return result + + +class numpy_with_mpmath: + """Namespace of universal functions on numpy arrays that use mpmath + backend for evaluation and return numpy arrays as outputs. + """ + + _provides = [ + 'abs', 'absolute', 'sqrt', 'exp', 'expm1', 'exp2', + 'log', 'log1p', 'log10', 'log2', + 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', + 'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh', + 'square', 'positive', 'negative', 'conjugate', 'sign', 'sinc', + 'normalize', + ] + + _mp_names = dict( + abs='absmin', absolute='absmin', + log='ln', + arcsin='asin', arccos='acos', arctan='atan', + arcsinh='asinh', arccosh='acosh', arctanh='atanh', + ) + + def __init__(self, mpmath, extra_prec_multiplier=0, extra_prec=0): + self.mpmath = mpmath + + for name in self._provides: + mp_name = self._mp_names.get(name, name) + + if hasattr(self, name): + op = getattr(self, name) + else: + + def op(x, mp_name=mp_name): + return getattr(x.context, mp_name)(x) + + setattr(self, name, vectorize_with_mpmath(op, mpmath=mpmath, extra_prec_multiplier=extra_prec_multiplier, extra_prec=extra_prec)) + + # The following function methods operate on mpmath number instances. + # The corresponding function names must be listed in + # numpy_with_mpmath._provides list. + + def square(self, x): + return x * x + + def positive(self, x): + return x + + def negative(self, x): + return -x + + def sqrt(self, x): + ctx = x.context + # workaround mpmath bugs: + if isinstance(x, ctx.mpc): + if ctx.isinf(x.real) and ctx.isinf(x.imag): + if x.real > 0: return x + ninf = x.real + inf = -ninf + if x.imag > 0: return ctx.make_mpc((inf._mpf_, inf._mpf_)) + return ctx.make_mpc((inf._mpf_, inf._mpf_)) + elif ctx.isfinite(x.real) and ctx.isinf(x.imag): + if x.imag > 0: + inf = x.imag + return ctx.make_mpc((inf._mpf_, inf._mpf_)) + else: + ninf = x.imag + inf = -ninf + return ctx.make_mpc((inf._mpf_, ninf._mpf_)) + return ctx.sqrt(x) + + def expm1(self, x): + return x.context.expm1(x) + + def log2(self, x): + return x.context.ln(x) / x.context.ln2 + + def log10(self, x): + return x.context.ln(x) / x.context.ln10 + + def exp2(self, x): + return x.context.exp(x * x.context.ln2) + + def normalize(self, exact, reference, value): + """Normalize reference and value using precision defined by the + difference of exact and reference. + """ + def worker(ctx, s, e, r, v): + ss, sm, se, sbc = s._mpf_ + es, em, ee, ebc = e._mpf_ + rs, rm, re, rbc = r._mpf_ + vs, vm, ve, vbc = v._mpf_ + + if not (ctx.isfinite(e) and ctx.isfinite(r) and ctx.isfinite(v)): + return r, v + + me = min(se, ee, re, ve) + + # transform mantissa parts to the same exponent base + sm_e = sm << (se - me) + em_e = em << (ee - me) + rm_e = rm << (re - me) + vm_e = vm << (ve - me) + + # find matching higher and non-matching lower bits of e and r + sm_b = bin(sm_e)[2:] if sm_e else '' + em_b = bin(em_e)[2:] if em_e else '' + rm_b = bin(rm_e)[2:] if rm_e else '' + vm_b = bin(vm_e)[2:] if vm_e else '' + + m = max(len(sm_b), len(em_b), len(rm_b), len(vm_b)) + em_b = '0' * (m - len(em_b)) + em_b + rm_b = '0' * (m - len(rm_b)) + rm_b + + c1 = 0 + for b0, b1 in zip(em_b, rm_b): + if b0 != b1: + break + c1 += 1 + c0 = m - c1 + + # truncate r and v mantissa + rm_m = rm_e >> c0 + vm_m = vm_e >> c0 + + # normalized r and v + nr = ctx.make_mpf((rs, rm_m, -c1, len(bin(rm_m)) - 2)) if rm_m else (-ctx.zero if rs else ctx.zero) + nv = ctx.make_mpf((vs, vm_m, -c1, len(bin(vm_m)) - 2)) if vm_m else (-ctx.zero if vs else ctx.zero) + + return nr, nv + + ctx = exact.context + scale = abs(exact) + if isinstance(exact, ctx.mpc): + rr, rv = worker(ctx, scale, exact.real, reference.real, value.real) + ir, iv = worker(ctx, scale, exact.imag, reference.imag, value.imag) + return ctx.make_mpc((rr._mpf_, ir._mpf_)), ctx.make_mpc((rv._mpf_, iv._mpf_)) + elif isinstance(exact, ctx.mpf): + return worker(ctx, scale, exact, reference, value) + else: + assert 0 # unreachable diff --git a/tests/lax_test.py b/tests/lax_test.py index a3980e056..4a5c0be2e 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3365,87 +3365,133 @@ class CustomElementTypesTest(jtu.JaxTestCase): class FunctionAccuracyTest(jtu.JaxTestCase): @parameterized.named_parameters( - dict(testcase_name=f"_{name}_{dtype.__name__}_{kind}", name=name, dtype=dtype, kind=kind) - for name, dtype, kind in itertools.product( - [ 'arccos', 'arccosh', 'arcsin', 'arcsinh', - 'arctan', 'arctanh', 'conjugate', 'cos', - 'cosh', 'exp', 'exp2', 'expm1', 'log', - 'log10', 'log1p', 'sin', 'sinh', 'sqrt', - 'square', 'tan', 'tanh', 'sinc', 'positive', - 'negative', 'absolute', 'sign'], + dict(testcase_name=f"_{dtype.__name__}", dtype=dtype) + for dtype in jtu.dtypes.supported([np.float32, np.float64, np.complex64, np.complex128])) + def testMPMathUtils(self, dtype): + try: + import mpmath + except ImportError as msg: + self.skipTest(f'could not import mpmath: {msg}') + + prec = {np.float32: 24, np.float64: 53, np.complex64: 24, np.complex128: 53}[dtype] + is_complex = dtype().dtype.kind == 'c' + + def func(x): + assert isinstance(x, mpmath.ctx_mp.mpnumeric) + assert x.context.prec == prec + assert isinstance(x, x.context.mpc if is_complex else x.context.mpf) + return x + + ufunc = jtu.vectorize_with_mpmath(func, mpmath=mpmath) + + with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): + if is_complex: + arr = jtu.complex_plane_sample(dtype=dtype, size_re=11) + else: + cdtype = getattr(np, ufunc.map_float_to_complex[dtype.__name__]) + arr = jtu.complex_plane_sample(dtype=cdtype, size_re=11, size_im=0)[1:2].real + + arr2 = ufunc.mptonp(ufunc.nptomp(arr)) + with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): + self.assertAllClose(arr, arr2, atol=0, rtol=0) + + arr3 = ufunc(arr) + with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): + self.assertAllClose(arr, arr3, atol=0, rtol=0) + + if is_complex: + # tests scale in normalize + v = dtype(1.1071487177940644+1.1102230246251565e-16j) + r = dtype(1.1071487177940644+0j) + mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1) + nr, nv = mnp.normalize(r, r, v) + self.assertAllClose(nr, nv) + + _functions_on_complex_plane = [ + 'arccos', 'arccosh', 'arcsin', 'arcsinh', + 'arctan', 'arctanh', 'conjugate', 'cos', + 'cosh', 'exp', 'exp2', 'expm1', 'log', + 'log10', 'log1p', 'sin', 'sinh', 'sqrt', + 'square', 'tan', 'tanh', 'sinc', 'positive', + 'negative', 'absolute', 'sign' + ] + + @parameterized.named_parameters( + dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype) + for name, dtype in itertools.product( + _functions_on_complex_plane, jtu.dtypes.supported([np.complex64, np.complex128]), - ['success', 'failure'], )) @jtu.skip_on_devices("tpu") - def testOnComplexPlane(self, name, dtype, kind): - all_regions = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero'] + def testSuccessOnComplexPlane(self, name, dtype): + self._testOnComplexPlaneWorker(name, dtype, 'success') + + @parameterized.named_parameters( + dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype) + for name, dtype in itertools.product( + _functions_on_complex_plane, + jtu.dtypes.supported([np.complex64, np.complex128]), + )) + @jtu.skip_on_devices("tpu") + def testFailureOnComplexPlane(self, name, dtype): + self._testOnComplexPlaneWorker(name, dtype, 'failure') + + def _testOnComplexPlaneWorker(self, name, dtype, kind): + try: + import mpmath + except ImportError as msg: + self.skipTest(f'could not import mpmath: {msg}') + is_cpu = jtu.test_device_matches(["cpu"]) machine = platform.machine() + # TODO: remove is_arm_cpu as previously arm cpu related failures + # were due to numpy issues. Confirm? is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm') is_cuda = jtu.test_device_matches(["cuda"]) - # TODO(pearu): eliminate all items in the following lists: - # TODO(pearu): when all items are eliminated, eliminate the kind == 'failure' tests - regions_with_inaccuracies = dict( - absolute = ['q1', 'q2', 'q3', 'q4'] if dtype == np.complex128 and is_cuda else [], - exp = (['pos', 'pinfj', 'pinf', 'ninfj', 'ninf'] - + (['q1', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])), - exp2 = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf', *(['q1', 'q4'] if is_cpu else [])], - log = ['q1', 'q2', 'q3', 'q4'], - log1p = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'ninfj', 'pinfj'], - log10 = ['q1', 'q2', 'q3', 'q4', 'zero', 'ninf', 'ninfj', 'pinf', 'pinfj'], - sinh = (['pos', 'neg', 'ninf', 'pinf'] - + (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])), - cosh = (['pos', 'neg', 'ninf', 'pinf'] - + (['q1', 'q2', 'q3', 'q4'] if is_arm_cpu and dtype != np.complex128 else [])), - tan = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'], - square = (['pinf'] - + (['ninfj', 'pinfj'] if is_arm_cpu else []) - + (['ninf'] if not is_arm_cpu else []) - + (['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else []) - + (['q1', 'q2', 'q3', 'q4'] if is_cpu and dtype == np.complex128 else [])), - sinc = ['q1', 'q2', 'q3', 'q4'], - arcsin = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - arccos = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - arctan = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - arcsinh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - arccosh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - arctanh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], - sin = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [], - cos = ['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_arm_cpu and dtype != np.complex128 else [], - expm1 = ['q1', 'q4', 'pinf'] if is_arm_cpu and dtype != np.complex128 else [], - ) + size_re = 11 + size_im = 11 + atol = None - if jtu.numpy_version() < (2, 0, 0): - regions_with_inaccuracies['sign'] = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'] + mnp = jtu.numpy_with_mpmath(mpmath, extra_prec=1) + mnp2 = jtu.numpy_with_mpmath(mpmath, extra_prec_multiplier=1) + ref_op = getattr(mnp, name) + ref2_op = getattr(mnp2, name) jnp_op = getattr(jnp, name) - if name == 'square': - # numpy square is incorrect on inputs with large absolute value - tiny = np.finfo(dtype).tiny - - def square(x): - re = (x.real - x.imag) * (x.real + x.imag) - im = x.real * x.imag * 2 - if is_cuda: - # apply FTZ - if np.isfinite(re) and abs(re) < tiny: - re *= 0 - if np.isfinite(im) and abs(im) < tiny: - im *= 0 - return np.array(complex(re, im), dtype=dtype) - - np_op = np.vectorize(square) - else: - np_op = getattr(np, name) - with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): - args = (jtu.complex_plane_sample(dtype=dtype, size_re=11),) + args = (jtu.complex_plane_sample(dtype=dtype, size_re=size_re, size_im=size_im),) result = np.asarray(jnp_op(*args)) - expected = np_op(*args) + expected = ref_op(*args) + expected2 = ref2_op(*args) - s0, s1 = (result.shape[0] - 3) // 2, (result.shape[1] - 3) // 2 + normalized_expected, normalized_result = mnp2.normalize(expected2, expected, result) + + # When comparing the results with expected, we'll divide the + # complex plane grid into smaller regions and perform the + # closeness tests on each region separately. The reason for this + # is that the inaccuracy or incorrectness issues with a particular + # function exists typically in specific regions while in other + # regions the function is accurate. So, such a division of the + # complex plane helps to identify the problematic regions as well + # as to fix the inaccuracy or incorrectness issues. + # + # Regions in complex plane: + # + # ( pinfj ) + # ( q2 ) (posj) ( q1 ) + # (ninf) ( neg ) (zero) ( pos ) (pinf) + # ( q3 ) (negj) ( q4 ) + # ( ninfj ) + # + # In addition, the 1/3 middle parts of regions q1, q2, q3, q4, + # neg, pos are tested separately as these don't contain extremely + # small or extremelly large values and functions on these regions + # ought not to possess any incorrectness issues. + + s0, s1 = size_re, size_im + s03, s13 = s0 // 3, s1 // 3 s_dict = dict( q1=(slice(s0 + 2, -1), slice(s1 + 2, -1)), q2=(slice(s0 + 2, -1), slice(1, s1 + 1)), @@ -3462,38 +3508,167 @@ class FunctionAccuracyTest(jtu.JaxTestCase): zero=(slice(s0 + 1, s0 + 2), slice(s1 + 1, s1 + 2)), ) - for region in all_regions: - if is_arm_cpu: - if ( - ( - region in ['q1', 'q2', 'q3', 'q4'] - and name in ['cos', 'cosh', 'sin', 'sinh', 'exp', 'expm1'] - ) - or (region in ['pinfj', 'ninfj'] and name in ['sin', 'cos']) - or (region == 'pinf' and name in ['expm1']) - ): - continue - s = s_dict[region] - inds = np.where(result[s] != expected[s]) - if inds[0].size > 0: - mismatches = [] - for ind in zip(*inds): - x, r, e = args[0][s][ind], str(result[s][ind]), str(expected[s][ind]) - if r == e: - # skip equal nan-s - continue - mismatches.append(f'jax.numpy.{name}{x} -> {r}, expected {e}') - mismatches = "\n".join(mismatches) - else: - mismatches = '' - if kind == 'success' and region not in regions_with_inaccuracies.get(name, []): - with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): - self.assertAllClose(result[s], expected[s], err_msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{mismatches}") - if kind == 'failure' and region in regions_with_inaccuracies.get(name, []): - with self.assertRaises(AssertionError, msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"): - with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): - self.assertAllClose(result[s], expected[s]) # on success, update regions_with_inaccuracies + if s03 and s13: + s_dict.update( + mq1 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)), + mq2 = (slice(s0 + 3 + s03, s0 + 3 + 2 * s03), slice(2 + s13, 2 + 2 * s13)), + mq3 = (slice(2 + s03, 2 + 2 * s03), slice(2 + s13, 2 + 2 * s13)), + mq4 = (slice(2 + s03, 2 + 2 * s03), slice(s1 + 3 + s13, s1 + 3 + 2 * s13)), + mneg=(s0 + 1, slice(2 + s13, 2 + 2 * s13)), + mpos=(s0 + 1, slice(s1 + 3 + s13, s1 + 3 + 2 * s13)), + mnegj=(slice(2 + s03, 2 + 2 * s03), s1 + 1), + mposj=(slice(s0 + 3 + s03, s0 + 3 + 2 * s03), s1 + 1), + ) + # Start with an assumption that all regions are problematic for a + # particular function: + regions_with_inaccuracies = list(s_dict) + + # Next, we'll remove non-problematic regions from the + # regions_with_inaccuracies list by explicitly keeping problematic + # regions: + def regions_with_inaccuracies_keep(*to_keep): + for item in regions_with_inaccuracies[:]: + if item not in to_keep: + regions_with_inaccuracies.remove(item) + + if name == 'absolute': + if is_cuda and dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4') + else: + regions_with_inaccuracies.clear() + + elif name == 'sign': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4') + + elif name == 'square': + if is_cuda: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj') + if is_cpu: + regions_with_inaccuracies_keep('ninf', 'pinf') + + elif name == 'log': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj') + + elif name == 'log10': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero') + + elif name == 'log1p': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj') + + elif name == 'exp': + regions_with_inaccuracies_keep('pos', 'pinf', 'mpos') + + elif name == 'exp2': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mpos', 'mnegj', 'mposj') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mpos') + + elif name == 'expm1': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos') + + elif name == 'sinc': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj') + + elif name == 'tan': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj') + + elif name == 'sinh': + if is_cuda: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos') + if is_cpu: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos') + + elif name == 'cosh': + regions_with_inaccuracies_keep('neg', 'pos', 'ninf', 'pinf', 'mneg', 'mpos') + + elif name == 'tanh': + regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj') + + elif name == 'arccos': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj') + + elif name == 'arccosh': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mnegj') + + elif name == 'arcsin': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj') + + elif name == 'arcsinh': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mnegj') + + elif name == 'arctan': + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj') + + elif name == 'arctanh': + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos') + + elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt'}: + regions_with_inaccuracies.clear() + else: + assert 0 # unreachable + + # Finally, perform the closeness tests per region: + unexpected_success_regions = [] + for region_name, region_slice in s_dict.items(): + region = args[0][region_slice] + inexact_indices = np.where(normalized_result[region_slice] != normalized_expected[region_slice]) + + if inexact_indices[0].size == 0: + inexact_samples = '' + else: + inexact_samples = [] + for ind in zip(*inexact_indices): + x = region[ind] + y1, y2 = result[region_slice][ind], expected[region_slice][ind] + ny1, ny2 = normalized_result[region_slice][ind], normalized_expected[region_slice][ind] + if str(y1) == str(y2): # skip equal nan-s + continue + max_abs_diff = abs(ny1 - ny2).max() if np.isfinite(y1) and np.isfinite(y1) else np.inf + inexact_samples.append((max_abs_diff, f'jax.numpy.{name}({x}) -> {y1} [{ny1}], expected {y2} [{ny2}]')) + inexact_samples = "\n".join([msg for _, msg in sorted(inexact_samples)]) + + if kind == 'success' and region_name not in regions_with_inaccuracies: + with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): + self.assertAllClose( + normalized_result[region_slice], normalized_expected[region_slice], atol=atol, + err_msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{inexact_samples}") + + if kind == 'failure' and region_name in regions_with_inaccuracies: + try: + with self.assertRaises(AssertionError, msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"): + with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): + self.assertAllClose(normalized_result[region_slice], normalized_expected[region_slice]) + except AssertionError as msg: + if str(msg).startswith('AssertionError not raised'): + unexpected_success_regions.append(region_name) + else: + raise # something else is wrong.. + + if kind == 'success' and regions_with_inaccuracies: + reason = "xfail: problematic regions: " + ", ".join(regions_with_inaccuracies) + raise unittest.SkipTest(reason) + + if kind == 'failure': + self.assertEqual(unexpected_success_regions, []) # regions_with_inaccuracies requires an update! + if not regions_with_inaccuracies: + raise unittest.SkipTest("no problematic regions") if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())