diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst index 330c8c9bb..68c7d6e4d 100644 --- a/docs/CHANGELOG.rst +++ b/docs/CHANGELOG.rst @@ -12,6 +12,7 @@ jax 0.2.10 (Unreleased) * `GitHub commits `__. * New features: * :func:`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods. + * :func:`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods. * Bug fixes: diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 75e8cc86d..41d27cc35 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -129,6 +129,16 @@ jax.scipy.stats.beta logpdf pdf +jax.scipy.stats.betabinom +~~~~~~~~~~~~~~~~~~~~ +.. automodule:: jax.scipy.stats.betabinom + +.. autosummary:: + :toctree: _autosummary + + logpmf + pmf + jax.scipy.stats.cauchy ~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: jax.scipy.stats.cauchy diff --git a/jax/_src/scipy/stats/betabinom.py b/jax/_src/scipy/stats/betabinom.py new file mode 100644 index 000000000..64f407081 --- /dev/null +++ b/jax/_src/scipy/stats/betabinom.py @@ -0,0 +1,40 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + + +import scipy.stats as osp_stats + +from jax import lax +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact, _constant_like, where, inf, logical_or, nan +from jax._src.scipy.special import betaln + + +@_wraps(osp_stats.betabinom.logpmf, update_doc=False) +def logpmf(k, n, a, b, loc=0): + k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b, loc) + y = lax.sub(lax.floor(k), loc) + one = _constant_like(y, 1) + zero = _constant_like(y, 0) + combiln = lax.neg(lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n,y), one), lax.add(y,one)))) + beta_lns = lax.sub(betaln(lax.add(y,a), lax.add(lax.sub(n,y),b)), betaln(a,b)) + log_probs = lax.add(combiln, beta_lns) + y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc))) + log_probs = where(y_cond, -inf, log_probs) + n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)), lax.lt(b, zero)) + return where(n_a_b_cond, nan, log_probs) + +@_wraps(osp_stats.betabinom.pmf, update_doc=False) +def pmf(k, n, a, b, loc=0): + return lax.exp(logpmf(k, n, a, b, loc)) diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index c048c4f9c..52246b416 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -29,3 +29,4 @@ from . import poisson from . import t from . import uniform from . import chi2 +from . import betabinom \ No newline at end of file diff --git a/jax/scipy/stats/betabinom.py b/jax/scipy/stats/betabinom.py new file mode 100644 index 000000000..afbf0a4b9 --- /dev/null +++ b/jax/scipy/stats/betabinom.py @@ -0,0 +1,20 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: F401 + +from jax._src.scipy.stats.betabinom import ( + logpmf, + pmf, +) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index afd8b9c99..ec2f368af 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -420,6 +420,25 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(5) + def testBetaBinomLogPmf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.betabinom.logpmf + lax_fun = lsp_stats.betabinom.logpmf + + def args_maker(): + k, n, a, b, loc = map(rng, shapes, dtypes) + k = np.floor(k) + n = np.ceil(n) + a = np.clip(a, a_min = 0.1, a_max = None) + b = np.clip(a, a_min = 0.1, a_max = None) + loc = np.floor(loc) + return [k, n, a, b, loc] + + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) + def testIssue972(self): self.assertAllClose( np.ones((4,), np.float32),