mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
Update gammainc and gammaincc against scipy 1.16: return nan whenever one of operands is nan.
This commit is contained in:
parent
abcc7fdf4c
commit
f608a8c502
@ -306,13 +306,13 @@ def igamma_impl(a, x, *, dtype):
|
||||
x_is_infinity = eq(x, _const(x, float('inf')))
|
||||
a_is_zero = eq(a, _const(a, 0))
|
||||
x_is_zero = eq(x, _const(x, 0))
|
||||
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)])
|
||||
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero), is_nan])
|
||||
|
||||
use_igammac = bitwise_and(ge(x, _const(x, 1)), gt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
|
||||
ax = exp(ax)
|
||||
enabled = bitwise_not(_reduce(bitwise_or, [x_is_zero, domain_error, underflow, is_nan, x_is_infinity]))
|
||||
enabled = bitwise_not(_reduce(bitwise_or, [x_is_zero, domain_error, underflow, x_is_infinity]))
|
||||
|
||||
output = select(
|
||||
use_igammac,
|
||||
@ -437,11 +437,11 @@ def igammac_impl(a, x, *, dtype):
|
||||
a_is_zero = eq(a, _const(a, 0))
|
||||
x_is_zero = eq(x, _const(x, 0))
|
||||
x_is_infinity = eq(x, _const(x, float('inf')))
|
||||
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero)])
|
||||
domain_error = _reduce(bitwise_or, [lt(x, _const(x, 0)), lt(a, _const(a, 0)), bitwise_and(a_is_zero, x_is_zero), is_nan])
|
||||
use_igamma = bitwise_or(lt(x, _const(x, 1)), lt(x, a))
|
||||
ax = a * log(x) - x - lgamma(a)
|
||||
underflow = lt(ax, -log(dtypes.finfo(dtype).max))
|
||||
enabled = bitwise_not(_reduce(bitwise_or, [domain_error, underflow, is_nan, x_is_infinity, a_is_zero]))
|
||||
enabled = bitwise_not(_reduce(bitwise_or, [domain_error, underflow, x_is_infinity, a_is_zero]))
|
||||
ax = exp(ax)
|
||||
|
||||
igamma_call = _igamma_series(ax, x, a, bitwise_and(enabled, use_igamma),
|
||||
|
@ -170,7 +170,7 @@ def _pretty_special_fun_name(case):
|
||||
return dict(**case, testcase_name=name)
|
||||
|
||||
|
||||
class LaxScipySpcialFunctionsTest(jtu.JaxTestCase):
|
||||
class LaxScipySpecialFunctionsTest(jtu.JaxTestCase):
|
||||
|
||||
def _GetArgsMaker(self, rng, shapes, dtypes):
|
||||
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
||||
@ -291,20 +291,30 @@ class LaxScipySpcialFunctionsTest(jtu.JaxTestCase):
|
||||
dtype = jax.numpy.zeros(0).dtype # default float dtype.
|
||||
nan = float('nan')
|
||||
inf = float('inf')
|
||||
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan]).astype(dtype),
|
||||
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf]).astype(dtype)]
|
||||
if jtu.parse_version(scipy.__version__) >= (1, 16):
|
||||
samples_slice = slice(None)
|
||||
else:
|
||||
# disable samples that contradict with scipy/scipy#22441
|
||||
samples_slice = slice(None, -1)
|
||||
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan][samples_slice]).astype(dtype),
|
||||
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf][samples_slice]).astype(dtype)]
|
||||
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
|
||||
self._CheckAgainstNumpy(osp_special.gammainc, lsp_special.gammainc, args_maker, rtol=rtol)
|
||||
self._CheckAgainstNumpy(lsp_special.gammainc, osp_special.gammainc, args_maker, rtol=rtol)
|
||||
self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol)
|
||||
|
||||
def testGammaIncCBoundaryValues(self):
|
||||
dtype = jax.numpy.zeros(0).dtype # default float dtype.
|
||||
nan = float('nan')
|
||||
inf = float('inf')
|
||||
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan, 1]).astype(dtype),
|
||||
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf, -1]).astype(dtype)]
|
||||
if jtu.parse_version(scipy.__version__) >= (1, 16):
|
||||
samples_slice = slice(None)
|
||||
else:
|
||||
# disable samples that contradict with scipy/scipy#22441
|
||||
samples_slice = slice(None, -1)
|
||||
args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan][samples_slice]).astype(dtype),
|
||||
np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf][samples_slice]).astype(dtype)]
|
||||
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
|
||||
self._CheckAgainstNumpy(osp_special.gammaincc, lsp_special.gammaincc, args_maker, rtol=rtol)
|
||||
self._CheckAgainstNumpy(lsp_special.gammaincc, osp_special.gammaincc, args_maker, rtol=rtol)
|
||||
self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user