Added scipy.stats.bernoulli cdf and ppf.

This commit is contained in:
harryjulian 2022-12-16 10:12:40 +00:00
parent b8ae8e3fa1
commit c0d4ae0cc3
4 changed files with 64 additions and 0 deletions

View File

@ -159,6 +159,8 @@ jax.scipy.stats.bernoulli
logpmf
pmf
cdf
ppf
jax.scipy.stats.beta
~~~~~~~~~~~~~~~~~~~~

View File

@ -36,3 +36,26 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
@_wraps(osp_stats.bernoulli.pmf, update_doc=False)
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
return jnp.exp(logpmf(k, p, loc))
@_wraps(osp_stats.bernoulli.cdf, update_doc=False)
def cdf(k: ArrayLike, p: ArrayLike) -> Array:
k, p = jnp._promote_args_inexact('bernoulli.cdf', k, p)
zero, one = _lax_const(k, 0), _lax_const(k, 1)
conds = [
jnp.isnan(k) | jnp.isnan(p) | (p < zero) | (p > one),
lax.lt(k, zero),
jnp.logical_and(lax.ge(k, zero), lax.lt(k, one)),
lax.ge(k, one)
]
vals = [jnp.nan, zero, one - p, one]
return jnp.select(conds, vals)
@_wraps(osp_stats.bernoulli.ppf, update_doc=False)
def ppf(q: ArrayLike, p: ArrayLike) -> Array:
q, p = jnp._promote_args_inexact('bernoulli.ppf', q, p)
zero, one = _lax_const(q, 0), _lax_const(q, 1)
return jnp.where(
jnp.isnan(q) | jnp.isnan(p) | (p < zero) | (p > one) | (q < zero) | (q > one),
jnp.nan,
jnp.where(lax.le(q, one - p), zero, one)
)

View File

@ -18,4 +18,6 @@
from jax._src.scipy.stats.bernoulli import (
logpmf as logpmf,
pmf as pmf,
cdf as cdf,
ppf as ppf
)

View File

@ -155,6 +155,43 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testBernoulliCdf(self, shapes, dtypes):
rng_int = jtu.rand_int(self.rng(), -100, 100)
rng_uniform = jtu.rand_uniform(self.rng())
scipy_fun = osp_stats.bernoulli.cdf
lax_fun = lsp_stats.bernoulli.cdf
def args_maker():
x = rng_int(shapes[0], dtypes[0])
p = rng_uniform(shapes[1], dtypes[1])
return [x, p]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testBernoulliPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.bernoulli.ppf
lax_fun = lsp_stats.bernoulli.ppf
if scipy_version < (1, 9, 2):
self.skipTest("Scipy 1.9.2 needed for fix https://github.com/scipy/scipy/pull/17166.")
def args_maker():
q, p = map(rng, shapes, dtypes)
q = expit(q)
p = expit(p)
return [q, p]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testGeomLogPmf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())