jax.scipy.stats: add logsf & make sf more accurate near zero

This commit is contained in:
Jake VanderPlas 2023-08-22 14:44:32 -07:00
parent 6670ea46d9
commit d1c2277bfc
11 changed files with 172 additions and 36 deletions

View File

@ -192,6 +192,7 @@ jax.scipy.stats.beta
cdf cdf
logcdf logcdf
sf sf
logsf
jax.scipy.stats.betabinom jax.scipy.stats.betabinom
~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~
@ -225,6 +226,7 @@ jax.scipy.stats.cauchy
cdf cdf
logcdf logcdf
sf sf
logsf
isf isf
ppf ppf
@ -240,6 +242,7 @@ jax.scipy.stats.chi2
cdf cdf
logcdf logcdf
sf sf
logsf
jax.scipy.stats.dirichlet jax.scipy.stats.dirichlet
@ -272,6 +275,7 @@ jax.scipy.stats.gamma
cdf cdf
logcdf logcdf
sf sf
logsf
jax.scipy.stats.gennorm jax.scipy.stats.gennorm
~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~
@ -350,12 +354,13 @@ jax.scipy.stats.norm
.. autosummary:: .. autosummary::
:toctree: _autosummary :toctree: _autosummary
cdf
logcdf
logpdf logpdf
pdf pdf
cdf
logcdf
ppf ppf
sf sf
logsf
isf isf
jax.scipy.stats.pareto jax.scipy.stats.pareto

View File

@ -59,12 +59,26 @@ def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
@_wraps(osp_stats.beta.logcdf, update_doc=False) @_wraps(osp_stats.beta.logcdf, update_doc=False)
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, a, b, loc, scale)) return lax.log(cdf(x, a, b, loc, scale))
@_wraps(osp_stats.beta.sf, update_doc=False) @_wraps(osp_stats.beta.sf, update_doc=False)
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike, def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, a, b, loc, scale) x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result) return betainc(
b,
a,
1 - lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, 1),
)
)
@_wraps(osp_stats.beta.logsf, update_doc=False)
def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, a, b, loc, scale))

View File

@ -53,9 +53,14 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
@_wraps(osp_stats.cauchy.sf, update_doc=False) @_wraps(osp_stats.cauchy.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, = promote_args_inexact("cauchy.sf", x) x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale)
cdf_result = cdf(x, loc, scale) return cdf(-x, -loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
@_wraps(osp_stats.cauchy.logsf, update_doc=False)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale)
return logcdf(-x, -loc, scale)
@_wraps(osp_stats.cauchy.isf, update_doc=False) @_wraps(osp_stats.cauchy.isf, update_doc=False)

View File

@ -20,7 +20,7 @@ import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammainc from jax.scipy.special import gammainc, gammaincc
@_wraps(osp_stats.chi2.logpdf, update_doc=False) @_wraps(osp_stats.chi2.logpdf, update_doc=False)
@ -67,5 +67,21 @@ def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
@_wraps(osp_stats.chi2.sf, update_doc=False) @_wraps(osp_stats.chi2.sf, update_doc=False)
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, df, loc, scale) x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result) two = _lax_const(scale, 2)
return gammaincc(
lax.div(df, two),
lax.clamp(
_lax_const(x, 0),
lax.div(
lax.sub(x, loc),
lax.mul(scale, two),
),
_lax_const(x, jnp.inf),
),
)
@_wraps(osp_stats.chi2.logsf, update_doc=False)
def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, df, loc, scale))

View File

@ -59,3 +59,8 @@ def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale) x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
return gammaincc(a, lax.div(lax.sub(x, loc), scale)) return gammaincc(a, lax.div(lax.sub(x, loc), scale))
@_wraps(osp_stats.gamma.logsf, update_doc=False)
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, a, loc, scale))

View File

@ -88,18 +88,7 @@ def pdf(x, a, b, loc=0, scale=1):
@_wraps(osp_stats.truncnorm.logsf, update_doc=False) @_wraps(osp_stats.truncnorm.logsf, update_doc=False)
def logsf(x, a, b, loc=0, scale=1): def logsf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale) x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b) return logcdf(-x, -b, -a, -loc, scale)
x = lax.div(lax.sub(x, loc), scale)
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
logsf = jnp.select(
# third condition: avoid catastrophic cancellation (from scipy)
[x >= b, x <= a, logsf > -0.1, x > a],
[-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf]
)
logsf = jnp.where(a >= b, jnp.nan, logsf)
return logsf
@_wraps(osp_stats.truncnorm.sf, update_doc=False) @_wraps(osp_stats.truncnorm.sf, update_doc=False)

View File

@ -16,9 +16,10 @@
# See PEP 484 & https://github.com/google/jax/issues/7570 # See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.stats.beta import ( from jax._src.scipy.stats.beta import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf, cdf as cdf,
logcdf as logcdf, logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
sf as sf, sf as sf,
) )

View File

@ -16,11 +16,12 @@
# See PEP 484 & https://github.com/google/jax/issues/7570 # See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.stats.cauchy import ( from jax._src.scipy.stats.cauchy import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf, cdf as cdf,
logcdf as logcdf,
sf as sf,
isf as isf, isf as isf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
ppf as ppf, ppf as ppf,
sf as sf,
) )

View File

@ -16,9 +16,10 @@
# See PEP 484 & https://github.com/google/jax/issues/7570 # See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.stats.chi2 import ( from jax._src.scipy.stats.chi2 import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf, cdf as cdf,
logcdf as logcdf, logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
sf as sf, sf as sf,
) )

View File

@ -16,9 +16,10 @@
# See PEP 484 & https://github.com/google/jax/issues/7570 # See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.stats.gamma import ( from jax._src.scipy.stats.gamma import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf, cdf as cdf,
logcdf as logcdf, logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
sf as sf, sf as sf,
) )

View File

@ -257,6 +257,22 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker, self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4}) rtol={np.float32: 2e-3, np.float64: 1e-4})
@genNamedParametersNArgs(5)
def testBetaLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.beta.logsf
lax_fun = lsp_stats.beta.logsf
def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
def testBetaLogPdfZero(self): def testBetaLogPdfZero(self):
# Regression test for https://github.com/google/jax/issues/7645 # Regression test for https://github.com/google/jax/issues/7645
a = b = 1. a = b = 1.
@ -279,7 +295,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
with jtu.strict_promotion_if_dtypes_match(dtypes): with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4) tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker, tol={np.float64: 1E-14})
@genNamedParametersNArgs(3) @genNamedParametersNArgs(3)
def testCauchyLogCdf(self, shapes, dtypes): def testCauchyLogCdf(self, shapes, dtypes):
@ -299,6 +315,42 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}, self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14}) atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testCauchyCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.cdf
lax_fun = lsp_stats.cauchy.cdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testCauchyLogSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.logsf
lax_fun = lsp_stats.cauchy.logsf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})
@genNamedParametersNArgs(3) @genNamedParametersNArgs(3)
def testCauchySf(self, shapes, dtypes): def testCauchySf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
@ -314,7 +366,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
with jtu.strict_promotion_if_dtypes_match(dtypes): with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4) tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})
@genNamedParametersNArgs(3) @genNamedParametersNArgs(3)
def testCauchyIsf(self, shapes, dtypes): def testCauchyIsf(self, shapes, dtypes):
@ -450,6 +503,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
@genNamedParametersNArgs(4) @genNamedParametersNArgs(4)
def testGammaLogSf(self, shapes, dtypes): def testGammaLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.logsf
lax_fun = lsp_stats.gamma.logsf
def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
return [x, a, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testGammaSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng()) rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.sf scipy_fun = osp_stats.gamma.sf
lax_fun = lsp_stats.gamma.sf lax_fun = lsp_stats.gamma.sf
@ -960,6 +1028,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=5e-4) tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2Cdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.cdf
lax_fun = lsp_stats.chi2.cdf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4) @genNamedParametersNArgs(4)
def testChi2Sf(self, shapes, dtypes): def testChi2Sf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng()) rng = jtu.rand_positive(self.rng())
@ -975,6 +1058,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=5e-4) tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2LogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.logsf
lax_fun = lsp_stats.chi2.logsf
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5) @genNamedParametersNArgs(5)
def testBetaBinomLogPmf(self, shapes, dtypes): def testBetaBinomLogPmf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng()) rng = jtu.rand_positive(self.rng())