From 77258cd6bdb08f3bf3d0c991b2beea8222aeb7db Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 2 Jan 2024 10:53:44 -0800 Subject: [PATCH] stats.binom.pmf: return zero for k > n --- jax/_src/scipy/stats/binom.py | 2 +- tests/scipy_stats_test.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/_src/scipy/stats/binom.py b/jax/_src/scipy/stats/binom.py index 852d39edc..869eab91c 100644 --- a/jax/_src/scipy/stats/binom.py +++ b/jax/_src/scipy/stats/binom.py @@ -33,7 +33,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra ) log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p))) log_probs = lax.add(comb_term, log_linear_term) - return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs) + return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -jnp.inf) @_wraps(osp_stats.nbinom.pmf, update_doc=False) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 2db5b354a..09b12b87b 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -1159,6 +1159,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) + def testBinomPmfOutOfRange(self): + # Regression test for https://github.com/google/jax/issues/19150 + self.assertEqual(lsp_stats.binom.pmf(k=6.5, n=5, p=0.8), 0.0) + def testIssue972(self): self.assertAllClose( np.ones((4,), np.float32),