stats.binom.pmf: return zero for k > n

This commit is contained in:
Jake VanderPlas 2024-01-02 10:53:44 -08:00
parent 697f17adf1
commit 77258cd6bd
2 changed files with 5 additions and 1 deletions

View File

@ -33,7 +33,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra
)
log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p)))
log_probs = lax.add(comb_term, log_linear_term)
return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs)
return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -jnp.inf)
@_wraps(osp_stats.nbinom.pmf, update_doc=False)

View File

@ -1159,6 +1159,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
def testBinomPmfOutOfRange(self):
# Regression test for https://github.com/google/jax/issues/19150
self.assertEqual(lsp_stats.binom.pmf(k=6.5, n=5, p=0.8), 0.0)
def testIssue972(self):
self.assertAllClose(
np.ones((4,), np.float32),