mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #19166 from jakevdp:fix-binom
PiperOrigin-RevId: 595168665
This commit is contained in:
commit
15f4a8d2ec
@ -33,7 +33,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra
|
||||
)
|
||||
log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p)))
|
||||
log_probs = lax.add(comb_term, log_linear_term)
|
||||
return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs)
|
||||
return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -jnp.inf)
|
||||
|
||||
|
||||
@_wraps(osp_stats.nbinom.pmf, update_doc=False)
|
||||
|
@ -1159,6 +1159,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
def testBinomPmfOutOfRange(self):
|
||||
# Regression test for https://github.com/google/jax/issues/19150
|
||||
self.assertEqual(lsp_stats.binom.pmf(k=6.5, n=5, p=0.8), 0.0)
|
||||
|
||||
def testIssue972(self):
|
||||
self.assertAllClose(
|
||||
np.ones((4,), np.float32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user