mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Fixed loc and scale parameters for logistic distribution. CDF and SF have been added for several distributions, including cauchy, gamma, logistic, chi2 and beta. ISF and PPF have also been added for cauchy and logistic.
This commit is contained in:
parent
2ccd785e16
commit
feb9ab33af
@ -172,6 +172,9 @@ jax.scipy.stats.beta
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
|
||||
jax.scipy.stats.betabinom
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -192,6 +195,11 @@ jax.scipy.stats.cauchy
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
isf
|
||||
ppf
|
||||
|
||||
jax.scipy.stats.chi2
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
@ -202,7 +210,9 @@ jax.scipy.stats.chi2
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
|
||||
|
||||
jax.scipy.stats.dirichlet
|
||||
@ -232,6 +242,9 @@ jax.scipy.stats.gamma
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
|
||||
jax.scipy.stats.gennorm
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -296,6 +309,8 @@ jax.scipy.stats.norm
|
||||
logpdf
|
||||
pdf
|
||||
ppf
|
||||
sf
|
||||
isf
|
||||
|
||||
jax.scipy.stats.pareto
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -19,7 +19,7 @@ from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import betaln, xlogy, xlog1py
|
||||
from jax.scipy.special import betaln, betainc, xlogy, xlog1py
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.logpdf, update_doc=False)
|
||||
@ -40,3 +40,31 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, b, loc, scale = _promote_args_inexact("beta.cdf", x, a, b, loc, scale)
|
||||
return betainc(
|
||||
a,
|
||||
b,
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(lax.sub(x, loc), scale),
|
||||
_lax_const(x, 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
cdf_result = cdf(x, a, b, loc, scale)
|
||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
||||
|
@ -20,6 +20,7 @@ from jax import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact
|
||||
from jax._src.numpy.lax_numpy import arctan
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@ -31,6 +32,44 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
normalize_term = lax.log(lax.mul(pi, scale))
|
||||
return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = _promote_args_inexact("cauchy.cdf", x, loc, scale)
|
||||
pi = _lax_const(x, np.pi)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x)))
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, = _promote_args_inexact("cauchy.sf", x)
|
||||
return lax.sub(_lax_const(x, 1), cdf(x, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.isf, update_doc=False)
|
||||
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
q, loc, scale = _promote_args_inexact("cauchy.isf", q, loc, scale)
|
||||
pi = _lax_const(q, np.pi)
|
||||
half_pi = _lax_const(q, np.pi / 2)
|
||||
unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q)))
|
||||
return lax.add(lax.mul(unscaled, scale), loc)
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.ppf, update_doc=False)
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
q, loc, scale = _promote_args_inexact("cauchy.ppf", q, loc, scale)
|
||||
pi = _lax_const(q, np.pi)
|
||||
half_pi = _lax_const(q, np.pi / 2)
|
||||
unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi))
|
||||
return lax.add(lax.mul(unscaled, scale), loc)
|
||||
|
@ -20,6 +20,7 @@ from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import gammainc
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
|
||||
@ -37,6 +38,35 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
|
||||
return where(lax.lt(x, loc), -inf, log_probs)
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, df, loc, scale = _promote_args_inexact("chi2.cdf", x, df, loc, scale)
|
||||
two = _lax_const(scale, 2)
|
||||
return gammainc(
|
||||
lax.div(df, two),
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(
|
||||
lax.sub(x, loc),
|
||||
lax.mul(scale, two),
|
||||
),
|
||||
inf,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, df, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
cdf_result = cdf(x, df, loc, scale)
|
||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
||||
|
@ -19,7 +19,7 @@ from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import gammaln, xlogy
|
||||
from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.logpdf, update_doc=False)
|
||||
@ -32,6 +32,31 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
|
||||
log_probs = lax.sub(log_linear_term, shape_terms)
|
||||
return where(lax.lt(x, loc), -inf, log_probs)
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, a, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, loc, scale = _promote_args_inexact("gamma.cdf", x, a, loc, scale)
|
||||
return gammainc(
|
||||
a,
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(lax.sub(x, loc), scale),
|
||||
inf,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, a, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, loc, scale = _promote_args_inexact("gamma.sf", x, a, loc, scale)
|
||||
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
|
||||
|
@ -24,29 +24,38 @@ from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike) -> Array:
|
||||
x, = _promote_args_inexact("logistic.logpdf", x)
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = _promote_args_inexact("logistic.logpdf", x, loc, scale)
|
||||
x = lax.div(lax.sub(x, loc), scale)
|
||||
two = _lax_const(x, 2)
|
||||
half_x = lax.div(x, two)
|
||||
return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))
|
||||
return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike) -> Array:
|
||||
return lax.exp(logpdf(x))
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.ppf, update_doc=False)
|
||||
def ppf(x):
|
||||
return logit(x)
|
||||
def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = _promote_args_inexact("logistic.ppf", x, loc, scale)
|
||||
return lax.add(lax.mul(logit(x), scale), loc)
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.sf, update_doc=False)
|
||||
def sf(x):
|
||||
return expit(lax.neg(x))
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = _promote_args_inexact("logistic.sf", x, loc, scale)
|
||||
return expit(lax.neg(lax.div(lax.sub(x, loc), scale)))
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.isf, update_doc=False)
|
||||
def isf(x):
|
||||
return -logit(x)
|
||||
def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = _promote_args_inexact("logistic.isf", x, loc, scale)
|
||||
return lax.add(lax.mul(lax.neg(logit(x)), scale), loc)
|
||||
|
||||
|
||||
@_wraps(osp_stats.logistic.cdf, update_doc=False)
|
||||
def cdf(x):
|
||||
return expit(x)
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = _promote_args_inexact("logistic.cdf", x, loc, scale)
|
||||
return expit(lax.div(lax.sub(x, loc), scale))
|
||||
|
@ -25,6 +25,7 @@ from jax._src.numpy.lax_numpy import _promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy import special
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
|
||||
@ -55,3 +56,13 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
@_wraps(osp_stats.norm.ppf, update_doc=False)
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return jnp.asarray(special.ndtri(q) * scale + loc, float)
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.sub(_lax_const(x, 1), cdf(x, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.isf, update_doc=False)
|
||||
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return ppf(lax.sub(_lax_const(q, 1), q), loc, scale)
|
||||
|
@ -37,6 +37,7 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
|
||||
return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
|
||||
|
||||
|
||||
@_wraps(osp_stats.t.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
||||
|
@ -18,4 +18,7 @@
|
||||
from jax._src.scipy.stats.beta import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -18,4 +18,9 @@
|
||||
from jax._src.scipy.stats.cauchy import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
sf as sf,
|
||||
isf as isf,
|
||||
ppf as ppf,
|
||||
)
|
||||
|
@ -18,4 +18,7 @@
|
||||
from jax._src.scipy.stats.chi2 import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -18,4 +18,7 @@
|
||||
from jax._src.scipy.stats.gamma import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -21,4 +21,6 @@ from jax._src.scipy.stats.norm import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
ppf as ppf,
|
||||
sf as sf,
|
||||
isf as isf,
|
||||
)
|
||||
|
@ -226,6 +226,38 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testBetaLogCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.beta.logcdf
|
||||
lax_fun = lsp_stats.beta.logcdf
|
||||
|
||||
def args_maker():
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testBetaSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.beta.sf
|
||||
lax_fun = lsp_stats.beta.sf
|
||||
|
||||
def args_maker():
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||
|
||||
def testBetaLogPdfZero(self):
|
||||
# Regression test for https://github.com/google/jax/issues/7645
|
||||
a = b = 1.
|
||||
@ -250,6 +282,80 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchyLogCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.cauchy.logcdf
|
||||
lax_fun = lsp_stats.cauchy.logcdf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchySf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.cauchy.sf
|
||||
lax_fun = lsp_stats.cauchy.sf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchyIsf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.cauchy.isf
|
||||
lax_fun = lsp_stats.cauchy.isf
|
||||
|
||||
def args_maker():
|
||||
q, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that q is in desired range
|
||||
# since lax.tan and numpy.tan work different near divergence points
|
||||
q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [q, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=2e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchyPpf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.cauchy.ppf
|
||||
lax_fun = lsp_stats.cauchy.ppf
|
||||
|
||||
def args_maker():
|
||||
q, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that q is in desired
|
||||
# since lax.tan and numpy.tan work different near divergence points
|
||||
q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [q, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=2e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shapes=[
|
||||
[x_shape, alpha_shape]
|
||||
@ -321,6 +427,37 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testGammaLogCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.gamma.logcdf
|
||||
lax_fun = lsp_stats.gamma.logcdf
|
||||
|
||||
def args_maker():
|
||||
x, a, loc, scale = map(rng, shapes, dtypes)
|
||||
x = np.clip(x, 0, None).astype(x.dtype)
|
||||
return [x, a, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testGammaLogSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.gamma.sf
|
||||
lax_fun = lsp_stats.gamma.sf
|
||||
|
||||
def args_maker():
|
||||
x, a, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testGammaLogPdfZero(self):
|
||||
# Regression test for https://github.com/google/jax/issues/7256
|
||||
self.assertAllClose(
|
||||
@ -411,28 +548,34 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol={np.float32: 1e-5, np.float64: 1e-6})
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLogisticCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.logistic.cdf
|
||||
lax_fun = lsp_stats.logistic.cdf
|
||||
|
||||
def args_maker():
|
||||
return list(map(rng, shapes, dtypes))
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=3e-5)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLogisticLogpdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.logistic.logpdf
|
||||
lax_fun = lsp_stats.logistic.logpdf
|
||||
|
||||
def args_maker():
|
||||
return list(map(rng, shapes, dtypes))
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
@ -445,32 +588,54 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)),
|
||||
check_dtypes=False)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLogisticPpf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.logistic.ppf
|
||||
lax_fun = lsp_stats.logistic.ppf
|
||||
|
||||
def args_maker():
|
||||
return list(map(rng, shapes, dtypes))
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLogisticSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.logistic.sf
|
||||
lax_fun = lsp_stats.logistic.sf
|
||||
|
||||
def args_maker():
|
||||
return list(map(rng, shapes, dtypes))
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=2e-5)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLogisticIsf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.logistic.isf
|
||||
lax_fun = lsp_stats.logistic.isf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testNormLogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -525,6 +690,24 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testNormSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.norm.sf
|
||||
lax_fun = lsp_stats.norm.sf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testNormPpf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -543,6 +726,25 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testNormIsf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.norm.isf
|
||||
lax_fun = lsp_stats.norm.isf
|
||||
|
||||
def args_maker():
|
||||
q, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure probability is between 0 and 1:
|
||||
q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [q, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testTruncnormLogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -716,6 +918,36 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testChi2LogCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.chi2.logcdf
|
||||
lax_fun = lsp_stats.chi2.logcdf
|
||||
|
||||
def args_maker():
|
||||
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testChi2Sf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.chi2.sf
|
||||
lax_fun = lsp_stats.chi2.sf
|
||||
|
||||
def args_maker():
|
||||
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testBetaBinomLogPmf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user