mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
stats.norm: add logsf & make sf more accurate near zero
This commit is contained in:
parent
2f28848c7c
commit
cf11f8da8a
@ -57,10 +57,16 @@ def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return jnp.asarray(special.ndtri(q) * scale + loc, float)
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.logsf, update_doc=False)
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("norm.logsf", x, loc, scale)
|
||||
return logcdf(-x, -loc, scale)
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.sf, update_doc=False)
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
cdf_result = cdf(x, loc, scale)
|
||||
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
|
||||
x, loc, scale = promote_args_inexact("norm.sf", x, loc, scale)
|
||||
return cdf(-x, -loc, scale)
|
||||
|
||||
|
||||
@_wraps(osp_stats.norm.isf, update_doc=False)
|
||||
|
@ -19,6 +19,7 @@ from jax._src.scipy.stats.norm import (
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
logpdf as logpdf,
|
||||
logsf as logsf,
|
||||
pdf as pdf,
|
||||
ppf as ppf,
|
||||
sf as sf,
|
||||
|
@ -693,6 +693,23 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testNormLogSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.norm.logsf
|
||||
lax_fun = lsp_stats.norm.logsf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testNormSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -710,6 +727,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testNormSfNearZero(self):
|
||||
# Regression test for https://github.com/google/jax/issues/17199
|
||||
value = np.array(10, np.float32)
|
||||
self.assertAllClose(osp_stats.norm.sf(value).astype('float32'),
|
||||
lsp_stats.norm.sf(value),
|
||||
atol=0, rtol=1E-5)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testNormPpf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user