From 6d8b3e4cff97d966e56670e70957334885439b76 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Thu, 15 Feb 2024 13:29:35 +0200 Subject: [PATCH] Fix complex sin and cos on inputs with small absolute value or large pure imaginary part --- jax/_src/lax/lax.py | 42 ++++++++++- jax/_src/test_util.py | 60 +++++++++++++++ tests/filecheck/shapes.filecheck.py | 6 +- tests/lax_test.py | 113 ++++++++++++++++++++++++++++ 4 files changed, 217 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 6ab420f78..a168d4fbd 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1837,13 +1837,51 @@ def logistic_impl(x): mlir.register_lowering(logistic_p, mlir.lower_fun(logistic_impl, multiple_results=False)) +def _sin_complex(x): + # use expm1 instead of exp to avoid cancellation when abs(x) is small + # relies on the quality of real-valued expm1, sin, cos + # sin(x) = complex(sin(real(x)) * cosh(imag(x)), cos(real(x)) * sinh(imag(x))) + # 2 * sinh(x) = exp(x) - 1 - (exp(-x) - 1) = expm1(x) - expm1(-x) + # 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2 + a, b = real(x), imag(x) + a_is_zero = eq(a, _const(a, 0)) + sn, cs = sin(a), cos(a) + e1m, e2m = expm1(b), expm1(-b) + snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2 + re, im = sn * csh, cs * snh + # avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf + return select(a_is_zero, complex(_const(a, 0), im), complex(re, im)) + +def _sin_lowering(ctx, x): + if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): + sine = mlir.lower_fun(_sin_complex, multiple_results=False) + return sine(ctx, x) + return _nary_lower_hlo(hlo.sine, ctx, x) + sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) -mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.sine)) +mlir.register_lowering(sin_p, _sin_lowering) + +def _cos_complex(x): + # cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x))) + # see also _sin_complex + a, b = real(x), imag(x) + a_is_zero = eq(a, _const(a, 0)) + sn, cs = sin(a), cos(a) + e1m, e2m = expm1(b), expm1(-b) + snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2 + re, im = cs * csh, -sn * snh + return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im)) + +def _cos_lowering(ctx, x): + if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): + cosine = mlir.lower_fun(_cos_complex, multiple_results=False) + return cosine(ctx, x) + return _nary_lower_hlo(hlo.cosine, ctx, x) cos_p = standard_unop(_float | _complex, 'cos') ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) -mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.cosine)) +mlir.register_lowering(cos_p, _cos_lowering) @_upcast_fp16_for_computation def _tan_impl(x): diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index dd84d6318..7dc57b53a 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1413,3 +1413,63 @@ def numpy_vecdot(x, y, axis): y = np.moveaxis(y, axis, -1) x, y = np.broadcast_arrays(x, y) return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0] + + +def complex_plane_sample(dtype, size_re=10, size_im=None): + """Return a 2-D array of complex numbers that covers the complex plane + with a grid of samples. + + The size of the grid is (3 + 2 * size_im) x (3 + 2 * size_re) + that includes infinity points, extreme finite points, and the + specified number of points from real and imaginary axis. + + For example: + + >>> print(complex_plane_sample(np.complex64, 0, 3)) + [[-inf -infj 0. -infj inf -infj] + [-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j] + [-inf-2.0000052e+00j 0.-2.0000052e+00j inf-2.0000052e+00j] + [-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j] + [-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j] + [-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j] + [-inf+2.0000052e+00j 0.+2.0000052e+00j inf+2.0000052e+00j] + [-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j] + [-inf +infj 0. +infj inf +infj]] + + """ + if size_im is None: + size_im = size_re + finfo = np.finfo(dtype) + + def make_axis_points(size): + logmin = np.log10(abs(finfo.min)) + logtiny = np.log10(finfo.tiny) + logmax = np.log10(finfo.max) + axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype) + + with warnings.catch_warnings(): + # Silence RuntimeWarning: overflow encountered in cast + warnings.simplefilter("ignore") + axis_points[1:size + 1] = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype) + axis_points[-size - 1:-1] = np.logspace(logtiny, logmax, size, dtype=finfo.dtype) + + if size > 1: + axis_points[1] = finfo.min + axis_points[-2] = finfo.max + if size > 0: + axis_points[size] = -finfo.tiny + axis_points[-size - 1] = finfo.tiny + axis_points[0] = -np.inf + axis_points[-1] = np.inf + return axis_points + + real_axis_points = make_axis_points(size_re) + imag_axis_points = make_axis_points(size_im) + + real_part = real_axis_points.reshape((-1, 3 + 2 * size_re)).repeat(3 + 2 * size_im, 0).astype(dtype) + + imag_part = imag_axis_points.repeat(2).view(dtype) + imag_part.real[:] = 0 + imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1) + + return real_part + imag_part diff --git a/tests/filecheck/shapes.filecheck.py b/tests/filecheck/shapes.filecheck.py index 5cd8feb86..f834e8c78 100644 --- a/tests/filecheck/shapes.filecheck.py +++ b/tests/filecheck/shapes.filecheck.py @@ -96,12 +96,14 @@ def main(_): # CHECK-LABEL: TEST: cos complex64[] # CHECK: hlo.cosine - # CHECK-SAME: tensor> + # TODO: when the accuracy of lax.cos is fixed upstream, undo relevant parts of jax PR 19823 + # CHECK-SAME: tensor print_ir(np.complex64(0))(lax.cos) # CHECK-LABEL: TEST: cos complex128[] # CHECK: hlo.cosine - # CHECK-SAME: tensor> + # TODO: when the accuracy of lax.cos is fixed upstream, undo relevant parts of jax PR 19823 + # CHECK-SAME: tensor print_ir(np.complex128(0))(lax.cos) diff --git a/tests/lax_test.py b/tests/lax_test.py index f7446d010..ecd9ec4a9 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -46,6 +46,7 @@ from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.util import NumpyComplexWarning config.parse_flags_with_absl() @@ -3335,5 +3336,117 @@ class CustomElementTypesTest(jtu.JaxTestCase): # TODO(frostig,mattjj): more polymorphic primitives tests + +class FunctionAccuracyTest(jtu.JaxTestCase): + + @parameterized.named_parameters( + dict(testcase_name=f"_{name}_{dtype.__name__}_{kind}", name=name, dtype=dtype, kind=kind) + for name, dtype, kind in itertools.product( + [ 'arccos', 'arccosh', 'arcsin', 'arcsinh', + 'arctan', 'arctanh', 'conjugate', 'cos', + 'cosh', 'exp', 'exp2', 'expm1', 'log', + 'log10', 'log1p', 'sin', 'sinh', 'sqrt', + 'square', 'tan', 'tanh', 'sinc', 'positive', + 'negative', 'absolute', 'sign'], + jtu.dtypes.supported([np.complex64, np.complex128]), + ['success', 'failure'], + )) + @jtu.skip_on_devices("tpu") + def testOnComplexPlane(self, name, dtype, kind): + all_regions = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero'] + is_cpu = jtu.test_device_matches(["cpu"]) + is_cuda = jtu.test_device_matches(["cuda"]) + + # TODO(pearu): eliminate all items in the following lists: + # TODO(pearu): when all items are eliminated, eliminate the kind == 'failure' tests + regions_with_inaccuracies = dict( + absolute = ['q1', 'q2', 'q3', 'q4'] if dtype == np.complex128 and is_cuda else [], + exp = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf'], + exp2 = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf', *(['q1', 'q4'] if is_cpu else [])], + log = ['q1', 'q2', 'q3', 'q4'], + log1p = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'ninfj', 'pinfj'], + log10 = ['q1', 'q2', 'q3', 'q4', 'zero', 'ninf', 'ninfj', 'pinf', 'pinfj'], + sinh = ['pos', 'neg', 'ninf', 'pinf'], + cosh = ['pos', 'neg', 'ninf', 'pinf'], + tan = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'], + square = ((['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else []) + + ['ninf', 'pinf'] + + (['q1', 'q2', 'q3', 'q4'] if is_cpu and dtype == np.complex128 else [])), + sinc = ['q1', 'q2', 'q3', 'q4'], + sign = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'], + arcsin = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], + arccos = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], + arctan = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], + arcsinh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], + arccosh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], + arctanh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'], + ) + + jnp_op = getattr(jnp, name) + + if name == 'square': + # numpy square is incorrect on inputs with large absolute value + tiny = np.finfo(dtype).tiny + + def square(x): + re = (x.real - x.imag) * (x.real + x.imag) + im = x.real * x.imag * 2 + if is_cuda: + # apply FTZ + if np.isfinite(re) and abs(re) < tiny: + re *= 0 + if np.isfinite(im) and abs(im) < tiny: + im *= 0 + return np.array(complex(re, im), dtype=dtype) + + np_op = np.vectorize(square) + else: + np_op = getattr(np, name) + + with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"): + args = (jtu.complex_plane_sample(dtype=dtype, size_re=11),) + result = np.asarray(jnp_op(*args)) + expected = np_op(*args) + + s0, s1 = (result.shape[0] - 3) // 2, (result.shape[1] - 3) // 2 + s_dict = dict( + q1=(slice(s0 + 2, -1), slice(s1 + 2, -1)), + q2=(slice(s0 + 2, -1), slice(1, s1 + 1)), + q3=(slice(1, s0 + 1), slice(1, s1 + 1)), + q4=(slice(1, s0 + 1), slice(s1 + 2, -1)), + neg=(s0 + 1, slice(1, s1 + 1)), + pos=(s0 + 1, slice(s1 + 2, -1)), + negj=(slice(1, s0 + 1), s1 + 1), + posj=(slice(s0 + 2, -1), s1 + 1), + ninf=(slice(None), 0), + pinf=(slice(None), -1), + ninfj=(0, slice(None)), + pinfj=(-1, slice(None)), + zero=(slice(s0 + 1, s0 + 2), slice(s1 + 1, s1 + 2)), + ) + + for region in all_regions: + s = s_dict[region] + inds = np.where(result[s] != expected[s]) + if inds[0].size > 0: + mismatches = [] + for ind in zip(*inds): + x, r, e = args[0][s][ind], str(result[s][ind]), str(expected[s][ind]) + if r == e: + # skip equal nan-s + continue + mismatches.append(f'jax.numpy.{name}{x} -> {r}, expected {e}') + mismatches = "\n".join(mismatches) + else: + mismatches = '' + if kind == 'success' and region not in regions_with_inaccuracies.get(name, []): + with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): + self.assertAllClose(result[s], expected[s], err_msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{mismatches}") + if kind == 'failure' and region in regions_with_inaccuracies.get(name, []): + with self.assertRaises(AssertionError, msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"): + with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"): + self.assertAllClose(result[s], expected[s]) # on success, update regions_with_inaccuracies + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())