mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #20560 from pearu:pearu/log1p-fixes
PiperOrigin-RevId: 621562841
This commit is contained in:
commit
57ee6b7550
@ -1746,6 +1746,24 @@ class numpy_with_mpmath:
|
||||
def expm1(self, x):
|
||||
return x.context.expm1(x)
|
||||
|
||||
def log1p(self, x):
|
||||
ctx = x.context
|
||||
if isinstance(x, ctx.mpc):
|
||||
# Workaround mpmath 1.3 bug in log(+-inf+-infj) evaluation (see mpmath/mpmath#774).
|
||||
# TODO(pearu): remove this function when mpmath 1.4 or newer
|
||||
# will be the required test dependency.
|
||||
if ctx.isinf(x.real) and ctx.isinf(x.imag):
|
||||
pi = ctx.pi
|
||||
if x.real > 0 and x.imag > 0:
|
||||
return ctx.make_mpc((x.real._mpf_, (pi / 4)._mpf_))
|
||||
if x.real > 0 and x.imag < 0:
|
||||
return ctx.make_mpc((x.real._mpf_, (-pi / 4)._mpf_))
|
||||
if x.real < 0 and x.imag < 0:
|
||||
return ctx.make_mpc(((-x.real)._mpf_, (-3 * pi / 4)._mpf_))
|
||||
if x.real < 0 and x.imag > 0:
|
||||
return ctx.make_mpc(((-x.real)._mpf_, (3 * pi / 4)._mpf_))
|
||||
return ctx.log1p(x)
|
||||
|
||||
def log2(self, x):
|
||||
return x.context.ln(x) / x.context.ln2
|
||||
|
||||
|
@ -3666,11 +3666,9 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
|
||||
|
||||
elif name == 'arctanh':
|
||||
if xla_extension_version < 251:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
|
||||
else:
|
||||
regions_with_inaccuracies_keep('pos', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
|
||||
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
|
||||
# TODO(pearu): after landing openxla/xla#10503, switch to
|
||||
# regions_with_inaccuracies_keep('pos', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
|
||||
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1'}:
|
||||
regions_with_inaccuracies.clear()
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user