mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
stats.binom.pmf: return zero for k > n
This commit is contained in:
parent
697f17adf1
commit
77258cd6bd
@ -33,7 +33,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra
|
||||
)
|
||||
log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p)))
|
||||
log_probs = lax.add(comb_term, log_linear_term)
|
||||
return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs)
|
||||
return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -jnp.inf)
|
||||
|
||||
|
||||
@_wraps(osp_stats.nbinom.pmf, update_doc=False)
|
||||
|
@ -1159,6 +1159,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
def testBinomPmfOutOfRange(self):
|
||||
# Regression test for https://github.com/google/jax/issues/19150
|
||||
self.assertEqual(lsp_stats.binom.pmf(k=6.5, n=5, p=0.8), 0.0)
|
||||
|
||||
def testIssue972(self):
|
||||
self.assertAllClose(
|
||||
np.ones((4,), np.float32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user