From feb9ab33af9a2506429b94bef5dd6ddefbd06d01 Mon Sep 17 00:00:00 2001 From: Misha <48orusef@gmail.com> Date: Mon, 6 Mar 2023 00:11:40 +0100 Subject: [PATCH] Fixed loc and scale parameters for logistic distribution. CDF and SF have been added for several distributions, including cauchy, gamma, logistic, chi2 and beta. ISF and PPF have also been added for cauchy and logistic. --- docs/jax.scipy.rst | 17 ++- jax/_src/scipy/stats/beta.py | 30 +++- jax/_src/scipy/stats/cauchy.py | 39 +++++ jax/_src/scipy/stats/chi2.py | 50 +++++-- jax/_src/scipy/stats/gamma.py | 27 +++- jax/_src/scipy/stats/logistic.py | 35 +++-- jax/_src/scipy/stats/norm.py | 11 ++ jax/_src/scipy/stats/t.py | 1 + jax/scipy/stats/beta.py | 3 + jax/scipy/stats/cauchy.py | 5 + jax/scipy/stats/chi2.py | 3 + jax/scipy/stats/gamma.py | 3 + jax/scipy/stats/norm.py | 2 + tests/scipy_stats_test.py | 248 ++++++++++++++++++++++++++++++- 14 files changed, 440 insertions(+), 34 deletions(-) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 5d798f76c..cad15d09f 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -172,6 +172,9 @@ jax.scipy.stats.beta logpdf pdf + cdf + logcdf + sf jax.scipy.stats.betabinom ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -192,6 +195,11 @@ jax.scipy.stats.cauchy logpdf pdf + cdf + logcdf + sf + isf + ppf jax.scipy.stats.chi2 ~~~~~~~~~~~~~~~~~~~~ @@ -202,7 +210,9 @@ jax.scipy.stats.chi2 logpdf pdf - + cdf + logcdf + sf jax.scipy.stats.dirichlet @@ -232,6 +242,9 @@ jax.scipy.stats.gamma logpdf pdf + cdf + logcdf + sf jax.scipy.stats.gennorm ~~~~~~~~~~~~~~~~~~~~~~~ @@ -296,6 +309,8 @@ jax.scipy.stats.norm logpdf pdf ppf + sf + isf jax.scipy.stats.pareto ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/jax/_src/scipy/stats/beta.py b/jax/_src/scipy/stats/beta.py index d4788ec1a..9f0669fff 100644 --- a/jax/_src/scipy/stats/beta.py +++ b/jax/_src/scipy/stats/beta.py @@ -19,7 +19,7 @@ from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or from jax._src.typing import Array, ArrayLike -from jax.scipy.special import betaln, xlogy, xlog1py +from jax.scipy.special import betaln, betainc, xlogy, xlog1py @_wraps(osp_stats.beta.logpdf, update_doc=False) @@ -40,3 +40,31 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.beta.cdf, update_doc=False) +def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, a, b, loc, scale = _promote_args_inexact("beta.cdf", x, a, b, loc, scale) + return betainc( + a, + b, + lax.clamp( + _lax_const(x, 0), + lax.div(lax.sub(x, loc), scale), + _lax_const(x, 1), + ) + ) + + +@_wraps(osp_stats.beta.logcdf, update_doc=False) +def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.beta.sf, update_doc=False) +def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike, + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + cdf_result = cdf(x, a, b, loc, scale) + return lax.sub(_lax_const(cdf_result, 1), cdf_result) diff --git a/jax/_src/scipy/stats/cauchy.py b/jax/_src/scipy/stats/cauchy.py index 5e5044f74..6ae947ae7 100644 --- a/jax/_src/scipy/stats/cauchy.py +++ b/jax/_src/scipy/stats/cauchy.py @@ -20,6 +20,7 @@ from jax import lax from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact +from jax._src.numpy.lax_numpy import arctan from jax._src.typing import Array, ArrayLike @@ -31,6 +32,44 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: normalize_term = lax.log(lax.mul(pi, scale)) return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x)))) + @_wraps(osp_stats.cauchy.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) + + +@_wraps(osp_stats.cauchy.cdf, update_doc=False) +def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = _promote_args_inexact("cauchy.cdf", x, loc, scale) + pi = _lax_const(x, np.pi) + scaled_x = lax.div(lax.sub(x, loc), scale) + return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x))) + + +@_wraps(osp_stats.cauchy.logcdf, update_doc=False) +def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, loc, scale)) + + +@_wraps(osp_stats.cauchy.sf, update_doc=False) +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, = _promote_args_inexact("cauchy.sf", x) + return lax.sub(_lax_const(x, 1), cdf(x, loc, scale)) + + +@_wraps(osp_stats.cauchy.isf, update_doc=False) +def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + q, loc, scale = _promote_args_inexact("cauchy.isf", q, loc, scale) + pi = _lax_const(q, np.pi) + half_pi = _lax_const(q, np.pi / 2) + unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q))) + return lax.add(lax.mul(unscaled, scale), loc) + + +@_wraps(osp_stats.cauchy.ppf, update_doc=False) +def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + q, loc, scale = _promote_args_inexact("cauchy.ppf", q, loc, scale) + pi = _lax_const(q, np.pi) + half_pi = _lax_const(q, np.pi / 2) + unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi)) + return lax.add(lax.mul(unscaled, scale), loc) diff --git a/jax/_src/scipy/stats/chi2.py b/jax/_src/scipy/stats/chi2.py index 2840bc183..29195119d 100644 --- a/jax/_src/scipy/stats/chi2.py +++ b/jax/_src/scipy/stats/chi2.py @@ -20,23 +20,53 @@ from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf from jax._src.typing import Array, ArrayLike +from jax.scipy.special import gammainc @_wraps(osp_stats.chi2.logpdf, update_doc=False) def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale) - one = _lax_const(x, 1) - two = _lax_const(x, 2) - y = lax.div(lax.sub(x, loc), scale) - df_on_two = lax.div(df, two) + x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale) + one = _lax_const(x, 1) + two = _lax_const(x, 2) + y = lax.div(lax.sub(x, loc), scale) + df_on_two = lax.div(df, two) - kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) + kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) - nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) + nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) + + log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) + return where(lax.lt(x, loc), -inf, log_probs) - log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) - return where(lax.lt(x, loc), -inf, log_probs) @_wraps(osp_stats.chi2.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return lax.exp(logpdf(x, df, loc, scale)) + return lax.exp(logpdf(x, df, loc, scale)) + + +@_wraps(osp_stats.chi2.cdf, update_doc=False) +def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, df, loc, scale = _promote_args_inexact("chi2.cdf", x, df, loc, scale) + two = _lax_const(scale, 2) + return gammainc( + lax.div(df, two), + lax.clamp( + _lax_const(x, 0), + lax.div( + lax.sub(x, loc), + lax.mul(scale, two), + ), + inf, + ), + ) + + +@_wraps(osp_stats.chi2.logcdf, update_doc=False) +def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, df, loc, scale)) + + +@_wraps(osp_stats.chi2.sf, update_doc=False) +def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + cdf_result = cdf(x, df, loc, scale) + return lax.sub(_lax_const(cdf_result, 1), cdf_result) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index 94ca30c49..03ec8baa5 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -19,7 +19,7 @@ from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf from jax._src.typing import Array, ArrayLike -from jax.scipy.special import gammaln, xlogy +from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc @_wraps(osp_stats.gamma.logpdf, update_doc=False) @@ -32,6 +32,31 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) log_probs = lax.sub(log_linear_term, shape_terms) return where(lax.lt(x, loc), -inf, log_probs) + @_wraps(osp_stats.gamma.pdf, update_doc=False) def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, a, loc, scale)) + + +@_wraps(osp_stats.gamma.cdf, update_doc=False) +def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, a, loc, scale = _promote_args_inexact("gamma.cdf", x, a, loc, scale) + return gammainc( + a, + lax.clamp( + _lax_const(x, 0), + lax.div(lax.sub(x, loc), scale), + inf, + ) + ) + + +@_wraps(osp_stats.gamma.logcdf, update_doc=False) +def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, a, loc, scale)) + + +@_wraps(osp_stats.gamma.sf, update_doc=False) +def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, a, loc, scale = _promote_args_inexact("gamma.sf", x, a, loc, scale) + return gammaincc(a, lax.div(lax.sub(x, loc), scale)) diff --git a/jax/_src/scipy/stats/logistic.py b/jax/_src/scipy/stats/logistic.py index e021d14d2..d8dacf49a 100644 --- a/jax/_src/scipy/stats/logistic.py +++ b/jax/_src/scipy/stats/logistic.py @@ -24,29 +24,38 @@ from jax._src.typing import Array, ArrayLike @_wraps(osp_stats.logistic.logpdf, update_doc=False) -def logpdf(x: ArrayLike) -> Array: - x, = _promote_args_inexact("logistic.logpdf", x) +def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = _promote_args_inexact("logistic.logpdf", x, loc, scale) + x = lax.div(lax.sub(x, loc), scale) two = _lax_const(x, 2) half_x = lax.div(x, two) - return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))) + return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale)) @_wraps(osp_stats.logistic.pdf, update_doc=False) -def pdf(x: ArrayLike) -> Array: - return lax.exp(logpdf(x)) +def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.exp(logpdf(x, loc, scale)) + @_wraps(osp_stats.logistic.ppf, update_doc=False) -def ppf(x): - return logit(x) +def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = _promote_args_inexact("logistic.ppf", x, loc, scale) + return lax.add(lax.mul(logit(x), scale), loc) + @_wraps(osp_stats.logistic.sf, update_doc=False) -def sf(x): - return expit(lax.neg(x)) +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = _promote_args_inexact("logistic.sf", x, loc, scale) + return expit(lax.neg(lax.div(lax.sub(x, loc), scale))) + @_wraps(osp_stats.logistic.isf, update_doc=False) -def isf(x): - return -logit(x) +def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = _promote_args_inexact("logistic.isf", x, loc, scale) + return lax.add(lax.mul(lax.neg(logit(x)), scale), loc) + @_wraps(osp_stats.logistic.cdf, update_doc=False) -def cdf(x): - return expit(x) +def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = _promote_args_inexact("logistic.cdf", x, loc, scale) + return expit(lax.div(lax.sub(x, loc), scale)) diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index d74c47132..1960aca58 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -25,6 +25,7 @@ from jax._src.numpy.lax_numpy import _promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy import special + @_wraps(osp_stats.norm.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale) @@ -55,3 +56,13 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @_wraps(osp_stats.norm.ppf, update_doc=False) def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return jnp.asarray(special.ndtri(q) * scale + loc, float) + + +@_wraps(osp_stats.norm.sf, update_doc=False) +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.sub(_lax_const(x, 1), cdf(x, loc, scale)) + + +@_wraps(osp_stats.norm.isf, update_doc=False) +def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return ppf(lax.sub(_lax_const(q, 1), q), loc, scale) diff --git a/jax/_src/scipy/stats/t.py b/jax/_src/scipy/stats/t.py index aff476cfd..22d2d053a 100644 --- a/jax/_src/scipy/stats/t.py +++ b/jax/_src/scipy/stats/t.py @@ -37,6 +37,7 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 quadratic = lax.div(lax.mul(scaled_x, scaled_x), df) return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic)))) + @_wraps(osp_stats.t.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, df, loc, scale)) diff --git a/jax/scipy/stats/beta.py b/jax/scipy/stats/beta.py index 30653f0c3..963181fa0 100644 --- a/jax/scipy/stats/beta.py +++ b/jax/scipy/stats/beta.py @@ -18,4 +18,7 @@ from jax._src.scipy.stats.beta import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, ) diff --git a/jax/scipy/stats/cauchy.py b/jax/scipy/stats/cauchy.py index af6c1ba49..b3b0d994c 100644 --- a/jax/scipy/stats/cauchy.py +++ b/jax/scipy/stats/cauchy.py @@ -18,4 +18,9 @@ from jax._src.scipy.stats.cauchy import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, + isf as isf, + ppf as ppf, ) diff --git a/jax/scipy/stats/chi2.py b/jax/scipy/stats/chi2.py index 349c0c7cc..9cb28c8a6 100644 --- a/jax/scipy/stats/chi2.py +++ b/jax/scipy/stats/chi2.py @@ -18,4 +18,7 @@ from jax._src.scipy.stats.chi2 import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, ) diff --git a/jax/scipy/stats/gamma.py b/jax/scipy/stats/gamma.py index 3c518f43f..268fc4fa0 100644 --- a/jax/scipy/stats/gamma.py +++ b/jax/scipy/stats/gamma.py @@ -18,4 +18,7 @@ from jax._src.scipy.stats.gamma import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, ) diff --git a/jax/scipy/stats/norm.py b/jax/scipy/stats/norm.py index fd9506d7c..c6b85f25d 100644 --- a/jax/scipy/stats/norm.py +++ b/jax/scipy/stats/norm.py @@ -21,4 +21,6 @@ from jax._src.scipy.stats.norm import ( logpdf as logpdf, pdf as pdf, ppf as ppf, + sf as sf, + isf as isf, ) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 0b5984d58..071495aff 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -226,6 +226,38 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck(lax_fun, args_maker, rtol={np.float32: 2e-3, np.float64: 1e-4}) + @genNamedParametersNArgs(5) + def testBetaLogCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.beta.logcdf + lax_fun = lsp_stats.beta.logcdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker, + rtol={np.float32: 2e-3, np.float64: 1e-4}) + + @genNamedParametersNArgs(5) + def testBetaSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.beta.sf + lax_fun = lsp_stats.beta.sf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker, + rtol={np.float32: 2e-3, np.float64: 1e-4}) + def testBetaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7645 a = b = 1. @@ -250,6 +282,80 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) + def testCauchyLogCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.logcdf + lax_fun = lsp_stats.cauchy.logcdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testCauchySf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.sf + lax_fun = lsp_stats.cauchy.sf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testCauchyIsf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.isf + lax_fun = lsp_stats.cauchy.isf + + def args_maker(): + q, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that q is in desired range + # since lax.tan and numpy.tan work different near divergence points + q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [q, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=2e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testCauchyPpf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.ppf + lax_fun = lsp_stats.cauchy.ppf + + def args_maker(): + q, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that q is in desired + # since lax.tan and numpy.tan work different near divergence points + q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [q, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=2e-4) + self._CompileAndCheck(lax_fun, args_maker) + @jtu.sample_product( shapes=[ [x_shape, alpha_shape] @@ -321,6 +427,37 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) + def testGammaLogCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.gamma.logcdf + lax_fun = lsp_stats.gamma.logcdf + + def args_maker(): + x, a, loc, scale = map(rng, shapes, dtypes) + x = np.clip(x, 0, None).astype(x.dtype) + return [x, a, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testGammaLogSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.gamma.sf + lax_fun = lsp_stats.gamma.sf + + def args_maker(): + x, a, loc, scale = map(rng, shapes, dtypes) + return [x, a, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + def testGammaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7256 self.assertAllClose( @@ -411,28 +548,34 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol={np.float32: 1e-5, np.float64: 1e-6}) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=3e-5) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticLogpdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.logpdf lax_fun = lsp_stats.logistic.logpdf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) @@ -445,32 +588,54 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)), check_dtypes=False) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.ppf lax_fun = lsp_stats.logistic.ppf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.sf lax_fun = lsp_stats.logistic.sf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=2e-5) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) + def testLogisticIsf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.logistic.isf + lax_fun = lsp_stats.logistic.isf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) def testNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -525,6 +690,24 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) + def testNormSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.norm.sf + lax_fun = lsp_stats.norm.sf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-6) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) def testNormPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -543,6 +726,25 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + + @genNamedParametersNArgs(3) + def testNormIsf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.norm.isf + lax_fun = lsp_stats.norm.isf + + def args_maker(): + q, loc, scale = map(rng, shapes, dtypes) + # ensure probability is between 0 and 1: + q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [q, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + @genNamedParametersNArgs(5) def testTruncnormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -716,6 +918,36 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) + def testChi2LogCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.chi2.logcdf + lax_fun = lsp_stats.chi2.logcdf + + def args_maker(): + x, df, loc, scale = map(rng, shapes, dtypes) + return [x, df, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testChi2Sf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.chi2.sf + lax_fun = lsp_stats.chi2.sf + + def args_maker(): + x, df, loc, scale = map(rng, shapes, dtypes) + return [x, df, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(5) def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng())