stats.norm: add logsf & make sf more accurate near zero

This commit is contained in:
Jake VanderPlas 2023-08-21 16:48:39 -07:00
parent 2f28848c7c
commit cf11f8da8a
3 changed files with 33 additions and 2 deletions

View File

@ -57,10 +57,16 @@ def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return jnp.asarray(special.ndtri(q) * scale + loc, float)
@_wraps(osp_stats.norm.logsf, update_doc=False)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("norm.logsf", x, loc, scale)
return logcdf(-x, -loc, scale)
@_wraps(osp_stats.norm.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
x, loc, scale = promote_args_inexact("norm.sf", x, loc, scale)
return cdf(-x, -loc, scale)
@_wraps(osp_stats.norm.isf, update_doc=False)

View File

@ -19,6 +19,7 @@ from jax._src.scipy.stats.norm import (
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
ppf as ppf,
sf as sf,

View File

@ -693,6 +693,23 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testNormLogSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.norm.logsf
lax_fun = lsp_stats.norm.logsf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testNormSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
@ -710,6 +727,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)
def testNormSfNearZero(self):
# Regression test for https://github.com/google/jax/issues/17199
value = np.array(10, np.float32)
self.assertAllClose(osp_stats.norm.sf(value).astype('float32'),
lsp_stats.norm.sf(value),
atol=0, rtol=1E-5)
@genNamedParametersNArgs(3)
def testNormPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())