mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Fix#10219
This commit is contained in:
parent
35b32eef96
commit
4695dd919c
@ -16,12 +16,18 @@ import scipy.stats as osp_stats
|
||||
from jax.scipy.special import expit, logit
|
||||
|
||||
from jax import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.logpdf, update_doc=False)
|
||||
def logpdf(x):
|
||||
return lax.neg(x) - 2. * lax.log1p(lax.exp(lax.neg(x)))
|
||||
x, = _promote_args_inexact("logistic.logpdf", x)
|
||||
two = _lax_const(x, 2)
|
||||
half_x = lax.div(x, two)
|
||||
return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))
|
||||
|
||||
@_wraps(osp_stats.logistic.pdf, update_doc=False)
|
||||
def pdf(x):
|
||||
|
@ -322,6 +322,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testLogisticLogpdfOverflow(self):
|
||||
# Regression test for https://github.com/google/jax/issues/10219
|
||||
self.assertAllClose(
|
||||
np.array([-100, -100], np.float32),
|
||||
lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)),
|
||||
check_dtypes=False)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
def testLogisticPpf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user