Merge pull request #20560 from pearu:pearu/log1p-fixes

PiperOrigin-RevId: 621562841
This commit is contained in:
jax authors 2024-04-03 10:17:11 -07:00
commit 57ee6b7550
2 changed files with 21 additions and 5 deletions

View File

@ -1746,6 +1746,24 @@ class numpy_with_mpmath:
def expm1(self, x):
return x.context.expm1(x)
def log1p(self, x):
ctx = x.context
if isinstance(x, ctx.mpc):
# Workaround mpmath 1.3 bug in log(+-inf+-infj) evaluation (see mpmath/mpmath#774).
# TODO(pearu): remove this function when mpmath 1.4 or newer
# will be the required test dependency.
if ctx.isinf(x.real) and ctx.isinf(x.imag):
pi = ctx.pi
if x.real > 0 and x.imag > 0:
return ctx.make_mpc((x.real._mpf_, (pi / 4)._mpf_))
if x.real > 0 and x.imag < 0:
return ctx.make_mpc((x.real._mpf_, (-pi / 4)._mpf_))
if x.real < 0 and x.imag < 0:
return ctx.make_mpc(((-x.real)._mpf_, (-3 * pi / 4)._mpf_))
if x.real < 0 and x.imag > 0:
return ctx.make_mpc(((-x.real)._mpf_, (3 * pi / 4)._mpf_))
return ctx.log1p(x)
def log2(self, x):
return x.context.ln(x) / x.context.ln2

View File

@ -3666,11 +3666,9 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
elif name == 'arctanh':
if xla_extension_version < 251:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
else:
regions_with_inaccuracies_keep('pos', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
# TODO(pearu): after landing openxla/xla#10503, switch to
# regions_with_inaccuracies_keep('pos', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1'}:
regions_with_inaccuracies.clear()
else: