Fix#10219

This commit is contained in:
YouJiacheng 2022-04-11 13:19:47 +08:00
parent 35b32eef96
commit 4695dd919c
2 changed files with 14 additions and 1 deletions

View File

@ -16,12 +16,18 @@ import scipy.stats as osp_stats
from jax.scipy.special import expit, logit
from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.numpy import lax_numpy as jnp
@_wraps(osp_stats.logistic.logpdf, update_doc=False)
def logpdf(x):
return lax.neg(x) - 2. * lax.log1p(lax.exp(lax.neg(x)))
x, = _promote_args_inexact("logistic.logpdf", x)
two = _lax_const(x, 2)
half_x = lax.div(x, two)
return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))
@_wraps(osp_stats.logistic.pdf, update_doc=False)
def pdf(x):

View File

@ -322,6 +322,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
def testLogisticLogpdfOverflow(self):
# Regression test for https://github.com/google/jax/issues/10219
self.assertAllClose(
np.array([-100, -100], np.float32),
lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)),
check_dtypes=False)
@genNamedParametersNArgs(1)
def testLogisticPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())