Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions.

This commit is contained in:
Misha 2023-03-12 06:53:09 +01:00
parent 2002d49230
commit 83b3f5b759
14 changed files with 453 additions and 46 deletions

View File

@ -172,6 +172,9 @@ jax.scipy.stats.beta
logpdf
pdf
cdf
logcdf
sf
jax.scipy.stats.betabinom
~~~~~~~~~~~~~~~~~~~~~~~~~
@ -192,6 +195,11 @@ jax.scipy.stats.cauchy
logpdf
pdf
cdf
logcdf
sf
isf
ppf
jax.scipy.stats.chi2
~~~~~~~~~~~~~~~~~~~~
@ -202,7 +210,9 @@ jax.scipy.stats.chi2
logpdf
pdf
cdf
logcdf
sf
jax.scipy.stats.dirichlet
@ -232,6 +242,9 @@ jax.scipy.stats.gamma
logpdf
pdf
cdf
logcdf
sf
jax.scipy.stats.gennorm
~~~~~~~~~~~~~~~~~~~~~~~
@ -296,6 +309,8 @@ jax.scipy.stats.norm
logpdf
pdf
ppf
sf
isf
jax.scipy.stats.pareto
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -19,7 +19,7 @@ import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import betaln, xlogy, xlog1py
from jax.scipy.special import betaln, betainc, xlogy, xlog1py
@_wraps(osp_stats.beta.logpdf, update_doc=False)
@ -40,3 +40,31 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, a, b, loc, scale))
@_wraps(osp_stats.beta.cdf, update_doc=False)
def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale)
return betainc(
a,
b,
lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, 1),
)
)
@_wraps(osp_stats.beta.logcdf, update_doc=False)
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, a, b, loc, scale))
@_wraps(osp_stats.beta.sf, update_doc=False)
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, a, b, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)

View File

@ -18,8 +18,8 @@ import scipy.stats as osp_stats
from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax.numpy import arctan
from jax._src.typing import Array, ArrayLike
@ -31,6 +31,46 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
normalize_term = lax.log(lax.mul(pi, scale))
return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
@_wraps(osp_stats.cauchy.pdf, update_doc=False)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))
@_wraps(osp_stats.cauchy.cdf, update_doc=False)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale)
pi = _lax_const(x, np.pi)
scaled_x = lax.div(lax.sub(x, loc), scale)
return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x)))
@_wraps(osp_stats.cauchy.logcdf, update_doc=False)
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, loc, scale))
@_wraps(osp_stats.cauchy.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, = promote_args_inexact("cauchy.sf", x)
cdf_result = cdf(x, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
@_wraps(osp_stats.cauchy.isf, update_doc=False)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale)
pi = _lax_const(q, np.pi)
half_pi = _lax_const(q, np.pi / 2)
unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q)))
return lax.add(lax.mul(unscaled, scale), loc)
@_wraps(osp_stats.cauchy.ppf, update_doc=False)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale)
pi = _lax_const(q, np.pi)
half_pi = _lax_const(q, np.pi / 2)
unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi))
return lax.add(lax.mul(unscaled, scale), loc)

View File

@ -20,23 +20,52 @@ import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammainc
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
one = _lax_const(x, 1)
two = _lax_const(x, 2)
y = lax.div(lax.sub(x, loc), scale)
df_on_two = lax.div(df, two)
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
one = _lax_const(x, 1)
two = _lax_const(x, 2)
y = lax.div(lax.sub(x, loc), scale)
df_on_two = lax.div(df, two)
kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))
kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
@_wraps(osp_stats.chi2.pdf, update_doc=False)
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, df, loc, scale))
return lax.exp(logpdf(x, df, loc, scale))
@_wraps(osp_stats.chi2.cdf, update_doc=False)
def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale)
two = _lax_const(scale, 2)
return gammainc(
lax.div(df, two),
lax.clamp(
_lax_const(x, 0),
lax.div(
lax.sub(x, loc),
lax.mul(scale, two),
),
_lax_const(x, jnp.inf),
),
)
@_wraps(osp_stats.chi2.logcdf, update_doc=False)
def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, df, loc, scale))
@_wraps(osp_stats.chi2.sf, update_doc=False)
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, df, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)

View File

@ -19,7 +19,7 @@ import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammaln, xlogy
from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc
@_wraps(osp_stats.gamma.logpdf, update_doc=False)
@ -35,3 +35,27 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
@_wraps(osp_stats.gamma.pdf, update_doc=False)
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, a, loc, scale))
@_wraps(osp_stats.gamma.cdf, update_doc=False)
def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale)
return gammainc(
a,
lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, jnp.inf),
)
)
@_wraps(osp_stats.gamma.logcdf, update_doc=False)
def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, a, loc, scale))
@_wraps(osp_stats.gamma.sf, update_doc=False)
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
return gammaincc(a, lax.div(lax.sub(x, loc), scale))

View File

@ -23,29 +23,38 @@ from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.logistic.logpdf, update_doc=False)
def logpdf(x: ArrayLike) -> Array:
x, = promote_args_inexact("logistic.logpdf", x)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale)
x = lax.div(lax.sub(x, loc), scale)
two = _lax_const(x, 2)
half_x = lax.div(x, two)
return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))
return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale))
@_wraps(osp_stats.logistic.pdf, update_doc=False)
def pdf(x: ArrayLike) -> Array:
return lax.exp(logpdf(x))
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))
@_wraps(osp_stats.logistic.ppf, update_doc=False)
def ppf(x):
return logit(x)
def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale)
return lax.add(lax.mul(logit(x), scale), loc)
@_wraps(osp_stats.logistic.sf, update_doc=False)
def sf(x):
return expit(lax.neg(x))
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale)
return expit(lax.neg(lax.div(lax.sub(x, loc), scale)))
@_wraps(osp_stats.logistic.isf, update_doc=False)
def isf(x):
return -logit(x)
def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale)
return lax.add(lax.mul(lax.neg(logit(x)), scale), loc)
@_wraps(osp_stats.logistic.cdf, update_doc=False)
def cdf(x):
return expit(x)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale)
return expit(lax.div(lax.sub(x, loc), scale))

View File

@ -24,6 +24,7 @@ from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy import special
@_wraps(osp_stats.norm.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale)
@ -54,3 +55,14 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
@_wraps(osp_stats.norm.ppf, update_doc=False)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return jnp.asarray(special.ndtri(q) * scale + loc, float)
@_wraps(osp_stats.norm.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
@_wraps(osp_stats.norm.isf, update_doc=False)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return ppf(lax.sub(_lax_const(q, 1), q), loc, scale)

View File

@ -36,6 +36,7 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
@_wraps(osp_stats.t.pdf, update_doc=False)
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, df, loc, scale))

View File

@ -18,4 +18,7 @@
from jax._src.scipy.stats.beta import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
)

View File

@ -18,4 +18,9 @@
from jax._src.scipy.stats.cauchy import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
isf as isf,
ppf as ppf,
)

View File

@ -18,4 +18,7 @@
from jax._src.scipy.stats.chi2 import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
)

View File

@ -18,4 +18,7 @@
from jax._src.scipy.stats.gamma import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
)

View File

@ -21,4 +21,6 @@ from jax._src.scipy.stats.norm import (
logpdf as logpdf,
pdf as pdf,
ppf as ppf,
sf as sf,
isf as isf,
)

View File

@ -190,7 +190,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(3)
def testGeomLogPmf(self, shapes, dtypes):
@ -226,6 +226,38 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
@genNamedParametersNArgs(5)
def testBetaLogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.beta.logcdf
lax_fun = lsp_stats.beta.logcdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
@genNamedParametersNArgs(5)
def testBetaSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.beta.sf
lax_fun = lsp_stats.beta.sf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
def testBetaLogPdfZero(self):
# Regression test for https://github.com/google/jax/issues/7645
a = b = 1.
@ -250,6 +282,80 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testCauchyLogCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.logcdf
lax_fun = lsp_stats.cauchy.logcdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testCauchySf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.sf
lax_fun = lsp_stats.cauchy.sf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testCauchyIsf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.isf
lax_fun = lsp_stats.cauchy.isf
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that q is in desired range
# since lax.tan and numpy.tan work different near divergence points
q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [q, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=2e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(3)
def testCauchyPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.ppf
lax_fun = lsp_stats.cauchy.ppf
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that q is in desired
# since lax.tan and numpy.tan work different near divergence points
q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [q, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=2e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@jtu.sample_product(
shapes=[
[x_shape, alpha_shape]
@ -326,6 +432,37 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self.assertAllClose(
osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6)
@genNamedParametersNArgs(4)
def testGammaLogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.logcdf
lax_fun = lsp_stats.gamma.logcdf
def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
x = np.clip(x, 0, None).astype(x.dtype)
return [x, a, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testGammaLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.sf
lax_fun = lsp_stats.gamma.sf
def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
return [x, a, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testGenNormLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
@ -411,32 +548,39 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol={np.float32: 1e-5, np.float64: 1e-6})
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1)
@genNamedParametersNArgs(3)
def testLogisticCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.cdf
lax_fun = lsp_stats.logistic.cdf
def args_maker():
return list(map(rng, shapes, dtypes))
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=3e-5)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1)
@genNamedParametersNArgs(3)
def testLogisticLogpdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.logpdf
lax_fun = lsp_stats.logistic.logpdf
def args_maker():
return list(map(rng, shapes, dtypes))
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
def testLogisticLogpdfOverflow(self):
# Regression test for https://github.com/google/jax/issues/10219
@ -445,31 +589,56 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)),
check_dtypes=False)
@genNamedParametersNArgs(1)
@genNamedParametersNArgs(3)
def testLogisticPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.ppf
lax_fun = lsp_stats.logistic.ppf
def args_maker():
return list(map(rng, shapes, dtypes))
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(1)
@genNamedParametersNArgs(3)
def testLogisticSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.sf
lax_fun = lsp_stats.logistic.sf
def args_maker():
return list(map(rng, shapes, dtypes))
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=2e-5)
self._CompileAndCheck(lax_fun, args_maker)
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=2e-5)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testLogisticIsf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.isf
lax_fun = lsp_stats.logistic.isf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(3)
def testNormLogPdf(self, shapes, dtypes):
@ -524,6 +693,22 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testNormSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.sf
lax_fun = lsp_stats.norm.sf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testNormPpf(self, shapes, dtypes):
@ -543,6 +728,24 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(3)
def testNormIsf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.isf
lax_fun = lsp_stats.norm.isf
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
# ensure probability is between 0 and 1:
q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [q, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(5)
def testTruncnormLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
@ -716,6 +919,36 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2LogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.logcdf
lax_fun = lsp_stats.chi2.logcdf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2Sf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.sf
lax_fun = lsp_stats.chi2.sf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testBetaBinomLogPmf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())