Merge pull request #17206 from jakevdp:stats-sf

PiperOrigin-RevId: 559241134
This commit is contained in:
jax authors 2023-08-22 15:32:23 -07:00
commit c052bbfa68
11 changed files with 172 additions and 36 deletions

View File

@ -192,6 +192,7 @@ jax.scipy.stats.beta
cdf
logcdf
sf
logsf
jax.scipy.stats.betabinom
~~~~~~~~~~~~~~~~~~~~~~~~~
@ -225,6 +226,7 @@ jax.scipy.stats.cauchy
cdf
logcdf
sf
logsf
isf
ppf
@ -240,6 +242,7 @@ jax.scipy.stats.chi2
cdf
logcdf
sf
logsf
jax.scipy.stats.dirichlet
@ -272,6 +275,7 @@ jax.scipy.stats.gamma
cdf
logcdf
sf
logsf
jax.scipy.stats.gennorm
~~~~~~~~~~~~~~~~~~~~~~~
@ -350,12 +354,13 @@ jax.scipy.stats.norm
.. autosummary::
:toctree: _autosummary
cdf
logcdf
logpdf
pdf
cdf
logcdf
ppf
sf
logsf
isf
jax.scipy.stats.pareto

View File

@ -59,12 +59,26 @@ def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
@_wraps(osp_stats.beta.logcdf, update_doc=False)
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, a, b, loc, scale))
@_wraps(osp_stats.beta.sf, update_doc=False)
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, a, b, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale)
return betainc(
b,
a,
1 - lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, 1),
)
)
@_wraps(osp_stats.beta.logsf, update_doc=False)
def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, a, b, loc, scale))

View File

@ -53,9 +53,14 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
@_wraps(osp_stats.cauchy.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, = promote_args_inexact("cauchy.sf", x)
cdf_result = cdf(x, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale)
return cdf(-x, -loc, scale)
@_wraps(osp_stats.cauchy.logsf, update_doc=False)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale)
return logcdf(-x, -loc, scale)
@_wraps(osp_stats.cauchy.isf, update_doc=False)

View File

@ -20,7 +20,7 @@ import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammainc
from jax.scipy.special import gammainc, gammaincc
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
@ -67,5 +67,21 @@ def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
@_wraps(osp_stats.chi2.sf, update_doc=False)
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, df, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale)
two = _lax_const(scale, 2)
return gammaincc(
lax.div(df, two),
lax.clamp(
_lax_const(x, 0),
lax.div(
lax.sub(x, loc),
lax.mul(scale, two),
),
_lax_const(x, jnp.inf),
),
)
@_wraps(osp_stats.chi2.logsf, update_doc=False)
def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, df, loc, scale))

View File

@ -59,3 +59,8 @@ def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
@_wraps(osp_stats.gamma.logsf, update_doc=False)
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, a, loc, scale))

View File

@ -88,18 +88,7 @@ def pdf(x, a, b, loc=0, scale=1):
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
def logsf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
x = lax.div(lax.sub(x, loc), scale)
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
logsf = jnp.select(
# third condition: avoid catastrophic cancellation (from scipy)
[x >= b, x <= a, logsf > -0.1, x > a],
[-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf]
)
logsf = jnp.where(a >= b, jnp.nan, logsf)
return logsf
return logcdf(-x, -b, -a, -loc, scale)
@_wraps(osp_stats.truncnorm.sf, update_doc=False)

View File

@ -16,9 +16,10 @@
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.stats.beta import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
sf as sf,
)

View File

@ -16,11 +16,12 @@
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.stats.cauchy import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
isf as isf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
ppf as ppf,
sf as sf,
)

View File

@ -16,9 +16,10 @@
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.stats.chi2 import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
sf as sf,
)

View File

@ -16,9 +16,10 @@
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.stats.gamma import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
sf as sf,
)

View File

@ -257,6 +257,22 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
@genNamedParametersNArgs(5)
def testBetaLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.beta.logsf
lax_fun = lsp_stats.beta.logsf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
def testBetaLogPdfZero(self):
# Regression test for https://github.com/google/jax/issues/7645
a = b = 1.
@ -279,7 +295,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker, tol={np.float64: 1E-14})
@genNamedParametersNArgs(3)
def testCauchyLogCdf(self, shapes, dtypes):
@ -299,6 +315,42 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testCauchyCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.cdf
lax_fun = lsp_stats.cauchy.cdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testCauchyLogSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.logsf
lax_fun = lsp_stats.cauchy.logsf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testCauchySf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
@ -314,7 +366,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testCauchyIsf(self, shapes, dtypes):
@ -450,6 +503,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
@genNamedParametersNArgs(4)
def testGammaLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.logsf
lax_fun = lsp_stats.gamma.logsf
def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
return [x, a, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testGammaSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.sf
lax_fun = lsp_stats.gamma.sf
@ -960,6 +1028,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2Cdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.cdf
lax_fun = lsp_stats.chi2.cdf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2Sf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
@ -975,6 +1058,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2LogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.logsf
lax_fun = lsp_stats.chi2.logsf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testBetaBinomLogPmf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())