Fix jax.scipy.stats.poisson.logpmf to emulate scipy.stats.poisson.logpmf for non-integer values of k

This commit is contained in:
rajasekharporeddy 2024-04-24 00:29:52 +05:30
parent 3146c2a3f6
commit 95ed0538fd
2 changed files with 2 additions and 3 deletions

View File

@ -29,7 +29,8 @@ def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
zero = _lax_const(k, 0)
x = lax.sub(k, loc)
log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs)
return jnp.where(jnp.logical_or(lax.lt(x, zero),
lax.ne(jnp.round(k), k)), -jnp.inf, log_probs)
@implements(osp_stats.poisson.pmf, update_doc=False)
def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:

View File

@ -129,7 +129,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
k, mu, loc = map(rng, shapes, dtypes)
k = np.floor(k)
# clipping to ensure that rate parameter is strictly positive
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
loc = np.floor(loc)
@ -148,7 +147,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
k, mu, loc = map(rng, shapes, dtypes)
k = np.floor(k)
# clipping to ensure that rate parameter is strictly positive
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
loc = np.floor(loc)