mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Added scipy.stats.bernoulli cdf and ppf.
This commit is contained in:
parent
b8ae8e3fa1
commit
c0d4ae0cc3
@ -159,6 +159,8 @@ jax.scipy.stats.bernoulli
|
||||
|
||||
logpmf
|
||||
pmf
|
||||
cdf
|
||||
ppf
|
||||
|
||||
jax.scipy.stats.beta
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -36,3 +36,26 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
@_wraps(osp_stats.bernoulli.pmf, update_doc=False)
|
||||
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
return jnp.exp(logpmf(k, p, loc))
|
||||
|
||||
@_wraps(osp_stats.bernoulli.cdf, update_doc=False)
|
||||
def cdf(k: ArrayLike, p: ArrayLike) -> Array:
|
||||
k, p = jnp._promote_args_inexact('bernoulli.cdf', k, p)
|
||||
zero, one = _lax_const(k, 0), _lax_const(k, 1)
|
||||
conds = [
|
||||
jnp.isnan(k) | jnp.isnan(p) | (p < zero) | (p > one),
|
||||
lax.lt(k, zero),
|
||||
jnp.logical_and(lax.ge(k, zero), lax.lt(k, one)),
|
||||
lax.ge(k, one)
|
||||
]
|
||||
vals = [jnp.nan, zero, one - p, one]
|
||||
return jnp.select(conds, vals)
|
||||
|
||||
@_wraps(osp_stats.bernoulli.ppf, update_doc=False)
|
||||
def ppf(q: ArrayLike, p: ArrayLike) -> Array:
|
||||
q, p = jnp._promote_args_inexact('bernoulli.ppf', q, p)
|
||||
zero, one = _lax_const(q, 0), _lax_const(q, 1)
|
||||
return jnp.where(
|
||||
jnp.isnan(q) | jnp.isnan(p) | (p < zero) | (p > one) | (q < zero) | (q > one),
|
||||
jnp.nan,
|
||||
jnp.where(lax.le(q, one - p), zero, one)
|
||||
)
|
||||
|
@ -18,4 +18,6 @@
|
||||
from jax._src.scipy.stats.bernoulli import (
|
||||
logpmf as logpmf,
|
||||
pmf as pmf,
|
||||
cdf as cdf,
|
||||
ppf as ppf
|
||||
)
|
||||
|
@ -155,6 +155,43 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(2)
|
||||
def testBernoulliCdf(self, shapes, dtypes):
|
||||
rng_int = jtu.rand_int(self.rng(), -100, 100)
|
||||
rng_uniform = jtu.rand_uniform(self.rng())
|
||||
scipy_fun = osp_stats.bernoulli.cdf
|
||||
lax_fun = lsp_stats.bernoulli.cdf
|
||||
|
||||
def args_maker():
|
||||
x = rng_int(shapes[0], dtypes[0])
|
||||
p = rng_uniform(shapes[1], dtypes[1])
|
||||
return [x, p]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(2)
|
||||
def testBernoulliPpf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.bernoulli.ppf
|
||||
lax_fun = lsp_stats.bernoulli.ppf
|
||||
|
||||
if scipy_version < (1, 9, 2):
|
||||
self.skipTest("Scipy 1.9.2 needed for fix https://github.com/scipy/scipy/pull/17166.")
|
||||
|
||||
def args_maker():
|
||||
q, p = map(rng, shapes, dtypes)
|
||||
q = expit(q)
|
||||
p = expit(p)
|
||||
return [q, p]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testGeomLogPmf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user