From 30d1a8a80f6031032e1766ccf9804d94ab26e29d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 26 Jun 2023 23:17:54 -0700 Subject: [PATCH] Add jax.scipy.stats.binom --- jax/_src/scipy/stats/binom.py | 42 +++++++++++++++++++++++++++++++++++ jax/scipy/stats/__init__.py | 1 + jax/scipy/stats/binom.py | 18 +++++++++++++++ tests/scipy_stats_test.py | 21 ++++++++++++++++++ 4 files changed, 82 insertions(+) create mode 100644 jax/_src/scipy/stats/binom.py create mode 100644 jax/scipy/stats/binom.py diff --git a/jax/_src/scipy/stats/binom.py b/jax/_src/scipy/stats/binom.py new file mode 100644 index 000000000..852d39edc --- /dev/null +++ b/jax/_src/scipy/stats/binom.py @@ -0,0 +1,42 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + + +import scipy.stats as osp_stats + +from jax import lax +import jax.numpy as jnp +from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.scipy.special import gammaln, xlogy, xlog1py +from jax._src.typing import Array, ArrayLike + + +@_wraps(osp_stats.nbinom.logpmf, update_doc=False) +def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: + """JAX implementation of scipy.stats.binom.logpmf.""" + k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc) + y = lax.sub(k, loc) + comb_term = lax.sub( + gammaln(n + 1), + lax.add(gammaln(y + 1), gammaln(n - y + 1)) + ) + log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p))) + log_probs = lax.add(comb_term, log_linear_term) + return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs) + + +@_wraps(osp_stats.nbinom.pmf, update_doc=False) +def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: + """JAX implementation of scipy.stats.binom.pmf.""" + return lax.exp(logpmf(k, n, p, loc)) diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 31a086516..5d57b1b61 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -17,6 +17,7 @@ from jax.scipy.stats import bernoulli as bernoulli from jax.scipy.stats import beta as beta +from jax.scipy.stats import binom as binom from jax.scipy.stats import cauchy as cauchy from jax.scipy.stats import dirichlet as dirichlet from jax.scipy.stats import expon as expon diff --git a/jax/scipy/stats/binom.py b/jax/scipy/stats/binom.py new file mode 100644 index 000000000..b011ed4fd --- /dev/null +++ b/jax/scipy/stats/binom.py @@ -0,0 +1,18 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.scipy.stats.binom import ( + logpmf as logpmf, + pmf as pmf, +) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index bf4176a2d..b73728f17 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -971,6 +971,27 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) + @genNamedParametersNArgs(4) + def testBinomLogPmf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.binom.logpmf + lax_fun = lsp_stats.binom.logpmf + + def args_maker(): + k, n, logit, loc = map(rng, shapes, dtypes) + k = np.floor(np.abs(k)) + n = np.ceil(np.abs(n)) + p = expit(logit) + loc = np.floor(loc) + return [k, n, p, loc] + + tol = {np.float32: 1e-6, np.float64: 1e-8} + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) + def testIssue972(self): self.assertAllClose( np.ones((4,), np.float32),