Fix complex sin and cos on inputs with small absolute value or large pure imaginary part

This commit is contained in:
Pearu Peterson 2024-02-15 13:29:35 +02:00
parent d434ab55d7
commit 6d8b3e4cff
4 changed files with 217 additions and 4 deletions

View File

@ -1837,13 +1837,51 @@ def logistic_impl(x):
mlir.register_lowering(logistic_p,
mlir.lower_fun(logistic_impl, multiple_results=False))
def _sin_complex(x):
# use expm1 instead of exp to avoid cancellation when abs(x) is small
# relies on the quality of real-valued expm1, sin, cos
# sin(x) = complex(sin(real(x)) * cosh(imag(x)), cos(real(x)) * sinh(imag(x)))
# 2 * sinh(x) = exp(x) - 1 - (exp(-x) - 1) = expm1(x) - expm1(-x)
# 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2
a, b = real(x), imag(x)
a_is_zero = eq(a, _const(a, 0))
sn, cs = sin(a), cos(a)
e1m, e2m = expm1(b), expm1(-b)
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
re, im = sn * csh, cs * snh
# avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf
return select(a_is_zero, complex(_const(a, 0), im), complex(re, im))
def _sin_lowering(ctx, x):
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
sine = mlir.lower_fun(_sin_complex, multiple_results=False)
return sine(ctx, x)
return _nary_lower_hlo(hlo.sine, ctx, x)
sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.sine))
mlir.register_lowering(sin_p, _sin_lowering)
def _cos_complex(x):
# cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x)))
# see also _sin_complex
a, b = real(x), imag(x)
a_is_zero = eq(a, _const(a, 0))
sn, cs = sin(a), cos(a)
e1m, e2m = expm1(b), expm1(-b)
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
re, im = cs * csh, -sn * snh
return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im))
def _cos_lowering(ctx, x):
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
cosine = mlir.lower_fun(_cos_complex, multiple_results=False)
return cosine(ctx, x)
return _nary_lower_hlo(hlo.cosine, ctx, x)
cos_p = standard_unop(_float | _complex, 'cos')
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.cosine))
mlir.register_lowering(cos_p, _cos_lowering)
@_upcast_fp16_for_computation
def _tan_impl(x):

View File

@ -1413,3 +1413,63 @@ def numpy_vecdot(x, y, axis):
y = np.moveaxis(y, axis, -1)
x, y = np.broadcast_arrays(x, y)
return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0]
def complex_plane_sample(dtype, size_re=10, size_im=None):
"""Return a 2-D array of complex numbers that covers the complex plane
with a grid of samples.
The size of the grid is (3 + 2 * size_im) x (3 + 2 * size_re)
that includes infinity points, extreme finite points, and the
specified number of points from real and imaginary axis.
For example:
>>> print(complex_plane_sample(np.complex64, 0, 3))
[[-inf -infj 0. -infj inf -infj]
[-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j]
[-inf-2.0000052e+00j 0.-2.0000052e+00j inf-2.0000052e+00j]
[-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j]
[-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j]
[-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j]
[-inf+2.0000052e+00j 0.+2.0000052e+00j inf+2.0000052e+00j]
[-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j]
[-inf +infj 0. +infj inf +infj]]
"""
if size_im is None:
size_im = size_re
finfo = np.finfo(dtype)
def make_axis_points(size):
logmin = np.log10(abs(finfo.min))
logtiny = np.log10(finfo.tiny)
logmax = np.log10(finfo.max)
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)
with warnings.catch_warnings():
# Silence RuntimeWarning: overflow encountered in cast
warnings.simplefilter("ignore")
axis_points[1:size + 1] = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
axis_points[-size - 1:-1] = np.logspace(logtiny, logmax, size, dtype=finfo.dtype)
if size > 1:
axis_points[1] = finfo.min
axis_points[-2] = finfo.max
if size > 0:
axis_points[size] = -finfo.tiny
axis_points[-size - 1] = finfo.tiny
axis_points[0] = -np.inf
axis_points[-1] = np.inf
return axis_points
real_axis_points = make_axis_points(size_re)
imag_axis_points = make_axis_points(size_im)
real_part = real_axis_points.reshape((-1, 3 + 2 * size_re)).repeat(3 + 2 * size_im, 0).astype(dtype)
imag_part = imag_axis_points.repeat(2).view(dtype)
imag_part.real[:] = 0
imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1)
return real_part + imag_part

View File

@ -96,12 +96,14 @@ def main(_):
# CHECK-LABEL: TEST: cos complex64[]
# CHECK: hlo.cosine
# CHECK-SAME: tensor<complex<f32>>
# TODO: when the accuracy of lax.cos is fixed upstream, undo relevant parts of jax PR 19823
# CHECK-SAME: tensor<f32>
print_ir(np.complex64(0))(lax.cos)
# CHECK-LABEL: TEST: cos complex128[]
# CHECK: hlo.cosine
# CHECK-SAME: tensor<complex<f64>>
# TODO: when the accuracy of lax.cos is fixed upstream, undo relevant parts of jax PR 19823
# CHECK-SAME: tensor<f64>
print_ir(np.complex128(0))(lax.cos)

View File

@ -46,6 +46,7 @@ from jax._src.interpreters import pxla
from jax._src.internal_test_util import lax_test_util
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.util import NumpyComplexWarning
config.parse_flags_with_absl()
@ -3335,5 +3336,117 @@ class CustomElementTypesTest(jtu.JaxTestCase):
# TODO(frostig,mattjj): more polymorphic primitives tests
class FunctionAccuracyTest(jtu.JaxTestCase):
@parameterized.named_parameters(
dict(testcase_name=f"_{name}_{dtype.__name__}_{kind}", name=name, dtype=dtype, kind=kind)
for name, dtype, kind in itertools.product(
[ 'arccos', 'arccosh', 'arcsin', 'arcsinh',
'arctan', 'arctanh', 'conjugate', 'cos',
'cosh', 'exp', 'exp2', 'expm1', 'log',
'log10', 'log1p', 'sin', 'sinh', 'sqrt',
'square', 'tan', 'tanh', 'sinc', 'positive',
'negative', 'absolute', 'sign'],
jtu.dtypes.supported([np.complex64, np.complex128]),
['success', 'failure'],
))
@jtu.skip_on_devices("tpu")
def testOnComplexPlane(self, name, dtype, kind):
all_regions = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero']
is_cpu = jtu.test_device_matches(["cpu"])
is_cuda = jtu.test_device_matches(["cuda"])
# TODO(pearu): eliminate all items in the following lists:
# TODO(pearu): when all items are eliminated, eliminate the kind == 'failure' tests
regions_with_inaccuracies = dict(
absolute = ['q1', 'q2', 'q3', 'q4'] if dtype == np.complex128 and is_cuda else [],
exp = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf'],
exp2 = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf', *(['q1', 'q4'] if is_cpu else [])],
log = ['q1', 'q2', 'q3', 'q4'],
log1p = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'ninfj', 'pinfj'],
log10 = ['q1', 'q2', 'q3', 'q4', 'zero', 'ninf', 'ninfj', 'pinf', 'pinfj'],
sinh = ['pos', 'neg', 'ninf', 'pinf'],
cosh = ['pos', 'neg', 'ninf', 'pinf'],
tan = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'],
square = ((['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else [])
+ ['ninf', 'pinf']
+ (['q1', 'q2', 'q3', 'q4'] if is_cpu and dtype == np.complex128 else [])),
sinc = ['q1', 'q2', 'q3', 'q4'],
sign = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'],
arcsin = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arccos = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arctan = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arcsinh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arccosh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
arctanh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
)
jnp_op = getattr(jnp, name)
if name == 'square':
# numpy square is incorrect on inputs with large absolute value
tiny = np.finfo(dtype).tiny
def square(x):
re = (x.real - x.imag) * (x.real + x.imag)
im = x.real * x.imag * 2
if is_cuda:
# apply FTZ
if np.isfinite(re) and abs(re) < tiny:
re *= 0
if np.isfinite(im) and abs(im) < tiny:
im *= 0
return np.array(complex(re, im), dtype=dtype)
np_op = np.vectorize(square)
else:
np_op = getattr(np, name)
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
args = (jtu.complex_plane_sample(dtype=dtype, size_re=11),)
result = np.asarray(jnp_op(*args))
expected = np_op(*args)
s0, s1 = (result.shape[0] - 3) // 2, (result.shape[1] - 3) // 2
s_dict = dict(
q1=(slice(s0 + 2, -1), slice(s1 + 2, -1)),
q2=(slice(s0 + 2, -1), slice(1, s1 + 1)),
q3=(slice(1, s0 + 1), slice(1, s1 + 1)),
q4=(slice(1, s0 + 1), slice(s1 + 2, -1)),
neg=(s0 + 1, slice(1, s1 + 1)),
pos=(s0 + 1, slice(s1 + 2, -1)),
negj=(slice(1, s0 + 1), s1 + 1),
posj=(slice(s0 + 2, -1), s1 + 1),
ninf=(slice(None), 0),
pinf=(slice(None), -1),
ninfj=(0, slice(None)),
pinfj=(-1, slice(None)),
zero=(slice(s0 + 1, s0 + 2), slice(s1 + 1, s1 + 2)),
)
for region in all_regions:
s = s_dict[region]
inds = np.where(result[s] != expected[s])
if inds[0].size > 0:
mismatches = []
for ind in zip(*inds):
x, r, e = args[0][s][ind], str(result[s][ind]), str(expected[s][ind])
if r == e:
# skip equal nan-s
continue
mismatches.append(f'jax.numpy.{name}{x} -> {r}, expected {e}')
mismatches = "\n".join(mismatches)
else:
mismatches = ''
if kind == 'success' and region not in regions_with_inaccuracies.get(name, []):
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
self.assertAllClose(result[s], expected[s], err_msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{mismatches}")
if kind == 'failure' and region in regions_with_inaccuracies.get(name, []):
with self.assertRaises(AssertionError, msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"):
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
self.assertAllClose(result[s], expected[s]) # on success, update regions_with_inaccuracies
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())