mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #17206 from jakevdp:stats-sf
PiperOrigin-RevId: 559241134
This commit is contained in:
commit
c052bbfa68
@ -192,6 +192,7 @@ jax.scipy.stats.beta
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
logsf
|
||||
|
||||
jax.scipy.stats.betabinom
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -225,6 +226,7 @@ jax.scipy.stats.cauchy
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
logsf
|
||||
isf
|
||||
ppf
|
||||
|
||||
@ -240,6 +242,7 @@ jax.scipy.stats.chi2
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
logsf
|
||||
|
||||
|
||||
jax.scipy.stats.dirichlet
|
||||
@ -272,6 +275,7 @@ jax.scipy.stats.gamma
|
||||
cdf
|
||||
logcdf
|
||||
sf
|
||||
logsf
|
||||
|
||||
jax.scipy.stats.gennorm
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -350,12 +354,13 @@ jax.scipy.stats.norm
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
cdf
|
||||
logcdf
|
||||
logpdf
|
||||
pdf
|
||||
cdf
|
||||
logcdf
|
||||
ppf
|
||||
sf
|
||||
logsf
|
||||
isf
|
||||
|
||||
jax.scipy.stats.pareto
|
||||
|
@ -59,12 +59,26 @@ def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
|
||||
@_wraps(osp_stats.beta.logcdf, update_doc=False)
|
||||
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(cdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
cdf_result = cdf(x, a, b, loc, scale)
|
||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale)
|
||||
return betainc(
|
||||
b,
|
||||
a,
|
||||
1 - lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(lax.sub(x, loc), scale),
|
||||
_lax_const(x, 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.beta.logsf, update_doc=False)
|
||||
def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(sf(x, a, b, loc, scale))
|
||||
|
@ -53,9 +53,14 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
|
||||
@_wraps(osp_stats.cauchy.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, = promote_args_inexact("cauchy.sf", x)
|
||||
cdf_result = cdf(x, loc, scale)
|
||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
||||
x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale)
|
||||
return cdf(-x, -loc, scale)
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.logsf, update_doc=False)
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale)
|
||||
return logcdf(-x, -loc, scale)
|
||||
|
||||
|
||||
@_wraps(osp_stats.cauchy.isf, update_doc=False)
|
||||
|
@ -20,7 +20,7 @@ import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax.scipy.special import gammainc
|
||||
from jax.scipy.special import gammainc, gammaincc
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
|
||||
@ -67,5 +67,21 @@ def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
|
||||
@_wraps(osp_stats.chi2.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
cdf_result = cdf(x, df, loc, scale)
|
||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
||||
x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale)
|
||||
two = _lax_const(scale, 2)
|
||||
return gammaincc(
|
||||
lax.div(df, two),
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(
|
||||
lax.sub(x, loc),
|
||||
lax.mul(scale, two),
|
||||
),
|
||||
_lax_const(x, jnp.inf),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@_wraps(osp_stats.chi2.logsf, update_doc=False)
|
||||
def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(sf(x, df, loc, scale))
|
||||
|
@ -59,3 +59,8 @@ def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
|
||||
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
|
||||
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
|
||||
|
||||
|
||||
@_wraps(osp_stats.gamma.logsf, update_doc=False)
|
||||
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.log(sf(x, a, loc, scale))
|
||||
|
@ -88,18 +88,7 @@ def pdf(x, a, b, loc=0, scale=1):
|
||||
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
|
||||
def logsf(x, a, b, loc=0, scale=1):
|
||||
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
|
||||
x, a, b = jnp.broadcast_arrays(x, a, b)
|
||||
x = lax.div(lax.sub(x, loc), scale)
|
||||
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
|
||||
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
|
||||
|
||||
logsf = jnp.select(
|
||||
# third condition: avoid catastrophic cancellation (from scipy)
|
||||
[x >= b, x <= a, logsf > -0.1, x > a],
|
||||
[-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf]
|
||||
)
|
||||
logsf = jnp.where(a >= b, jnp.nan, logsf)
|
||||
return logsf
|
||||
return logcdf(-x, -b, -a, -loc, scale)
|
||||
|
||||
|
||||
@_wraps(osp_stats.truncnorm.sf, update_doc=False)
|
||||
|
@ -16,9 +16,10 @@
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.scipy.stats.beta import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
logpdf as logpdf,
|
||||
logsf as logsf,
|
||||
pdf as pdf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -16,11 +16,12 @@
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.scipy.stats.cauchy import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
sf as sf,
|
||||
isf as isf,
|
||||
logcdf as logcdf,
|
||||
logpdf as logpdf,
|
||||
logsf as logsf,
|
||||
pdf as pdf,
|
||||
ppf as ppf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -16,9 +16,10 @@
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.scipy.stats.chi2 import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
logpdf as logpdf,
|
||||
logsf as logsf,
|
||||
pdf as pdf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -16,9 +16,10 @@
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.scipy.stats.gamma import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
logpdf as logpdf,
|
||||
logsf as logsf,
|
||||
pdf as pdf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -257,6 +257,22 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testBetaLogSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.beta.logsf
|
||||
lax_fun = lsp_stats.beta.logsf
|
||||
|
||||
def args_maker():
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||
|
||||
def testBetaLogPdfZero(self):
|
||||
# Regression test for https://github.com/google/jax/issues/7645
|
||||
a = b = 1.
|
||||
@ -279,7 +295,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
self._CompileAndCheck(lax_fun, args_maker, tol={np.float64: 1E-14})
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchyLogCdf(self, shapes, dtypes):
|
||||
@ -299,6 +315,42 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
|
||||
atol={np.float64: 1e-14})
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchyCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.cauchy.cdf
|
||||
lax_fun = lsp_stats.cauchy.cdf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
|
||||
atol={np.float64: 1e-14})
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchyLogSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.cauchy.logsf
|
||||
lax_fun = lsp_stats.cauchy.logsf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
|
||||
atol={np.float64: 1e-14})
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchySf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -314,7 +366,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
|
||||
atol={np.float64: 1e-14})
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testCauchyIsf(self, shapes, dtypes):
|
||||
@ -450,6 +503,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testGammaLogSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.gamma.logsf
|
||||
lax_fun = lsp_stats.gamma.logsf
|
||||
|
||||
def args_maker():
|
||||
x, a, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testGammaSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.gamma.sf
|
||||
lax_fun = lsp_stats.gamma.sf
|
||||
@ -960,6 +1028,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testChi2Cdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.chi2.cdf
|
||||
lax_fun = lsp_stats.chi2.cdf
|
||||
|
||||
def args_maker():
|
||||
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testChi2Sf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
@ -975,6 +1058,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testChi2LogSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.chi2.logsf
|
||||
lax_fun = lsp_stats.chi2.logsf
|
||||
|
||||
def args_maker():
|
||||
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testBetaBinomLogPmf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user