1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

Update gammainc and gammaincc against scipy 1.16: return nan whenever one of operands is nan.

This commit is contained in:
Pearu Peterson 2025-03-12 17:22:46 +02:00
parent abcc7fdf4c
commit f608a8c502
2 changed files with 21 additions and 11 deletions

@ -306,13 +306,13 @@ def igamma_impl(a, x, *, dtype):
x_is_infinity = eq(x, _const(x, float('inf')))
a_is_zero = eq(a, _const(a, 0))
x_is_zero = eq(x, _const(x, 0))
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)])
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero), is_nan])
use_igammac = bitwise_and(ge(x, _const(x, 1)), gt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
ax = exp(ax)
enabled = bitwise_not(_reduce(bitwise_or, [x_is_zero, domain_error, underflow, is_nan, x_is_infinity]))
enabled = bitwise_not(_reduce(bitwise_or, [x_is_zero, domain_error, underflow, x_is_infinity]))
output = select(
use_igammac,
@ -437,11 +437,11 @@ def igammac_impl(a, x, *, dtype):
a_is_zero = eq(a, _const(a, 0))
x_is_zero = eq(x, _const(x, 0))
x_is_infinity = eq(x, _const(x, float('inf')))
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)])
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero), is_nan])
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
ax = a * log(x) - x - lgamma(a)
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
enabled = bitwise_not(_reduce(bitwise_or, [domain_error, underflow, is_nan, x_is_infinity, a_is_zero]))
enabled = bitwise_not(_reduce(bitwise_or, [domain_error, underflow, x_is_infinity, a_is_zero]))
ax = exp(ax)
igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma),

@ -170,7 +170,7 @@ def _pretty_special_fun_name(case):
return dict(**case, testcase_name=name)
class LaxScipySpcialFunctionsTest(jtu.JaxTestCase):
class LaxScipySpecialFunctionsTest(jtu.JaxTestCase):
def _GetArgsMaker(self, rng, shapes, dtypes):
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
@ -291,20 +291,30 @@ class LaxScipySpcialFunctionsTest(jtu.JaxTestCase):
dtype = jax.numpy.zeros(0).dtype # default float dtype.
nan = float('nan')
inf = float('inf')
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan]).astype(dtype),
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf]).astype(dtype)]
if jtu.parse_version(scipy.__version__) >= (1, 16):
samples_slice = slice(None)
else:
# disable samples that contradict with scipy/scipy#22441
samples_slice = slice(None, -1)
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan][samples_slice]).astype(dtype),
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf][samples_slice]).astype(dtype)]
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
self._CheckAgainstNumpy(osp_special.gammainc, lsp_special.gammainc, args_maker, rtol=rtol)
self._CheckAgainstNumpy(lsp_special.gammainc, osp_special.gammainc, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol)
def testGammaIncCBoundaryValues(self):
dtype = jax.numpy.zeros(0).dtype # default float dtype.
nan = float('nan')
inf = float('inf')
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan, 1]).astype(dtype),
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf, -1]).astype(dtype)]
if jtu.parse_version(scipy.__version__) >= (1, 16):
samples_slice = slice(None)
else:
# disable samples that contradict with scipy/scipy#22441
samples_slice = slice(None, -1)
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan][samples_slice]).astype(dtype),
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf][samples_slice]).astype(dtype)]
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
self._CheckAgainstNumpy(osp_special.gammaincc, lsp_special.gammaincc, args_maker, rtol=rtol)
self._CheckAgainstNumpy(lsp_special.gammaincc, osp_special.gammaincc, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol)