From ee5c134e66dfa2e1fa4795943e26bd24310927a9 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 24 Apr 2024 23:49:10 +0300 Subject: [PATCH] Workaround mpmath 1.3 issues in asinh evaluation at infinities --- jax/_src/test_util.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 517c48487..411347153 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1796,12 +1796,12 @@ class numpy_with_mpmath: def arcsin(self, x): ctx = x.context if isinstance(x, ctx.mpc): - # Workaround mpmath 1.3 bug in asin(+-inf+-infj) evaluation (see mpmath/mpmath#793). - # TODO(pearu): remove this function when mpmath 1.4 or newer - # will be the required test dependency. + # Workaround mpmath 1.3 bug in asin(+-inf+-infj) evaluation (see + # mpmath/mpmath#793). + # TODO(pearu): remove the if-block below when mpmath 1.4 or + # newer will be the required test dependency. pi = ctx.pi inf = ctx.inf - nan = ctx.nan zero = ctx.zero if ctx.isinf(x.real): sign_real = -1 if x.real < 0 else 1 @@ -1811,9 +1811,10 @@ class numpy_with_mpmath: elif ctx.isinf(x.imag): return ctx.make_mpc((zero._mpf_, x.imag._mpf_)) - # TODO(pearu): adjust this code according to mpmath/mpmath#786 - # resolution when mpmath 1.4 or newer will be the required test - # dependency. + # On branch cut, mpmath.mp.asin returns different value compared + # to mpmath.fp.asin and numpy.arcsin (see + # mpmath/mpmath#786). The following if-block ensures + # compatibiliy with numpy.arcsin. if x.real > 1 and x.imag == 0: return ctx.asin(x).conjugate() @@ -1822,10 +1823,26 @@ class numpy_with_mpmath: def arcsinh(self, x): ctx = x.context - # TODO(pearu): adjust this code according to mpmath/mpmath#786 - # resolution when mpmath 1.4 or newer will be the required test - # dependency. if isinstance(x, ctx.mpc): + # Workaround mpmath 1.3 bug in asinh(+-inf+-infj) evaluation + # (see mpmath/mpmath#749). + # TODO(pearu): remove the if-block below when mpmath 1.4 or + # newer will be the required test dependency. + pi = ctx.pi + inf = ctx.inf + zero = ctx.zero + if ctx.isinf(x.imag): + sign_imag = -1 if x.imag < 0 else 1 + real = -inf if x.real < 0 else inf + imag = sign_imag * pi / (4 if ctx.isinf(x.real) else 2) + return ctx.make_mpc((real._mpf_, imag._mpf_)) + elif ctx.isinf(x.real): + return ctx.make_mpc((x.real._mpf_, zero._mpf_)) + + # On branch cut, mpmath.mp.asinh returns different value + # compared to mpmath.fp.asinh and numpy.arcsinh (see + # mpmath/mpmath#786). The following if-block ensures + # compatibiliy with numpy.arcsinh. if x.real == 0 and x.imag < -1: return (-ctx.asinh(x)).conjugate() return ctx.asinh(x)