diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index 4c72bcac5..4f913cc14 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -57,10 +57,16 @@ def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return jnp.asarray(special.ndtri(q) * scale + loc, float) +@_wraps(osp_stats.norm.logsf, update_doc=False) +def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("norm.logsf", x, loc, scale) + return logcdf(-x, -loc, scale) + + @_wraps(osp_stats.norm.sf, update_doc=False) def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - cdf_result = cdf(x, loc, scale) - return lax.sub(_lax_const(cdf_result, 1), cdf_result) + x, loc, scale = promote_args_inexact("norm.sf", x, loc, scale) + return cdf(-x, -loc, scale) @_wraps(osp_stats.norm.isf, update_doc=False) diff --git a/jax/scipy/stats/norm.py b/jax/scipy/stats/norm.py index c6b85f25d..f47765adf 100644 --- a/jax/scipy/stats/norm.py +++ b/jax/scipy/stats/norm.py @@ -19,6 +19,7 @@ from jax._src.scipy.stats.norm import ( cdf as cdf, logcdf as logcdf, logpdf as logpdf, + logsf as logsf, pdf as pdf, ppf as ppf, sf as sf, diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index b73728f17..ad9549ca1 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -693,6 +693,23 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) + def testNormLogSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.norm.logsf + lax_fun = lsp_stats.norm.logsf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) def testNormSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -710,6 +727,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) + def testNormSfNearZero(self): + # Regression test for https://github.com/google/jax/issues/17199 + value = np.array(10, np.float32) + self.assertAllClose(osp_stats.norm.sf(value).astype('float32'), + lsp_stats.norm.sf(value), + atol=0, rtol=1E-5) + @genNamedParametersNArgs(3) def testNormPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng())