mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions.
This commit is contained in:
parent
2002d49230
commit
83b3f5b759
@ -172,6 +172,9 @@ jax.scipy.stats.beta
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
|
||||
jax.scipy.stats.betabinom
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -192,6 +195,11 @@ jax.scipy.stats.cauchy
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
isf
|
||||
ppf
|
||||
|
||||
jax.scipy.stats.chi2
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
@ -202,7 +210,9 @@ jax.scipy.stats.chi2
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
|
||||
|
||||
jax.scipy.stats.dirichlet
|
||||
@ -232,6 +242,9 @@ jax.scipy.stats.gamma
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
|
||||
jax.scipy.stats.gennorm
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -296,6 +309,8 @@ jax.scipy.stats.norm
|
||||
logpdf
|
||||
pdf
|
||||
ppf
|
||||
sf
|
||||
isf
|
||||
|
||||
jax.scipy.stats.pareto
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -19,7 +19,7 @@ import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import betaln, xlogy, xlog1py
|
||||
from jax.scipy.special import betaln, betainc, xlogy, xlog1py
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.logpdf, update_doc=False)
|
||||
@ -40,3 +40,31 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale)
|
||||
return betainc(
|
||||
a,
|
||||
b,
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(lax.sub(x, loc), scale),
|
||||
_lax_const(x, 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
cdf_result = cdf(x, a, b, loc, scale)
|
||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
||||
|
@ -18,8 +18,8 @@ import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax.numpy import arctan
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@ -31,6 +31,46 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
normalize_term = lax.log(lax.mul(pi, scale))
|
||||
return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale)
|
||||
pi = _lax_const(x, np.pi)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x)))
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, = promote_args_inexact("cauchy.sf", x)
|
||||
cdf_result = cdf(x, loc, scale)
|
||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.isf, update_doc=False)
|
||||
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale)
|
||||
pi = _lax_const(q, np.pi)
|
||||
half_pi = _lax_const(q, np.pi / 2)
|
||||
unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q)))
|
||||
return lax.add(lax.mul(unscaled, scale), loc)
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.ppf, update_doc=False)
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale)
|
||||
pi = _lax_const(q, np.pi)
|
||||
half_pi = _lax_const(q, np.pi / 2)
|
||||
unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi))
|
||||
return lax.add(lax.mul(unscaled, scale), loc)
|
||||
|
@ -20,23 +20,52 @@ import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import gammainc
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
two = _lax_const(x, 2)
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
df_on_two = lax.div(df, two)
|
||||
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
two = _lax_const(x, 2)
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
df_on_two = lax.div(df, two)
|
||||
|
||||
kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))
|
||||
kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))
|
||||
|
||||
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
|
||||
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
|
||||
|
||||
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
|
||||
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
|
||||
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
|
||||
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
|
||||
|
||||
@_wraps(osp_stats.chi2.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale)
|
||||
two = _lax_const(scale, 2)
|
||||
return gammainc(
|
||||
lax.div(df, two),
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(
|
||||
lax.sub(x, loc),
|
||||
lax.mul(scale, two),
|
||||
),
|
||||
_lax_const(x, jnp.inf),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, df, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
cdf_result = cdf(x, df, loc, scale)
|
||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
||||
|
@ -19,7 +19,7 @@ import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import gammaln, xlogy
|
||||
from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.logpdf, update_doc=False)
|
||||
@ -35,3 +35,27 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
|
||||
@_wraps(osp_stats.gamma.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, a, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale)
|
||||
return gammainc(
|
||||
a,
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(lax.sub(x, loc), scale),
|
||||
_lax_const(x, jnp.inf),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, a, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
|
||||
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
|
||||
|
@ -23,29 +23,38 @@ from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike) -> Array:
|
||||
x, = promote_args_inexact("logistic.logpdf", x)
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale)
|
||||
x = lax.div(lax.sub(x, loc), scale)
|
||||
two = _lax_const(x, 2)
|
||||
half_x = lax.div(x, two)
|
||||
return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))
|
||||
return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike) -> Array:
|
||||
return lax.exp(logpdf(x))
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.ppf, update_doc=False)
|
||||
def ppf(x):
|
||||
return logit(x)
|
||||
def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale)
|
||||
return lax.add(lax.mul(logit(x), scale), loc)
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.sf, update_doc=False)
|
||||
def sf(x):
|
||||
return expit(lax.neg(x))
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale)
|
||||
return expit(lax.neg(lax.div(lax.sub(x, loc), scale)))
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.isf, update_doc=False)
|
||||
def isf(x):
|
||||
return -logit(x)
|
||||
def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale)
|
||||
return lax.add(lax.mul(lax.neg(logit(x)), scale), loc)
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.cdf, update_doc=False)
|
||||
def cdf(x):
|
||||
return expit(x)
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale)
|
||||
return expit(lax.div(lax.sub(x, loc), scale))
|
||||
|
@ -24,6 +24,7 @@ from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy import special
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale)
|
||||
@ -54,3 +55,14 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
@_wraps(osp_stats.norm.ppf, update_doc=False)
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return jnp.asarray(special.ndtri(q) * scale + loc, float)
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
cdf_result = cdf(x, loc, scale)
|
||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.isf, update_doc=False)
|
||||
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return ppf(lax.sub(_lax_const(q, 1), q), loc, scale)
|
||||
|
@ -36,6 +36,7 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
|
||||
return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
|
||||
|
||||
|
||||
@_wraps(osp_stats.t.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
||||
|
@ -18,4 +18,7 @@
|
||||
from jax._src.scipy.stats.beta import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -18,4 +18,9 @@
|
||||
from jax._src.scipy.stats.cauchy import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
sf as sf,
|
||||
isf as isf,
|
||||
ppf as ppf,
|
||||
)
|
||||
|
@ -18,4 +18,7 @@
|
||||
from jax._src.scipy.stats.chi2 import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -18,4 +18,7 @@
|
||||
from jax._src.scipy.stats.gamma import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -21,4 +21,6 @@ from jax._src.scipy.stats.norm import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
ppf as ppf,
|
||||
sf as sf,
|
||||
isf as isf,
|
||||
)
|
||||
|
@ -190,7 +190,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testGeomLogPmf(self, shapes, dtypes):
|
||||
@ -226,6 +226,38 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testBetaLogCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.beta.logcdf
|
||||
lax_fun = lsp_stats.beta.logcdf
|
||||
|
||||
def args_maker():
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testBetaSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.beta.sf
|
||||
lax_fun = lsp_stats.beta.sf
|
||||
|
||||
def args_maker():
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||
|
||||
def testBetaLogPdfZero(self):
|
||||
# Regression test for https://github.com/google/jax/issues/7645
|
||||
a = b = 1.
|
||||
@ -250,6 +282,80 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchyLogCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.cauchy.logcdf
|
||||
lax_fun = lsp_stats.cauchy.logcdf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchySf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.cauchy.sf
|
||||
lax_fun = lsp_stats.cauchy.sf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchyIsf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.cauchy.isf
|
||||
lax_fun = lsp_stats.cauchy.isf
|
||||
|
||||
def args_maker():
|
||||
q, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that q is in desired range
|
||||
# since lax.tan and numpy.tan work different near divergence points
|
||||
q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [q, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=2e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchyPpf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.cauchy.ppf
|
||||
lax_fun = lsp_stats.cauchy.ppf
|
||||
|
||||
def args_maker():
|
||||
q, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that q is in desired
|
||||
# since lax.tan and numpy.tan work different near divergence points
|
||||
q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [q, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=2e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
@jtu.sample_product(
|
||||
shapes=[
|
||||
[x_shape, alpha_shape]
|
||||
@ -326,6 +432,37 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(
|
||||
osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testGammaLogCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.gamma.logcdf
|
||||
lax_fun = lsp_stats.gamma.logcdf
|
||||
|
||||
def args_maker():
|
||||
x, a, loc, scale = map(rng, shapes, dtypes)
|
||||
x = np.clip(x, 0, None).astype(x.dtype)
|
||||
return [x, a, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testGammaLogSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.gamma.sf
|
||||
lax_fun = lsp_stats.gamma.sf
|
||||
|
||||
def args_maker():
|
||||
x, a, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(2)
|
||||
def testGenNormLogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -411,32 +548,39 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol={np.float32: 1e-5, np.float64: 1e-6})
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLogisticCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.logistic.cdf
|
||||
lax_fun = lsp_stats.logistic.cdf
|
||||
|
||||
def args_maker():
|
||||
return list(map(rng, shapes, dtypes))
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=3e-5)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLogisticLogpdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.logistic.logpdf
|
||||
lax_fun = lsp_stats.logistic.logpdf
|
||||
|
||||
def args_maker():
|
||||
return list(map(rng, shapes, dtypes))
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testLogisticLogpdfOverflow(self):
|
||||
# Regression test for https://github.com/google/jax/issues/10219
|
||||
@ -445,31 +589,56 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)),
|
||||
check_dtypes=False)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLogisticPpf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.logistic.ppf
|
||||
lax_fun = lsp_stats.logistic.ppf
|
||||
|
||||
def args_maker():
|
||||
return list(map(rng, shapes, dtypes))
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLogisticSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.logistic.sf
|
||||
lax_fun = lsp_stats.logistic.sf
|
||||
|
||||
def args_maker():
|
||||
return list(map(rng, shapes, dtypes))
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=2e-5)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=2e-5)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLogisticIsf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.logistic.isf
|
||||
lax_fun = lsp_stats.logistic.isf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testNormLogPdf(self, shapes, dtypes):
|
||||
@ -524,6 +693,22 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testNormSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.norm.sf
|
||||
lax_fun = lsp_stats.norm.sf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testNormPpf(self, shapes, dtypes):
|
||||
@ -543,6 +728,24 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testNormIsf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.norm.isf
|
||||
lax_fun = lsp_stats.norm.isf
|
||||
|
||||
def args_maker():
|
||||
q, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure probability is between 0 and 1:
|
||||
q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [q, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testTruncnormLogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -716,6 +919,36 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testChi2LogCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.chi2.logcdf
|
||||
lax_fun = lsp_stats.chi2.logcdf
|
||||
|
||||
def args_maker():
|
||||
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testChi2Sf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.chi2.sf
|
||||
lax_fun = lsp_stats.chi2.sf
|
||||
|
||||
def args_maker():
|
||||
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testBetaBinomLogPmf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user