mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix complex sin and cos on inputs with small absolute value or large pure imaginary part
This commit is contained in:
parent
d434ab55d7
commit
6d8b3e4cff
@ -1837,13 +1837,51 @@ def logistic_impl(x):
|
||||
mlir.register_lowering(logistic_p,
|
||||
mlir.lower_fun(logistic_impl, multiple_results=False))
|
||||
|
||||
def _sin_complex(x):
|
||||
# use expm1 instead of exp to avoid cancellation when abs(x) is small
|
||||
# relies on the quality of real-valued expm1, sin, cos
|
||||
# sin(x) = complex(sin(real(x)) * cosh(imag(x)), cos(real(x)) * sinh(imag(x)))
|
||||
# 2 * sinh(x) = exp(x) - 1 - (exp(-x) - 1) = expm1(x) - expm1(-x)
|
||||
# 2 * cosh(x) = exp(x) - 1 + (exp(-x) - 1) + 2 = expm1(x) + expm1(-x) + 2
|
||||
a, b = real(x), imag(x)
|
||||
a_is_zero = eq(a, _const(a, 0))
|
||||
sn, cs = sin(a), cos(a)
|
||||
e1m, e2m = expm1(b), expm1(-b)
|
||||
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
|
||||
re, im = sn * csh, cs * snh
|
||||
# avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf
|
||||
return select(a_is_zero, complex(_const(a, 0), im), complex(re, im))
|
||||
|
||||
def _sin_lowering(ctx, x):
|
||||
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
|
||||
sine = mlir.lower_fun(_sin_complex, multiple_results=False)
|
||||
return sine(ctx, x)
|
||||
return _nary_lower_hlo(hlo.sine, ctx, x)
|
||||
|
||||
sin_p = standard_unop(_float | _complex, 'sin')
|
||||
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
||||
mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.sine))
|
||||
mlir.register_lowering(sin_p, _sin_lowering)
|
||||
|
||||
def _cos_complex(x):
|
||||
# cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x)))
|
||||
# see also _sin_complex
|
||||
a, b = real(x), imag(x)
|
||||
a_is_zero = eq(a, _const(a, 0))
|
||||
sn, cs = sin(a), cos(a)
|
||||
e1m, e2m = expm1(b), expm1(-b)
|
||||
snh, csh = (e1m - e2m) / 2, (e1m + e2m + 2) / 2
|
||||
re, im = cs * csh, -sn * snh
|
||||
return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im))
|
||||
|
||||
def _cos_lowering(ctx, x):
|
||||
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
|
||||
cosine = mlir.lower_fun(_cos_complex, multiple_results=False)
|
||||
return cosine(ctx, x)
|
||||
return _nary_lower_hlo(hlo.cosine, ctx, x)
|
||||
|
||||
cos_p = standard_unop(_float | _complex, 'cos')
|
||||
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
|
||||
mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.cosine))
|
||||
mlir.register_lowering(cos_p, _cos_lowering)
|
||||
|
||||
@_upcast_fp16_for_computation
|
||||
def _tan_impl(x):
|
||||
|
@ -1413,3 +1413,63 @@ def numpy_vecdot(x, y, axis):
|
||||
y = np.moveaxis(y, axis, -1)
|
||||
x, y = np.broadcast_arrays(x, y)
|
||||
return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0]
|
||||
|
||||
|
||||
def complex_plane_sample(dtype, size_re=10, size_im=None):
|
||||
"""Return a 2-D array of complex numbers that covers the complex plane
|
||||
with a grid of samples.
|
||||
|
||||
The size of the grid is (3 + 2 * size_im) x (3 + 2 * size_re)
|
||||
that includes infinity points, extreme finite points, and the
|
||||
specified number of points from real and imaginary axis.
|
||||
|
||||
For example:
|
||||
|
||||
>>> print(complex_plane_sample(np.complex64, 0, 3))
|
||||
[[-inf -infj 0. -infj inf -infj]
|
||||
[-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j]
|
||||
[-inf-2.0000052e+00j 0.-2.0000052e+00j inf-2.0000052e+00j]
|
||||
[-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j]
|
||||
[-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j]
|
||||
[-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j]
|
||||
[-inf+2.0000052e+00j 0.+2.0000052e+00j inf+2.0000052e+00j]
|
||||
[-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j]
|
||||
[-inf +infj 0. +infj inf +infj]]
|
||||
|
||||
"""
|
||||
if size_im is None:
|
||||
size_im = size_re
|
||||
finfo = np.finfo(dtype)
|
||||
|
||||
def make_axis_points(size):
|
||||
logmin = np.log10(abs(finfo.min))
|
||||
logtiny = np.log10(finfo.tiny)
|
||||
logmax = np.log10(finfo.max)
|
||||
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# Silence RuntimeWarning: overflow encountered in cast
|
||||
warnings.simplefilter("ignore")
|
||||
axis_points[1:size + 1] = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
|
||||
axis_points[-size - 1:-1] = np.logspace(logtiny, logmax, size, dtype=finfo.dtype)
|
||||
|
||||
if size > 1:
|
||||
axis_points[1] = finfo.min
|
||||
axis_points[-2] = finfo.max
|
||||
if size > 0:
|
||||
axis_points[size] = -finfo.tiny
|
||||
axis_points[-size - 1] = finfo.tiny
|
||||
axis_points[0] = -np.inf
|
||||
axis_points[-1] = np.inf
|
||||
return axis_points
|
||||
|
||||
real_axis_points = make_axis_points(size_re)
|
||||
imag_axis_points = make_axis_points(size_im)
|
||||
|
||||
real_part = real_axis_points.reshape((-1, 3 + 2 * size_re)).repeat(3 + 2 * size_im, 0).astype(dtype)
|
||||
|
||||
imag_part = imag_axis_points.repeat(2).view(dtype)
|
||||
imag_part.real[:] = 0
|
||||
imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1)
|
||||
|
||||
return real_part + imag_part
|
||||
|
@ -96,12 +96,14 @@ def main(_):
|
||||
|
||||
# CHECK-LABEL: TEST: cos complex64[]
|
||||
# CHECK: hlo.cosine
|
||||
# CHECK-SAME: tensor<complex<f32>>
|
||||
# TODO: when the accuracy of lax.cos is fixed upstream, undo relevant parts of jax PR 19823
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.complex64(0))(lax.cos)
|
||||
|
||||
# CHECK-LABEL: TEST: cos complex128[]
|
||||
# CHECK: hlo.cosine
|
||||
# CHECK-SAME: tensor<complex<f64>>
|
||||
# TODO: when the accuracy of lax.cos is fixed upstream, undo relevant parts of jax PR 19823
|
||||
# CHECK-SAME: tensor<f64>
|
||||
print_ir(np.complex128(0))(lax.cos)
|
||||
|
||||
|
||||
|
@ -46,6 +46,7 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -3335,5 +3336,117 @@ class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(frostig,mattjj): more polymorphic primitives tests
|
||||
|
||||
|
||||
class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_{name}_{dtype.__name__}_{kind}", name=name, dtype=dtype, kind=kind)
|
||||
for name, dtype, kind in itertools.product(
|
||||
[ 'arccos', 'arccosh', 'arcsin', 'arcsinh',
|
||||
'arctan', 'arctanh', 'conjugate', 'cos',
|
||||
'cosh', 'exp', 'exp2', 'expm1', 'log',
|
||||
'log10', 'log1p', 'sin', 'sinh', 'sqrt',
|
||||
'square', 'tan', 'tanh', 'sinc', 'positive',
|
||||
'negative', 'absolute', 'sign'],
|
||||
jtu.dtypes.supported([np.complex64, np.complex128]),
|
||||
['success', 'failure'],
|
||||
))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testOnComplexPlane(self, name, dtype, kind):
|
||||
all_regions = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero']
|
||||
is_cpu = jtu.test_device_matches(["cpu"])
|
||||
is_cuda = jtu.test_device_matches(["cuda"])
|
||||
|
||||
# TODO(pearu): eliminate all items in the following lists:
|
||||
# TODO(pearu): when all items are eliminated, eliminate the kind == 'failure' tests
|
||||
regions_with_inaccuracies = dict(
|
||||
absolute = ['q1', 'q2', 'q3', 'q4'] if dtype == np.complex128 and is_cuda else [],
|
||||
exp = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf'],
|
||||
exp2 = ['pos', 'pinfj', 'pinf', 'ninfj', 'ninf', *(['q1', 'q4'] if is_cpu else [])],
|
||||
log = ['q1', 'q2', 'q3', 'q4'],
|
||||
log1p = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'ninfj', 'pinfj'],
|
||||
log10 = ['q1', 'q2', 'q3', 'q4', 'zero', 'ninf', 'ninfj', 'pinf', 'pinfj'],
|
||||
sinh = ['pos', 'neg', 'ninf', 'pinf'],
|
||||
cosh = ['pos', 'neg', 'ninf', 'pinf'],
|
||||
tan = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'],
|
||||
square = ((['q1', 'q2', 'q3', 'q4', 'ninfj', 'pinfj'] if is_cuda else [])
|
||||
+ ['ninf', 'pinf']
|
||||
+ (['q1', 'q2', 'q3', 'q4'] if is_cpu and dtype == np.complex128 else [])),
|
||||
sinc = ['q1', 'q2', 'q3', 'q4'],
|
||||
sign = ['q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninf', 'ninfj', 'pinf', 'pinfj'],
|
||||
arcsin = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
arccos = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
arctan = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
arcsinh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
arccosh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
arctanh = ['q1', 'q2', 'q3', 'q4', 'pos', 'neg', 'posj', 'negj', 'ninf', 'pinf', 'ninfj', 'pinfj'],
|
||||
)
|
||||
|
||||
jnp_op = getattr(jnp, name)
|
||||
|
||||
if name == 'square':
|
||||
# numpy square is incorrect on inputs with large absolute value
|
||||
tiny = np.finfo(dtype).tiny
|
||||
|
||||
def square(x):
|
||||
re = (x.real - x.imag) * (x.real + x.imag)
|
||||
im = x.real * x.imag * 2
|
||||
if is_cuda:
|
||||
# apply FTZ
|
||||
if np.isfinite(re) and abs(re) < tiny:
|
||||
re *= 0
|
||||
if np.isfinite(im) and abs(im) < tiny:
|
||||
im *= 0
|
||||
return np.array(complex(re, im), dtype=dtype)
|
||||
|
||||
np_op = np.vectorize(square)
|
||||
else:
|
||||
np_op = getattr(np, name)
|
||||
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="(overflow|invalid value|divide by zero) encountered in.*"):
|
||||
args = (jtu.complex_plane_sample(dtype=dtype, size_re=11),)
|
||||
result = np.asarray(jnp_op(*args))
|
||||
expected = np_op(*args)
|
||||
|
||||
s0, s1 = (result.shape[0] - 3) // 2, (result.shape[1] - 3) // 2
|
||||
s_dict = dict(
|
||||
q1=(slice(s0 + 2, -1), slice(s1 + 2, -1)),
|
||||
q2=(slice(s0 + 2, -1), slice(1, s1 + 1)),
|
||||
q3=(slice(1, s0 + 1), slice(1, s1 + 1)),
|
||||
q4=(slice(1, s0 + 1), slice(s1 + 2, -1)),
|
||||
neg=(s0 + 1, slice(1, s1 + 1)),
|
||||
pos=(s0 + 1, slice(s1 + 2, -1)),
|
||||
negj=(slice(1, s0 + 1), s1 + 1),
|
||||
posj=(slice(s0 + 2, -1), s1 + 1),
|
||||
ninf=(slice(None), 0),
|
||||
pinf=(slice(None), -1),
|
||||
ninfj=(0, slice(None)),
|
||||
pinfj=(-1, slice(None)),
|
||||
zero=(slice(s0 + 1, s0 + 2), slice(s1 + 1, s1 + 2)),
|
||||
)
|
||||
|
||||
for region in all_regions:
|
||||
s = s_dict[region]
|
||||
inds = np.where(result[s] != expected[s])
|
||||
if inds[0].size > 0:
|
||||
mismatches = []
|
||||
for ind in zip(*inds):
|
||||
x, r, e = args[0][s][ind], str(result[s][ind]), str(expected[s][ind])
|
||||
if r == e:
|
||||
# skip equal nan-s
|
||||
continue
|
||||
mismatches.append(f'jax.numpy.{name}{x} -> {r}, expected {e}')
|
||||
mismatches = "\n".join(mismatches)
|
||||
else:
|
||||
mismatches = ''
|
||||
if kind == 'success' and region not in regions_with_inaccuracies.get(name, []):
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
||||
self.assertAllClose(result[s], expected[s], err_msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{mismatches}")
|
||||
if kind == 'failure' and region in regions_with_inaccuracies.get(name, []):
|
||||
with self.assertRaises(AssertionError, msg=f"{name} in {region}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"):
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
||||
self.assertAllClose(result[s], expected[s]) # on success, update regions_with_inaccuracies
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user