Fixed loc and scale parameters for logistic distribution. CDF and SF have been added for several distributions, including cauchy, gamma, logistic, chi2 and beta. ISF and PPF have also been added for cauchy and logistic.

This commit is contained in:
Misha 2023-03-06 00:11:40 +01:00
parent 2ccd785e16
commit feb9ab33af
14 changed files with 440 additions and 34 deletions

View File

@ -172,6 +172,9 @@ jax.scipy.stats.beta
logpdf
pdf
cdf
logcdf
sf
jax.scipy.stats.betabinom
~~~~~~~~~~~~~~~~~~~~~~~~~
@ -192,6 +195,11 @@ jax.scipy.stats.cauchy
logpdf
pdf
cdf
logcdf
sf
isf
ppf
jax.scipy.stats.chi2
~~~~~~~~~~~~~~~~~~~~
@ -202,7 +210,9 @@ jax.scipy.stats.chi2
logpdf
pdf
cdf
logcdf
sf
jax.scipy.stats.dirichlet
@ -232,6 +242,9 @@ jax.scipy.stats.gamma
logpdf
pdf
cdf
logcdf
sf
jax.scipy.stats.gennorm
~~~~~~~~~~~~~~~~~~~~~~~
@ -296,6 +309,8 @@ jax.scipy.stats.norm
logpdf
pdf
ppf
sf
isf
jax.scipy.stats.pareto
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -19,7 +19,7 @@ from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import betaln, xlogy, xlog1py
from jax.scipy.special import betaln, betainc, xlogy, xlog1py
@_wraps(osp_stats.beta.logpdf, update_doc=False)
@ -40,3 +40,31 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, a, b, loc, scale))
@_wraps(osp_stats.beta.cdf, update_doc=False)
def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, b, loc, scale = _promote_args_inexact("beta.cdf", x, a, b, loc, scale)
return betainc(
a,
b,
lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, 1),
)
)
@_wraps(osp_stats.beta.logcdf, update_doc=False)
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, a, b, loc, scale))
@_wraps(osp_stats.beta.sf, update_doc=False)
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, a, b, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)

View File

@ -20,6 +20,7 @@ from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.numpy.lax_numpy import arctan
from jax._src.typing import Array, ArrayLike
@ -31,6 +32,44 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
normalize_term = lax.log(lax.mul(pi, scale))
return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
@_wraps(osp_stats.cauchy.pdf, update_doc=False)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))
@_wraps(osp_stats.cauchy.cdf, update_doc=False)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("cauchy.cdf", x, loc, scale)
pi = _lax_const(x, np.pi)
scaled_x = lax.div(lax.sub(x, loc), scale)
return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x)))
@_wraps(osp_stats.cauchy.logcdf, update_doc=False)
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, loc, scale))
@_wraps(osp_stats.cauchy.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, = _promote_args_inexact("cauchy.sf", x)
return lax.sub(_lax_const(x, 1), cdf(x, loc, scale))
@_wraps(osp_stats.cauchy.isf, update_doc=False)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
q, loc, scale = _promote_args_inexact("cauchy.isf", q, loc, scale)
pi = _lax_const(q, np.pi)
half_pi = _lax_const(q, np.pi / 2)
unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q)))
return lax.add(lax.mul(unscaled, scale), loc)
@_wraps(osp_stats.cauchy.ppf, update_doc=False)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
q, loc, scale = _promote_args_inexact("cauchy.ppf", q, loc, scale)
pi = _lax_const(q, np.pi)
half_pi = _lax_const(q, np.pi / 2)
unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi))
return lax.add(lax.mul(unscaled, scale), loc)

View File

@ -20,6 +20,7 @@ from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammainc
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
@ -37,6 +38,35 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
return where(lax.lt(x, loc), -inf, log_probs)
@_wraps(osp_stats.chi2.pdf, update_doc=False)
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, df, loc, scale))
@_wraps(osp_stats.chi2.cdf, update_doc=False)
def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, df, loc, scale = _promote_args_inexact("chi2.cdf", x, df, loc, scale)
two = _lax_const(scale, 2)
return gammainc(
lax.div(df, two),
lax.clamp(
_lax_const(x, 0),
lax.div(
lax.sub(x, loc),
lax.mul(scale, two),
),
inf,
),
)
@_wraps(osp_stats.chi2.logcdf, update_doc=False)
def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, df, loc, scale))
@_wraps(osp_stats.chi2.sf, update_doc=False)
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, df, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)

View File

@ -19,7 +19,7 @@ from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammaln, xlogy
from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc
@_wraps(osp_stats.gamma.logpdf, update_doc=False)
@ -32,6 +32,31 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
log_probs = lax.sub(log_linear_term, shape_terms)
return where(lax.lt(x, loc), -inf, log_probs)
@_wraps(osp_stats.gamma.pdf, update_doc=False)
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, a, loc, scale))
@_wraps(osp_stats.gamma.cdf, update_doc=False)
def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = _promote_args_inexact("gamma.cdf", x, a, loc, scale)
return gammainc(
a,
lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
inf,
)
)
@_wraps(osp_stats.gamma.logcdf, update_doc=False)
def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, a, loc, scale))
@_wraps(osp_stats.gamma.sf, update_doc=False)
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = _promote_args_inexact("gamma.sf", x, a, loc, scale)
return gammaincc(a, lax.div(lax.sub(x, loc), scale))

View File

@ -24,29 +24,38 @@ from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.logistic.logpdf, update_doc=False)
def logpdf(x: ArrayLike) -> Array:
x, = _promote_args_inexact("logistic.logpdf", x)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("logistic.logpdf", x, loc, scale)
x = lax.div(lax.sub(x, loc), scale)
two = _lax_const(x, 2)
half_x = lax.div(x, two)
return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))
return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale))
@_wraps(osp_stats.logistic.pdf, update_doc=False)
def pdf(x: ArrayLike) -> Array:
return lax.exp(logpdf(x))
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))
@_wraps(osp_stats.logistic.ppf, update_doc=False)
def ppf(x):
return logit(x)
def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("logistic.ppf", x, loc, scale)
return lax.add(lax.mul(logit(x), scale), loc)
@_wraps(osp_stats.logistic.sf, update_doc=False)
def sf(x):
return expit(lax.neg(x))
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("logistic.sf", x, loc, scale)
return expit(lax.neg(lax.div(lax.sub(x, loc), scale)))
@_wraps(osp_stats.logistic.isf, update_doc=False)
def isf(x):
return -logit(x)
def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("logistic.isf", x, loc, scale)
return lax.add(lax.mul(lax.neg(logit(x)), scale), loc)
@_wraps(osp_stats.logistic.cdf, update_doc=False)
def cdf(x):
return expit(x)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("logistic.cdf", x, loc, scale)
return expit(lax.div(lax.sub(x, loc), scale))

View File

@ -25,6 +25,7 @@ from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy import special
@_wraps(osp_stats.norm.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
@ -55,3 +56,13 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
@_wraps(osp_stats.norm.ppf, update_doc=False)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return jnp.asarray(special.ndtri(q) * scale + loc, float)
@_wraps(osp_stats.norm.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.sub(_lax_const(x, 1), cdf(x, loc, scale))
@_wraps(osp_stats.norm.isf, update_doc=False)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return ppf(lax.sub(_lax_const(q, 1), q), loc, scale)

View File

@ -37,6 +37,7 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
@_wraps(osp_stats.t.pdf, update_doc=False)
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, df, loc, scale))

View File

@ -18,4 +18,7 @@
from jax._src.scipy.stats.beta import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
)

View File

@ -18,4 +18,9 @@
from jax._src.scipy.stats.cauchy import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
isf as isf,
ppf as ppf,
)

View File

@ -18,4 +18,7 @@
from jax._src.scipy.stats.chi2 import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
)

View File

@ -18,4 +18,7 @@
from jax._src.scipy.stats.gamma import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
)

View File

@ -21,4 +21,6 @@ from jax._src.scipy.stats.norm import (
logpdf as logpdf,
pdf as pdf,
ppf as ppf,
sf as sf,
isf as isf,
)

View File

@ -226,6 +226,38 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
@genNamedParametersNArgs(5)
def testBetaLogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.beta.logcdf
lax_fun = lsp_stats.beta.logcdf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
@genNamedParametersNArgs(5)
def testBetaSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.beta.sf
lax_fun = lsp_stats.beta.sf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
def testBetaLogPdfZero(self):
# Regression test for https://github.com/google/jax/issues/7645
a = b = 1.
@ -250,6 +282,80 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testCauchyLogCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.logcdf
lax_fun = lsp_stats.cauchy.logcdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testCauchySf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.sf
lax_fun = lsp_stats.cauchy.sf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testCauchyIsf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.isf
lax_fun = lsp_stats.cauchy.isf
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that q is in desired range
# since lax.tan and numpy.tan work different near divergence points
q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [q, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=2e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testCauchyPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.ppf
lax_fun = lsp_stats.cauchy.ppf
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that q is in desired
# since lax.tan and numpy.tan work different near divergence points
q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [q, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=2e-4)
self._CompileAndCheck(lax_fun, args_maker)
@jtu.sample_product(
shapes=[
[x_shape, alpha_shape]
@ -321,6 +427,37 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testGammaLogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.logcdf
lax_fun = lsp_stats.gamma.logcdf
def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
x = np.clip(x, 0, None).astype(x.dtype)
return [x, a, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testGammaLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.sf
lax_fun = lsp_stats.gamma.sf
def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
return [x, a, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
def testGammaLogPdfZero(self):
# Regression test for https://github.com/google/jax/issues/7256
self.assertAllClose(
@ -411,28 +548,34 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol={np.float32: 1e-5, np.float64: 1e-6})
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1)
@genNamedParametersNArgs(3)
def testLogisticCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.cdf
lax_fun = lsp_stats.logistic.cdf
def args_maker():
return list(map(rng, shapes, dtypes))
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=3e-5)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1)
@genNamedParametersNArgs(3)
def testLogisticLogpdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.logpdf
lax_fun = lsp_stats.logistic.logpdf
def args_maker():
return list(map(rng, shapes, dtypes))
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
@ -445,32 +588,54 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)),
check_dtypes=False)
@genNamedParametersNArgs(1)
@genNamedParametersNArgs(3)
def testLogisticPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.ppf
lax_fun = lsp_stats.logistic.ppf
def args_maker():
return list(map(rng, shapes, dtypes))
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1)
@genNamedParametersNArgs(3)
def testLogisticSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.sf
lax_fun = lsp_stats.logistic.sf
def args_maker():
return list(map(rng, shapes, dtypes))
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=2e-5)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testLogisticIsf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.logistic.isf
lax_fun = lsp_stats.logistic.isf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testNormLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
@ -525,6 +690,24 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testNormSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.sf
lax_fun = lsp_stats.norm.sf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testNormPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
@ -543,6 +726,25 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(3)
def testNormIsf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.isf
lax_fun = lsp_stats.norm.isf
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
# ensure probability is between 0 and 1:
q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [q, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(5)
def testTruncnormLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
@ -716,6 +918,36 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2LogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.logcdf
lax_fun = lsp_stats.chi2.logcdf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2Sf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.sf
lax_fun = lsp_stats.chi2.sf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testBetaBinomLogPmf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())