Add jax.scipy.stats.binom

This commit is contained in:
Jake VanderPlas 2023-06-26 23:17:54 -07:00
parent 6bc74d2a98
commit 30d1a8a80f
4 changed files with 82 additions and 0 deletions

View File

@ -0,0 +1,42 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.scipy.special import gammaln, xlogy, xlog1py
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.nbinom.logpmf, update_doc=False)
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
"""JAX implementation of scipy.stats.binom.logpmf."""
k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc)
y = lax.sub(k, loc)
comb_term = lax.sub(
gammaln(n + 1),
lax.add(gammaln(y + 1), gammaln(n - y + 1))
)
log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p)))
log_probs = lax.add(comb_term, log_linear_term)
return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs)
@_wraps(osp_stats.nbinom.pmf, update_doc=False)
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
"""JAX implementation of scipy.stats.binom.pmf."""
return lax.exp(logpmf(k, n, p, loc))

View File

@ -17,6 +17,7 @@
from jax.scipy.stats import bernoulli as bernoulli
from jax.scipy.stats import beta as beta
from jax.scipy.stats import binom as binom
from jax.scipy.stats import cauchy as cauchy
from jax.scipy.stats import dirichlet as dirichlet
from jax.scipy.stats import expon as expon

18
jax/scipy/stats/binom.py Normal file
View File

@ -0,0 +1,18 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.scipy.stats.binom import (
logpmf as logpmf,
pmf as pmf,
)

View File

@ -971,6 +971,27 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
@genNamedParametersNArgs(4)
def testBinomLogPmf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.binom.logpmf
lax_fun = lsp_stats.binom.logpmf
def args_maker():
k, n, logit, loc = map(rng, shapes, dtypes)
k = np.floor(np.abs(k))
n = np.ceil(np.abs(n))
p = expit(logit)
loc = np.floor(loc)
return [k, n, p, loc]
tol = {np.float32: 1e-6, np.float64: 1e-8}
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
def testIssue972(self):
self.assertAllClose(
np.ones((4,), np.float32),