diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index cad15d09f..5d798f76c 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -172,9 +172,6 @@ jax.scipy.stats.beta logpdf pdf - cdf - logcdf - sf jax.scipy.stats.betabinom ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -195,11 +192,6 @@ jax.scipy.stats.cauchy logpdf pdf - cdf - logcdf - sf - isf - ppf jax.scipy.stats.chi2 ~~~~~~~~~~~~~~~~~~~~ @@ -210,9 +202,7 @@ jax.scipy.stats.chi2 logpdf pdf - cdf - logcdf - sf + jax.scipy.stats.dirichlet @@ -242,9 +232,6 @@ jax.scipy.stats.gamma logpdf pdf - cdf - logcdf - sf jax.scipy.stats.gennorm ~~~~~~~~~~~~~~~~~~~~~~~ @@ -309,8 +296,6 @@ jax.scipy.stats.norm logpdf pdf ppf - sf - isf jax.scipy.stats.pareto ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/jax/_src/scipy/stats/beta.py b/jax/_src/scipy/stats/beta.py index 9f0669fff..d4788ec1a 100644 --- a/jax/_src/scipy/stats/beta.py +++ b/jax/_src/scipy/stats/beta.py @@ -19,7 +19,7 @@ from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or from jax._src.typing import Array, ArrayLike -from jax.scipy.special import betaln, betainc, xlogy, xlog1py +from jax.scipy.special import betaln, xlogy, xlog1py @_wraps(osp_stats.beta.logpdf, update_doc=False) @@ -40,31 +40,3 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, a, b, loc, scale)) - - -@_wraps(osp_stats.beta.cdf, update_doc=False) -def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, - loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, a, b, loc, scale = _promote_args_inexact("beta.cdf", x, a, b, loc, scale) - return betainc( - a, - b, - lax.clamp( - _lax_const(x, 0), - lax.div(lax.sub(x, loc), scale), - _lax_const(x, 1), - ) - ) - - -@_wraps(osp_stats.beta.logcdf, update_doc=False) -def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, - loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return lax.log(cdf(x, a, b, loc, scale)) - - -@_wraps(osp_stats.beta.sf, update_doc=False) -def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike, - loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - cdf_result = cdf(x, a, b, loc, scale) - return lax.sub(_lax_const(cdf_result, 1), cdf_result) diff --git a/jax/_src/scipy/stats/cauchy.py b/jax/_src/scipy/stats/cauchy.py index 6ae947ae7..5e5044f74 100644 --- a/jax/_src/scipy/stats/cauchy.py +++ b/jax/_src/scipy/stats/cauchy.py @@ -20,7 +20,6 @@ from jax import lax from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact -from jax._src.numpy.lax_numpy import arctan from jax._src.typing import Array, ArrayLike @@ -32,44 +31,6 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: normalize_term = lax.log(lax.mul(pi, scale)) return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x)))) - @_wraps(osp_stats.cauchy.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) - - -@_wraps(osp_stats.cauchy.cdf, update_doc=False) -def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, loc, scale = _promote_args_inexact("cauchy.cdf", x, loc, scale) - pi = _lax_const(x, np.pi) - scaled_x = lax.div(lax.sub(x, loc), scale) - return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x))) - - -@_wraps(osp_stats.cauchy.logcdf, update_doc=False) -def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return lax.log(cdf(x, loc, scale)) - - -@_wraps(osp_stats.cauchy.sf, update_doc=False) -def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, = _promote_args_inexact("cauchy.sf", x) - return lax.sub(_lax_const(x, 1), cdf(x, loc, scale)) - - -@_wraps(osp_stats.cauchy.isf, update_doc=False) -def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - q, loc, scale = _promote_args_inexact("cauchy.isf", q, loc, scale) - pi = _lax_const(q, np.pi) - half_pi = _lax_const(q, np.pi / 2) - unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q))) - return lax.add(lax.mul(unscaled, scale), loc) - - -@_wraps(osp_stats.cauchy.ppf, update_doc=False) -def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - q, loc, scale = _promote_args_inexact("cauchy.ppf", q, loc, scale) - pi = _lax_const(q, np.pi) - half_pi = _lax_const(q, np.pi / 2) - unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi)) - return lax.add(lax.mul(unscaled, scale), loc) diff --git a/jax/_src/scipy/stats/chi2.py b/jax/_src/scipy/stats/chi2.py index 29195119d..2840bc183 100644 --- a/jax/_src/scipy/stats/chi2.py +++ b/jax/_src/scipy/stats/chi2.py @@ -20,53 +20,23 @@ from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf from jax._src.typing import Array, ArrayLike -from jax.scipy.special import gammainc @_wraps(osp_stats.chi2.logpdf, update_doc=False) def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale) - one = _lax_const(x, 1) - two = _lax_const(x, 2) - y = lax.div(lax.sub(x, loc), scale) - df_on_two = lax.div(df, two) + x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale) + one = _lax_const(x, 1) + two = _lax_const(x, 2) + y = lax.div(lax.sub(x, loc), scale) + df_on_two = lax.div(df, two) - kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) + kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) - nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) - - log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) - return where(lax.lt(x, loc), -inf, log_probs) + nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) + log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) + return where(lax.lt(x, loc), -inf, log_probs) @_wraps(osp_stats.chi2.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return lax.exp(logpdf(x, df, loc, scale)) - - -@_wraps(osp_stats.chi2.cdf, update_doc=False) -def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, df, loc, scale = _promote_args_inexact("chi2.cdf", x, df, loc, scale) - two = _lax_const(scale, 2) - return gammainc( - lax.div(df, two), - lax.clamp( - _lax_const(x, 0), - lax.div( - lax.sub(x, loc), - lax.mul(scale, two), - ), - inf, - ), - ) - - -@_wraps(osp_stats.chi2.logcdf, update_doc=False) -def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return lax.log(cdf(x, df, loc, scale)) - - -@_wraps(osp_stats.chi2.sf, update_doc=False) -def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - cdf_result = cdf(x, df, loc, scale) - return lax.sub(_lax_const(cdf_result, 1), cdf_result) + return lax.exp(logpdf(x, df, loc, scale)) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index 03ec8baa5..94ca30c49 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -19,7 +19,7 @@ from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf from jax._src.typing import Array, ArrayLike -from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc +from jax.scipy.special import gammaln, xlogy @_wraps(osp_stats.gamma.logpdf, update_doc=False) @@ -32,31 +32,6 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) log_probs = lax.sub(log_linear_term, shape_terms) return where(lax.lt(x, loc), -inf, log_probs) - @_wraps(osp_stats.gamma.pdf, update_doc=False) def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, a, loc, scale)) - - -@_wraps(osp_stats.gamma.cdf, update_doc=False) -def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, a, loc, scale = _promote_args_inexact("gamma.cdf", x, a, loc, scale) - return gammainc( - a, - lax.clamp( - _lax_const(x, 0), - lax.div(lax.sub(x, loc), scale), - inf, - ) - ) - - -@_wraps(osp_stats.gamma.logcdf, update_doc=False) -def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return lax.log(cdf(x, a, loc, scale)) - - -@_wraps(osp_stats.gamma.sf, update_doc=False) -def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, a, loc, scale = _promote_args_inexact("gamma.sf", x, a, loc, scale) - return gammaincc(a, lax.div(lax.sub(x, loc), scale)) diff --git a/jax/_src/scipy/stats/logistic.py b/jax/_src/scipy/stats/logistic.py index d8dacf49a..e021d14d2 100644 --- a/jax/_src/scipy/stats/logistic.py +++ b/jax/_src/scipy/stats/logistic.py @@ -24,38 +24,29 @@ from jax._src.typing import Array, ArrayLike @_wraps(osp_stats.logistic.logpdf, update_doc=False) -def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, loc, scale = _promote_args_inexact("logistic.logpdf", x, loc, scale) - x = lax.div(lax.sub(x, loc), scale) +def logpdf(x: ArrayLike) -> Array: + x, = _promote_args_inexact("logistic.logpdf", x) two = _lax_const(x, 2) half_x = lax.div(x, two) - return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale)) + return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))) @_wraps(osp_stats.logistic.pdf, update_doc=False) -def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return lax.exp(logpdf(x, loc, scale)) - +def pdf(x: ArrayLike) -> Array: + return lax.exp(logpdf(x)) @_wraps(osp_stats.logistic.ppf, update_doc=False) -def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, loc, scale = _promote_args_inexact("logistic.ppf", x, loc, scale) - return lax.add(lax.mul(logit(x), scale), loc) - +def ppf(x): + return logit(x) @_wraps(osp_stats.logistic.sf, update_doc=False) -def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, loc, scale = _promote_args_inexact("logistic.sf", x, loc, scale) - return expit(lax.neg(lax.div(lax.sub(x, loc), scale))) - +def sf(x): + return expit(lax.neg(x)) @_wraps(osp_stats.logistic.isf, update_doc=False) -def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, loc, scale = _promote_args_inexact("logistic.isf", x, loc, scale) - return lax.add(lax.mul(lax.neg(logit(x)), scale), loc) - +def isf(x): + return -logit(x) @_wraps(osp_stats.logistic.cdf, update_doc=False) -def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, loc, scale = _promote_args_inexact("logistic.cdf", x, loc, scale) - return expit(lax.div(lax.sub(x, loc), scale)) +def cdf(x): + return expit(x) diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index 1960aca58..d74c47132 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -25,7 +25,6 @@ from jax._src.numpy.lax_numpy import _promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy import special - @_wraps(osp_stats.norm.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale) @@ -56,13 +55,3 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @_wraps(osp_stats.norm.ppf, update_doc=False) def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return jnp.asarray(special.ndtri(q) * scale + loc, float) - - -@_wraps(osp_stats.norm.sf, update_doc=False) -def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return lax.sub(_lax_const(x, 1), cdf(x, loc, scale)) - - -@_wraps(osp_stats.norm.isf, update_doc=False) -def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return ppf(lax.sub(_lax_const(q, 1), q), loc, scale) diff --git a/jax/_src/scipy/stats/t.py b/jax/_src/scipy/stats/t.py index 22d2d053a..aff476cfd 100644 --- a/jax/_src/scipy/stats/t.py +++ b/jax/_src/scipy/stats/t.py @@ -37,7 +37,6 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 quadratic = lax.div(lax.mul(scaled_x, scaled_x), df) return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic)))) - @_wraps(osp_stats.t.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, df, loc, scale)) diff --git a/jax/scipy/stats/beta.py b/jax/scipy/stats/beta.py index 963181fa0..30653f0c3 100644 --- a/jax/scipy/stats/beta.py +++ b/jax/scipy/stats/beta.py @@ -18,7 +18,4 @@ from jax._src.scipy.stats.beta import ( logpdf as logpdf, pdf as pdf, - cdf as cdf, - logcdf as logcdf, - sf as sf, ) diff --git a/jax/scipy/stats/cauchy.py b/jax/scipy/stats/cauchy.py index b3b0d994c..af6c1ba49 100644 --- a/jax/scipy/stats/cauchy.py +++ b/jax/scipy/stats/cauchy.py @@ -18,9 +18,4 @@ from jax._src.scipy.stats.cauchy import ( logpdf as logpdf, pdf as pdf, - cdf as cdf, - logcdf as logcdf, - sf as sf, - isf as isf, - ppf as ppf, ) diff --git a/jax/scipy/stats/chi2.py b/jax/scipy/stats/chi2.py index 9cb28c8a6..349c0c7cc 100644 --- a/jax/scipy/stats/chi2.py +++ b/jax/scipy/stats/chi2.py @@ -18,7 +18,4 @@ from jax._src.scipy.stats.chi2 import ( logpdf as logpdf, pdf as pdf, - cdf as cdf, - logcdf as logcdf, - sf as sf, ) diff --git a/jax/scipy/stats/gamma.py b/jax/scipy/stats/gamma.py index 268fc4fa0..3c518f43f 100644 --- a/jax/scipy/stats/gamma.py +++ b/jax/scipy/stats/gamma.py @@ -18,7 +18,4 @@ from jax._src.scipy.stats.gamma import ( logpdf as logpdf, pdf as pdf, - cdf as cdf, - logcdf as logcdf, - sf as sf, ) diff --git a/jax/scipy/stats/norm.py b/jax/scipy/stats/norm.py index c6b85f25d..fd9506d7c 100644 --- a/jax/scipy/stats/norm.py +++ b/jax/scipy/stats/norm.py @@ -21,6 +21,4 @@ from jax._src.scipy.stats.norm import ( logpdf as logpdf, pdf as pdf, ppf as ppf, - sf as sf, - isf as isf, ) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 071495aff..0b5984d58 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -226,38 +226,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck(lax_fun, args_maker, rtol={np.float32: 2e-3, np.float64: 1e-4}) - @genNamedParametersNArgs(5) - def testBetaLogCdf(self, shapes, dtypes): - rng = jtu.rand_positive(self.rng()) - scipy_fun = osp_stats.beta.logcdf - lax_fun = lsp_stats.beta.logcdf - - def args_maker(): - x, a, b, loc, scale = map(rng, shapes, dtypes) - return [x, a, b, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, - rtol={np.float32: 2e-3, np.float64: 1e-4}) - - @genNamedParametersNArgs(5) - def testBetaSf(self, shapes, dtypes): - rng = jtu.rand_positive(self.rng()) - scipy_fun = osp_stats.beta.sf - lax_fun = lsp_stats.beta.sf - - def args_maker(): - x, a, b, loc, scale = map(rng, shapes, dtypes) - return [x, a, b, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, - rtol={np.float32: 2e-3, np.float64: 1e-4}) - def testBetaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7645 a = b = 1. @@ -282,80 +250,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(3) - def testCauchyLogCdf(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - scipy_fun = osp_stats.cauchy.logcdf - lax_fun = lsp_stats.cauchy.logcdf - - def args_maker(): - x, loc, scale = map(rng, shapes, dtypes) - # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) - return [x, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) - - @genNamedParametersNArgs(3) - def testCauchySf(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - scipy_fun = osp_stats.cauchy.sf - lax_fun = lsp_stats.cauchy.sf - - def args_maker(): - x, loc, scale = map(rng, shapes, dtypes) - # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) - return [x, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) - - @genNamedParametersNArgs(3) - def testCauchyIsf(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - scipy_fun = osp_stats.cauchy.isf - lax_fun = lsp_stats.cauchy.isf - - def args_maker(): - q, loc, scale = map(rng, shapes, dtypes) - # clipping to ensure that q is in desired range - # since lax.tan and numpy.tan work different near divergence points - q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype) - # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) - return [q, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=2e-4) - self._CompileAndCheck(lax_fun, args_maker) - - @genNamedParametersNArgs(3) - def testCauchyPpf(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - scipy_fun = osp_stats.cauchy.ppf - lax_fun = lsp_stats.cauchy.ppf - - def args_maker(): - q, loc, scale = map(rng, shapes, dtypes) - # clipping to ensure that q is in desired - # since lax.tan and numpy.tan work different near divergence points - q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype) - # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) - return [q, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=2e-4) - self._CompileAndCheck(lax_fun, args_maker) - @jtu.sample_product( shapes=[ [x_shape, alpha_shape] @@ -427,37 +321,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(4) - def testGammaLogCdf(self, shapes, dtypes): - rng = jtu.rand_positive(self.rng()) - scipy_fun = osp_stats.gamma.logcdf - lax_fun = lsp_stats.gamma.logcdf - - def args_maker(): - x, a, loc, scale = map(rng, shapes, dtypes) - x = np.clip(x, 0, None).astype(x.dtype) - return [x, a, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=5e-4) - self._CompileAndCheck(lax_fun, args_maker) - - @genNamedParametersNArgs(4) - def testGammaLogSf(self, shapes, dtypes): - rng = jtu.rand_positive(self.rng()) - scipy_fun = osp_stats.gamma.sf - lax_fun = lsp_stats.gamma.sf - - def args_maker(): - x, a, loc, scale = map(rng, shapes, dtypes) - return [x, a, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=5e-4) - self._CompileAndCheck(lax_fun, args_maker) - def testGammaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7256 self.assertAllClose( @@ -548,34 +411,28 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol={np.float32: 1e-5, np.float64: 1e-6}) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(3) + @genNamedParametersNArgs(1) def testLogisticCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): - x, loc, scale = map(rng, shapes, dtypes) - # ensure that scale is not too low - scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) - return [x, loc, scale] + return list(map(rng, shapes, dtypes)) with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=3e-5) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(3) + @genNamedParametersNArgs(1) def testLogisticLogpdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.logpdf lax_fun = lsp_stats.logistic.logpdf def args_maker(): - x, loc, scale = map(rng, shapes, dtypes) - # ensure that scale is not too low - scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) - return [x, loc, scale] + return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) @@ -588,54 +445,32 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)), check_dtypes=False) - @genNamedParametersNArgs(3) + @genNamedParametersNArgs(1) def testLogisticPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.ppf lax_fun = lsp_stats.logistic.ppf def args_maker(): - x, loc, scale = map(rng, shapes, dtypes) - # ensure that scale is not too low - scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) - return [x, loc, scale] + return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(3) + @genNamedParametersNArgs(1) def testLogisticSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.sf lax_fun = lsp_stats.logistic.sf def args_maker(): - x, loc, scale = map(rng, shapes, dtypes) - # ensure that scale is not too low - scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) - return [x, loc, scale] + return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=2e-5) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(3) - def testLogisticIsf(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - scipy_fun = osp_stats.logistic.isf - lax_fun = lsp_stats.logistic.isf - - def args_maker(): - x, loc, scale = map(rng, shapes, dtypes) - # ensure that scale is not too low - scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) - return [x, loc, scale] - - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(3) def testNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -690,24 +525,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(3) - def testNormSf(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - scipy_fun = osp_stats.norm.sf - lax_fun = lsp_stats.norm.sf - - def args_maker(): - x, loc, scale = map(rng, shapes, dtypes) - # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) - return [x, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-6) - self._CompileAndCheck(lax_fun, args_maker) - - @genNamedParametersNArgs(3) def testNormPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -726,25 +543,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) - - @genNamedParametersNArgs(3) - def testNormIsf(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - scipy_fun = osp_stats.norm.isf - lax_fun = lsp_stats.norm.isf - - def args_maker(): - q, loc, scale = map(rng, shapes, dtypes) - # ensure probability is between 0 and 1: - q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype) - # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) - return [q, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) - @genNamedParametersNArgs(5) def testTruncnormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -918,36 +716,6 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(4) - def testChi2LogCdf(self, shapes, dtypes): - rng = jtu.rand_positive(self.rng()) - scipy_fun = osp_stats.chi2.logcdf - lax_fun = lsp_stats.chi2.logcdf - - def args_maker(): - x, df, loc, scale = map(rng, shapes, dtypes) - return [x, df, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=5e-4) - self._CompileAndCheck(lax_fun, args_maker) - - @genNamedParametersNArgs(4) - def testChi2Sf(self, shapes, dtypes): - rng = jtu.rand_positive(self.rng()) - scipy_fun = osp_stats.chi2.sf - lax_fun = lsp_stats.chi2.sf - - def args_maker(): - x, df, loc, scale = map(rng, shapes, dtypes) - return [x, df, loc, scale] - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=5e-4) - self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(5) def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng())