mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jax.scipy.stats: add logsf & make sf more accurate near zero
This commit is contained in:
parent
6670ea46d9
commit
d1c2277bfc
@ -192,6 +192,7 @@ jax.scipy.stats.beta
|
|||||||
cdf
|
cdf
|
||||||
logcdf
|
logcdf
|
||||||
sf
|
sf
|
||||||
|
logsf
|
||||||
|
|
||||||
jax.scipy.stats.betabinom
|
jax.scipy.stats.betabinom
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
@ -225,6 +226,7 @@ jax.scipy.stats.cauchy
|
|||||||
cdf
|
cdf
|
||||||
logcdf
|
logcdf
|
||||||
sf
|
sf
|
||||||
|
logsf
|
||||||
isf
|
isf
|
||||||
ppf
|
ppf
|
||||||
|
|
||||||
@ -240,6 +242,7 @@ jax.scipy.stats.chi2
|
|||||||
cdf
|
cdf
|
||||||
logcdf
|
logcdf
|
||||||
sf
|
sf
|
||||||
|
logsf
|
||||||
|
|
||||||
|
|
||||||
jax.scipy.stats.dirichlet
|
jax.scipy.stats.dirichlet
|
||||||
@ -272,6 +275,7 @@ jax.scipy.stats.gamma
|
|||||||
cdf
|
cdf
|
||||||
logcdf
|
logcdf
|
||||||
sf
|
sf
|
||||||
|
logsf
|
||||||
|
|
||||||
jax.scipy.stats.gennorm
|
jax.scipy.stats.gennorm
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
@ -350,12 +354,13 @@ jax.scipy.stats.norm
|
|||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
cdf
|
|
||||||
logcdf
|
|
||||||
logpdf
|
logpdf
|
||||||
pdf
|
pdf
|
||||||
|
cdf
|
||||||
|
logcdf
|
||||||
ppf
|
ppf
|
||||||
sf
|
sf
|
||||||
|
logsf
|
||||||
isf
|
isf
|
||||||
|
|
||||||
jax.scipy.stats.pareto
|
jax.scipy.stats.pareto
|
||||||
|
@ -59,12 +59,26 @@ def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
|||||||
|
|
||||||
@_wraps(osp_stats.beta.logcdf, update_doc=False)
|
@_wraps(osp_stats.beta.logcdf, update_doc=False)
|
||||||
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
return lax.log(cdf(x, a, b, loc, scale))
|
return lax.log(cdf(x, a, b, loc, scale))
|
||||||
|
|
||||||
|
|
||||||
@_wraps(osp_stats.beta.sf, update_doc=False)
|
@_wraps(osp_stats.beta.sf, update_doc=False)
|
||||||
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
cdf_result = cdf(x, a, b, loc, scale)
|
x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale)
|
||||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
return betainc(
|
||||||
|
b,
|
||||||
|
a,
|
||||||
|
1 - lax.clamp(
|
||||||
|
_lax_const(x, 0),
|
||||||
|
lax.div(lax.sub(x, loc), scale),
|
||||||
|
_lax_const(x, 1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@_wraps(osp_stats.beta.logsf, update_doc=False)
|
||||||
|
def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||||
|
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
|
return lax.log(sf(x, a, b, loc, scale))
|
||||||
|
@ -53,9 +53,14 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
|||||||
|
|
||||||
@_wraps(osp_stats.cauchy.sf, update_doc=False)
|
@_wraps(osp_stats.cauchy.sf, update_doc=False)
|
||||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
x, = promote_args_inexact("cauchy.sf", x)
|
x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale)
|
||||||
cdf_result = cdf(x, loc, scale)
|
return cdf(-x, -loc, scale)
|
||||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
|
||||||
|
|
||||||
|
@_wraps(osp_stats.cauchy.logsf, update_doc=False)
|
||||||
|
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
|
x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale)
|
||||||
|
return logcdf(-x, -loc, scale)
|
||||||
|
|
||||||
|
|
||||||
@_wraps(osp_stats.cauchy.isf, update_doc=False)
|
@_wraps(osp_stats.cauchy.isf, update_doc=False)
|
||||||
|
@ -20,7 +20,7 @@ import jax.numpy as jnp
|
|||||||
from jax._src.lax.lax import _const as _lax_const
|
from jax._src.lax.lax import _const as _lax_const
|
||||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||||
from jax._src.typing import Array, ArrayLike
|
from jax._src.typing import Array, ArrayLike
|
||||||
from jax.scipy.special import gammainc
|
from jax.scipy.special import gammainc, gammaincc
|
||||||
|
|
||||||
|
|
||||||
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
|
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
|
||||||
@ -67,5 +67,21 @@ def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
|||||||
|
|
||||||
@_wraps(osp_stats.chi2.sf, update_doc=False)
|
@_wraps(osp_stats.chi2.sf, update_doc=False)
|
||||||
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
cdf_result = cdf(x, df, loc, scale)
|
x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale)
|
||||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
two = _lax_const(scale, 2)
|
||||||
|
return gammaincc(
|
||||||
|
lax.div(df, two),
|
||||||
|
lax.clamp(
|
||||||
|
_lax_const(x, 0),
|
||||||
|
lax.div(
|
||||||
|
lax.sub(x, loc),
|
||||||
|
lax.mul(scale, two),
|
||||||
|
),
|
||||||
|
_lax_const(x, jnp.inf),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@_wraps(osp_stats.chi2.logsf, update_doc=False)
|
||||||
|
def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
|
return lax.log(sf(x, df, loc, scale))
|
||||||
|
@ -59,3 +59,8 @@ def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
|
|||||||
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
|
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
|
||||||
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
|
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
|
||||||
|
|
||||||
|
|
||||||
|
@_wraps(osp_stats.gamma.logsf, update_doc=False)
|
||||||
|
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
|
return lax.log(sf(x, a, loc, scale))
|
||||||
|
@ -88,18 +88,7 @@ def pdf(x, a, b, loc=0, scale=1):
|
|||||||
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
|
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
|
||||||
def logsf(x, a, b, loc=0, scale=1):
|
def logsf(x, a, b, loc=0, scale=1):
|
||||||
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
|
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
|
||||||
x, a, b = jnp.broadcast_arrays(x, a, b)
|
return logcdf(-x, -b, -a, -loc, scale)
|
||||||
x = lax.div(lax.sub(x, loc), scale)
|
|
||||||
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
|
|
||||||
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
|
|
||||||
|
|
||||||
logsf = jnp.select(
|
|
||||||
# third condition: avoid catastrophic cancellation (from scipy)
|
|
||||||
[x >= b, x <= a, logsf > -0.1, x > a],
|
|
||||||
[-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf]
|
|
||||||
)
|
|
||||||
logsf = jnp.where(a >= b, jnp.nan, logsf)
|
|
||||||
return logsf
|
|
||||||
|
|
||||||
|
|
||||||
@_wraps(osp_stats.truncnorm.sf, update_doc=False)
|
@_wraps(osp_stats.truncnorm.sf, update_doc=False)
|
||||||
|
@ -16,9 +16,10 @@
|
|||||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||||
|
|
||||||
from jax._src.scipy.stats.beta import (
|
from jax._src.scipy.stats.beta import (
|
||||||
logpdf as logpdf,
|
|
||||||
pdf as pdf,
|
|
||||||
cdf as cdf,
|
cdf as cdf,
|
||||||
logcdf as logcdf,
|
logcdf as logcdf,
|
||||||
|
logpdf as logpdf,
|
||||||
|
logsf as logsf,
|
||||||
|
pdf as pdf,
|
||||||
sf as sf,
|
sf as sf,
|
||||||
)
|
)
|
||||||
|
@ -16,11 +16,12 @@
|
|||||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||||
|
|
||||||
from jax._src.scipy.stats.cauchy import (
|
from jax._src.scipy.stats.cauchy import (
|
||||||
logpdf as logpdf,
|
|
||||||
pdf as pdf,
|
|
||||||
cdf as cdf,
|
cdf as cdf,
|
||||||
logcdf as logcdf,
|
|
||||||
sf as sf,
|
|
||||||
isf as isf,
|
isf as isf,
|
||||||
|
logcdf as logcdf,
|
||||||
|
logpdf as logpdf,
|
||||||
|
logsf as logsf,
|
||||||
|
pdf as pdf,
|
||||||
ppf as ppf,
|
ppf as ppf,
|
||||||
|
sf as sf,
|
||||||
)
|
)
|
||||||
|
@ -16,9 +16,10 @@
|
|||||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||||
|
|
||||||
from jax._src.scipy.stats.chi2 import (
|
from jax._src.scipy.stats.chi2 import (
|
||||||
logpdf as logpdf,
|
|
||||||
pdf as pdf,
|
|
||||||
cdf as cdf,
|
cdf as cdf,
|
||||||
logcdf as logcdf,
|
logcdf as logcdf,
|
||||||
|
logpdf as logpdf,
|
||||||
|
logsf as logsf,
|
||||||
|
pdf as pdf,
|
||||||
sf as sf,
|
sf as sf,
|
||||||
)
|
)
|
||||||
|
@ -16,9 +16,10 @@
|
|||||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||||
|
|
||||||
from jax._src.scipy.stats.gamma import (
|
from jax._src.scipy.stats.gamma import (
|
||||||
logpdf as logpdf,
|
|
||||||
pdf as pdf,
|
|
||||||
cdf as cdf,
|
cdf as cdf,
|
||||||
logcdf as logcdf,
|
logcdf as logcdf,
|
||||||
|
logpdf as logpdf,
|
||||||
|
logsf as logsf,
|
||||||
|
pdf as pdf,
|
||||||
sf as sf,
|
sf as sf,
|
||||||
)
|
)
|
||||||
|
@ -257,6 +257,22 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
self._CompileAndCheck(lax_fun, args_maker,
|
self._CompileAndCheck(lax_fun, args_maker,
|
||||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(5)
|
||||||
|
def testBetaLogSf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
scipy_fun = osp_stats.beta.logsf
|
||||||
|
lax_fun = lsp_stats.beta.logsf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
return [x, a, b, loc, scale]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
|
tol=1e-3)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker,
|
||||||
|
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||||
|
|
||||||
def testBetaLogPdfZero(self):
|
def testBetaLogPdfZero(self):
|
||||||
# Regression test for https://github.com/google/jax/issues/7645
|
# Regression test for https://github.com/google/jax/issues/7645
|
||||||
a = b = 1.
|
a = b = 1.
|
||||||
@ -279,7 +295,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
tol=1e-4)
|
tol=1e-4)
|
||||||
self._CompileAndCheck(lax_fun, args_maker)
|
self._CompileAndCheck(lax_fun, args_maker, tol={np.float64: 1E-14})
|
||||||
|
|
||||||
@genNamedParametersNArgs(3)
|
@genNamedParametersNArgs(3)
|
||||||
def testCauchyLogCdf(self, shapes, dtypes):
|
def testCauchyLogCdf(self, shapes, dtypes):
|
||||||
@ -299,6 +315,42 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
|
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
|
||||||
atol={np.float64: 1e-14})
|
atol={np.float64: 1e-14})
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(3)
|
||||||
|
def testCauchyCdf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
scipy_fun = osp_stats.cauchy.cdf
|
||||||
|
lax_fun = lsp_stats.cauchy.cdf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
# clipping to ensure that scale is not too low
|
||||||
|
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||||
|
return [x, loc, scale]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
|
tol=1e-4)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
|
||||||
|
atol={np.float64: 1e-14})
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(3)
|
||||||
|
def testCauchyLogSf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
scipy_fun = osp_stats.cauchy.logsf
|
||||||
|
lax_fun = lsp_stats.cauchy.logsf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
# clipping to ensure that scale is not too low
|
||||||
|
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||||
|
return [x, loc, scale]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
|
tol=1e-4)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
|
||||||
|
atol={np.float64: 1e-14})
|
||||||
|
|
||||||
@genNamedParametersNArgs(3)
|
@genNamedParametersNArgs(3)
|
||||||
def testCauchySf(self, shapes, dtypes):
|
def testCauchySf(self, shapes, dtypes):
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
@ -314,7 +366,8 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
tol=1e-4)
|
tol=1e-4)
|
||||||
self._CompileAndCheck(lax_fun, args_maker)
|
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
|
||||||
|
atol={np.float64: 1e-14})
|
||||||
|
|
||||||
@genNamedParametersNArgs(3)
|
@genNamedParametersNArgs(3)
|
||||||
def testCauchyIsf(self, shapes, dtypes):
|
def testCauchyIsf(self, shapes, dtypes):
|
||||||
@ -450,6 +503,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
|
|
||||||
@genNamedParametersNArgs(4)
|
@genNamedParametersNArgs(4)
|
||||||
def testGammaLogSf(self, shapes, dtypes):
|
def testGammaLogSf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
scipy_fun = osp_stats.gamma.logsf
|
||||||
|
lax_fun = lsp_stats.gamma.logsf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, a, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
return [x, a, loc, scale]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
|
tol=5e-4)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(4)
|
||||||
|
def testGammaSf(self, shapes, dtypes):
|
||||||
rng = jtu.rand_positive(self.rng())
|
rng = jtu.rand_positive(self.rng())
|
||||||
scipy_fun = osp_stats.gamma.sf
|
scipy_fun = osp_stats.gamma.sf
|
||||||
lax_fun = lsp_stats.gamma.sf
|
lax_fun = lsp_stats.gamma.sf
|
||||||
@ -960,6 +1028,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
tol=5e-4)
|
tol=5e-4)
|
||||||
self._CompileAndCheck(lax_fun, args_maker)
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(4)
|
||||||
|
def testChi2Cdf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
scipy_fun = osp_stats.chi2.cdf
|
||||||
|
lax_fun = lsp_stats.chi2.cdf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
return [x, df, loc, scale]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
|
tol=5e-4)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
@genNamedParametersNArgs(4)
|
@genNamedParametersNArgs(4)
|
||||||
def testChi2Sf(self, shapes, dtypes):
|
def testChi2Sf(self, shapes, dtypes):
|
||||||
rng = jtu.rand_positive(self.rng())
|
rng = jtu.rand_positive(self.rng())
|
||||||
@ -975,6 +1058,21 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
tol=5e-4)
|
tol=5e-4)
|
||||||
self._CompileAndCheck(lax_fun, args_maker)
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(4)
|
||||||
|
def testChi2LogSf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
scipy_fun = osp_stats.chi2.logsf
|
||||||
|
lax_fun = lsp_stats.chi2.logsf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
return [x, df, loc, scale]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
|
tol=5e-4)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
@genNamedParametersNArgs(5)
|
@genNamedParametersNArgs(5)
|
||||||
def testBetaBinomLogPmf(self, shapes, dtypes):
|
def testBetaBinomLogPmf(self, shapes, dtypes):
|
||||||
rng = jtu.rand_positive(self.rng())
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user