From 83b3f5b759c3eac93ae8ff1104d862be94ff6d5b Mon Sep 17 00:00:00 2001 From: Misha <48orusef@gmail.com> Date: Sun, 12 Mar 2023 06:53:09 +0100 Subject: [PATCH] Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions. --- docs/jax.scipy.rst | 17 +- jax/_src/scipy/stats/beta.py | 30 +++- jax/_src/scipy/stats/cauchy.py | 44 ++++- jax/_src/scipy/stats/chi2.py | 49 ++++-- jax/_src/scipy/stats/gamma.py | 26 ++- jax/_src/scipy/stats/logistic.py | 35 ++-- jax/_src/scipy/stats/norm.py | 12 ++ jax/_src/scipy/stats/t.py | 1 + jax/scipy/stats/beta.py | 3 + jax/scipy/stats/cauchy.py | 5 + jax/scipy/stats/chi2.py | 3 + jax/scipy/stats/gamma.py | 3 + jax/scipy/stats/norm.py | 2 + tests/scipy_stats_test.py | 269 ++++++++++++++++++++++++++++--- 14 files changed, 453 insertions(+), 46 deletions(-) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 5d798f76c..cad15d09f 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -172,6 +172,9 @@ jax.scipy.stats.beta logpdf pdf + cdf + logcdf + sf jax.scipy.stats.betabinom ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -192,6 +195,11 @@ jax.scipy.stats.cauchy logpdf pdf + cdf + logcdf + sf + isf + ppf jax.scipy.stats.chi2 ~~~~~~~~~~~~~~~~~~~~ @@ -202,7 +210,9 @@ jax.scipy.stats.chi2 logpdf pdf - + cdf + logcdf + sf jax.scipy.stats.dirichlet @@ -232,6 +242,9 @@ jax.scipy.stats.gamma logpdf pdf + cdf + logcdf + sf jax.scipy.stats.gennorm ~~~~~~~~~~~~~~~~~~~~~~~ @@ -296,6 +309,8 @@ jax.scipy.stats.norm logpdf pdf ppf + sf + isf jax.scipy.stats.pareto ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/jax/_src/scipy/stats/beta.py b/jax/_src/scipy/stats/beta.py index 73ba2725d..cf8ae2e3b 100644 --- a/jax/_src/scipy/stats/beta.py +++ b/jax/_src/scipy/stats/beta.py @@ -19,7 +19,7 @@ import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps, promote_args_inexact from jax._src.typing import Array, ArrayLike -from jax.scipy.special import betaln, xlogy, xlog1py +from jax.scipy.special import betaln, betainc, xlogy, xlog1py @_wraps(osp_stats.beta.logpdf, update_doc=False) @@ -40,3 +40,31 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.beta.cdf, update_doc=False) +def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale) + return betainc( + a, + b, + lax.clamp( + _lax_const(x, 0), + lax.div(lax.sub(x, loc), scale), + _lax_const(x, 1), + ) + ) + + +@_wraps(osp_stats.beta.logcdf, update_doc=False) +def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, a, b, loc, scale)) + + +@_wraps(osp_stats.beta.sf, update_doc=False) +def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike, + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + cdf_result = cdf(x, a, b, loc, scale) + return lax.sub(_lax_const(cdf_result, 1), cdf_result) diff --git a/jax/_src/scipy/stats/cauchy.py b/jax/_src/scipy/stats/cauchy.py index 169cc6f85..426b1eec0 100644 --- a/jax/_src/scipy/stats/cauchy.py +++ b/jax/_src/scipy/stats/cauchy.py @@ -18,8 +18,8 @@ import scipy.stats as osp_stats from jax import lax from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps -from jax._src.numpy.util import promote_args_inexact +from jax._src.numpy.util import _wraps, promote_args_inexact +from jax.numpy import arctan from jax._src.typing import Array, ArrayLike @@ -31,6 +31,46 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: normalize_term = lax.log(lax.mul(pi, scale)) return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x)))) + @_wraps(osp_stats.cauchy.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) + + + +@_wraps(osp_stats.cauchy.cdf, update_doc=False) +def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale) + pi = _lax_const(x, np.pi) + scaled_x = lax.div(lax.sub(x, loc), scale) + return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x))) + + +@_wraps(osp_stats.cauchy.logcdf, update_doc=False) +def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, loc, scale)) + + +@_wraps(osp_stats.cauchy.sf, update_doc=False) +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, = promote_args_inexact("cauchy.sf", x) + cdf_result = cdf(x, loc, scale) + return lax.sub(_lax_const(cdf_result, 1), cdf_result) + + +@_wraps(osp_stats.cauchy.isf, update_doc=False) +def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale) + pi = _lax_const(q, np.pi) + half_pi = _lax_const(q, np.pi / 2) + unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q))) + return lax.add(lax.mul(unscaled, scale), loc) + + +@_wraps(osp_stats.cauchy.ppf, update_doc=False) +def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale) + pi = _lax_const(q, np.pi) + half_pi = _lax_const(q, np.pi / 2) + unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi)) + return lax.add(lax.mul(unscaled, scale), loc) diff --git a/jax/_src/scipy/stats/chi2.py b/jax/_src/scipy/stats/chi2.py index f82d8de02..912f225be 100644 --- a/jax/_src/scipy/stats/chi2.py +++ b/jax/_src/scipy/stats/chi2.py @@ -20,23 +20,52 @@ import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps, promote_args_inexact from jax._src.typing import Array, ArrayLike +from jax.scipy.special import gammainc @_wraps(osp_stats.chi2.logpdf, update_doc=False) def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale) - one = _lax_const(x, 1) - two = _lax_const(x, 2) - y = lax.div(lax.sub(x, loc), scale) - df_on_two = lax.div(df, two) + x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale) + one = _lax_const(x, 1) + two = _lax_const(x, 2) + y = lax.div(lax.sub(x, loc), scale) + df_on_two = lax.div(df, two) - kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) + kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) - nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) + nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) - log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) - return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) + log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) + return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) @_wraps(osp_stats.chi2.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - return lax.exp(logpdf(x, df, loc, scale)) + return lax.exp(logpdf(x, df, loc, scale)) + + +@_wraps(osp_stats.chi2.cdf, update_doc=False) +def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale) + two = _lax_const(scale, 2) + return gammainc( + lax.div(df, two), + lax.clamp( + _lax_const(x, 0), + lax.div( + lax.sub(x, loc), + lax.mul(scale, two), + ), + _lax_const(x, jnp.inf), + ), + ) + + +@_wraps(osp_stats.chi2.logcdf, update_doc=False) +def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, df, loc, scale)) + + +@_wraps(osp_stats.chi2.sf, update_doc=False) +def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + cdf_result = cdf(x, df, loc, scale) + return lax.sub(_lax_const(cdf_result, 1), cdf_result) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index ab9f47596..dcfb9439a 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -19,7 +19,7 @@ import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps, promote_args_inexact from jax._src.typing import Array, ArrayLike -from jax.scipy.special import gammaln, xlogy +from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc @_wraps(osp_stats.gamma.logpdf, update_doc=False) @@ -35,3 +35,27 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) @_wraps(osp_stats.gamma.pdf, update_doc=False) def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, a, loc, scale)) + + +@_wraps(osp_stats.gamma.cdf, update_doc=False) +def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale) + return gammainc( + a, + lax.clamp( + _lax_const(x, 0), + lax.div(lax.sub(x, loc), scale), + _lax_const(x, jnp.inf), + ) + ) + + +@_wraps(osp_stats.gamma.logcdf, update_doc=False) +def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(cdf(x, a, loc, scale)) + + +@_wraps(osp_stats.gamma.sf, update_doc=False) +def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale) + return gammaincc(a, lax.div(lax.sub(x, loc), scale)) diff --git a/jax/_src/scipy/stats/logistic.py b/jax/_src/scipy/stats/logistic.py index effdf21b7..67901e83f 100644 --- a/jax/_src/scipy/stats/logistic.py +++ b/jax/_src/scipy/stats/logistic.py @@ -23,29 +23,38 @@ from jax._src.typing import Array, ArrayLike @_wraps(osp_stats.logistic.logpdf, update_doc=False) -def logpdf(x: ArrayLike) -> Array: - x, = promote_args_inexact("logistic.logpdf", x) +def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale) + x = lax.div(lax.sub(x, loc), scale) two = _lax_const(x, 2) half_x = lax.div(x, two) - return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))) + return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale)) @_wraps(osp_stats.logistic.pdf, update_doc=False) -def pdf(x: ArrayLike) -> Array: - return lax.exp(logpdf(x)) +def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.exp(logpdf(x, loc, scale)) + @_wraps(osp_stats.logistic.ppf, update_doc=False) -def ppf(x): - return logit(x) +def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale) + return lax.add(lax.mul(logit(x), scale), loc) + @_wraps(osp_stats.logistic.sf, update_doc=False) -def sf(x): - return expit(lax.neg(x)) +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale) + return expit(lax.neg(lax.div(lax.sub(x, loc), scale))) + @_wraps(osp_stats.logistic.isf, update_doc=False) -def isf(x): - return -logit(x) +def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale) + return lax.add(lax.mul(lax.neg(logit(x)), scale), loc) + @_wraps(osp_stats.logistic.cdf, update_doc=False) -def cdf(x): - return expit(x) +def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale) + return expit(lax.div(lax.sub(x, loc), scale)) diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index 011b566de..4c72bcac5 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -24,6 +24,7 @@ from jax._src.numpy.util import _wraps, promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy import special + @_wraps(osp_stats.norm.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale) @@ -54,3 +55,14 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @_wraps(osp_stats.norm.ppf, update_doc=False) def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return jnp.asarray(special.ndtri(q) * scale + loc, float) + + +@_wraps(osp_stats.norm.sf, update_doc=False) +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + cdf_result = cdf(x, loc, scale) + return lax.sub(_lax_const(cdf_result, 1), cdf_result) + + +@_wraps(osp_stats.norm.isf, update_doc=False) +def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return ppf(lax.sub(_lax_const(q, 1), q), loc, scale) diff --git a/jax/_src/scipy/stats/t.py b/jax/_src/scipy/stats/t.py index efc615554..5a54f2bf5 100644 --- a/jax/_src/scipy/stats/t.py +++ b/jax/_src/scipy/stats/t.py @@ -36,6 +36,7 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 quadratic = lax.div(lax.mul(scaled_x, scaled_x), df) return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic)))) + @_wraps(osp_stats.t.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, df, loc, scale)) diff --git a/jax/scipy/stats/beta.py b/jax/scipy/stats/beta.py index 30653f0c3..963181fa0 100644 --- a/jax/scipy/stats/beta.py +++ b/jax/scipy/stats/beta.py @@ -18,4 +18,7 @@ from jax._src.scipy.stats.beta import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, ) diff --git a/jax/scipy/stats/cauchy.py b/jax/scipy/stats/cauchy.py index af6c1ba49..b3b0d994c 100644 --- a/jax/scipy/stats/cauchy.py +++ b/jax/scipy/stats/cauchy.py @@ -18,4 +18,9 @@ from jax._src.scipy.stats.cauchy import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, + isf as isf, + ppf as ppf, ) diff --git a/jax/scipy/stats/chi2.py b/jax/scipy/stats/chi2.py index 349c0c7cc..9cb28c8a6 100644 --- a/jax/scipy/stats/chi2.py +++ b/jax/scipy/stats/chi2.py @@ -18,4 +18,7 @@ from jax._src.scipy.stats.chi2 import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, ) diff --git a/jax/scipy/stats/gamma.py b/jax/scipy/stats/gamma.py index 3c518f43f..268fc4fa0 100644 --- a/jax/scipy/stats/gamma.py +++ b/jax/scipy/stats/gamma.py @@ -18,4 +18,7 @@ from jax._src.scipy.stats.gamma import ( logpdf as logpdf, pdf as pdf, + cdf as cdf, + logcdf as logcdf, + sf as sf, ) diff --git a/jax/scipy/stats/norm.py b/jax/scipy/stats/norm.py index fd9506d7c..c6b85f25d 100644 --- a/jax/scipy/stats/norm.py +++ b/jax/scipy/stats/norm.py @@ -21,4 +21,6 @@ from jax._src.scipy.stats.norm import ( logpdf as logpdf, pdf as pdf, ppf as ppf, + sf as sf, + isf as isf, ) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 0b5984d58..fb7a5af9a 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -190,7 +190,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) - self._CompileAndCheck(lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(3) def testGeomLogPmf(self, shapes, dtypes): @@ -226,6 +226,38 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck(lax_fun, args_maker, rtol={np.float32: 2e-3, np.float64: 1e-4}) + @genNamedParametersNArgs(5) + def testBetaLogCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.beta.logcdf + lax_fun = lsp_stats.beta.logcdf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker, + rtol={np.float32: 2e-3, np.float64: 1e-4}) + + @genNamedParametersNArgs(5) + def testBetaSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.beta.sf + lax_fun = lsp_stats.beta.sf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker, + rtol={np.float32: 2e-3, np.float64: 1e-4}) + def testBetaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7645 a = b = 1. @@ -250,6 +282,80 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) + def testCauchyLogCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.logcdf + lax_fun = lsp_stats.cauchy.logcdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testCauchySf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.sf + lax_fun = lsp_stats.cauchy.sf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testCauchyIsf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.isf + lax_fun = lsp_stats.cauchy.isf + + def args_maker(): + q, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that q is in desired range + # since lax.tan and numpy.tan work different near divergence points + q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [q, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=2e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + + @genNamedParametersNArgs(3) + def testCauchyPpf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.ppf + lax_fun = lsp_stats.cauchy.ppf + + def args_maker(): + q, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that q is in desired + # since lax.tan and numpy.tan work different near divergence points + q = np.clip(q, 5e-3, 1 - 5e-3).astype(q.dtype) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [q, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=2e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + @jtu.sample_product( shapes=[ [x_shape, alpha_shape] @@ -326,6 +432,37 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self.assertAllClose( osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) + @genNamedParametersNArgs(4) + def testGammaLogCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.gamma.logcdf + lax_fun = lsp_stats.gamma.logcdf + + def args_maker(): + x, a, loc, scale = map(rng, shapes, dtypes) + x = np.clip(x, 0, None).astype(x.dtype) + return [x, a, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testGammaLogSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.gamma.sf + lax_fun = lsp_stats.gamma.sf + + def args_maker(): + x, a, loc, scale = map(rng, shapes, dtypes) + return [x, a, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(2) def testGenNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -411,32 +548,39 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol={np.float32: 1e-5, np.float64: 1e-6}) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=3e-5) self._CompileAndCheck(lax_fun, args_maker) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticLogpdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.logpdf lax_fun = lsp_stats.logistic.logpdf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) def testLogisticLogpdfOverflow(self): # Regression test for https://github.com/google/jax/issues/10219 @@ -445,31 +589,56 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)), check_dtypes=False) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.ppf lax_fun = lsp_stats.logistic.ppf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) - @genNamedParametersNArgs(1) + @genNamedParametersNArgs(3) def testLogisticSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.sf lax_fun = lsp_stats.logistic.sf def args_maker(): - return list(map(rng, shapes, dtypes)) + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=2e-5) - self._CompileAndCheck(lax_fun, args_maker) + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=2e-5) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testLogisticIsf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.logistic.isf + lax_fun = lsp_stats.logistic.isf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # ensure that scale is not too low + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(3) def testNormLogPdf(self, shapes, dtypes): @@ -524,6 +693,22 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) + def testNormSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.norm.sf + lax_fun = lsp_stats.norm.sf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-6) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormPpf(self, shapes, dtypes): @@ -543,6 +728,24 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + @genNamedParametersNArgs(3) + def testNormIsf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.norm.isf + lax_fun = lsp_stats.norm.isf + + def args_maker(): + q, loc, scale = map(rng, shapes, dtypes) + # ensure probability is between 0 and 1: + q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [q, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + @genNamedParametersNArgs(5) def testTruncnormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -716,6 +919,36 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) + def testChi2LogCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.chi2.logcdf + lax_fun = lsp_stats.chi2.logcdf + + def args_maker(): + x, df, loc, scale = map(rng, shapes, dtypes) + return [x, df, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testChi2Sf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.chi2.sf + lax_fun = lsp_stats.chi2.sf + + def args_maker(): + x, df, loc, scale = map(rng, shapes, dtypes) + return [x, df, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(5) def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng())