feat(gh-13291): Add exponential distribution functions: cdf, logcdf, sf, logsf, and ppf

This commit is contained in:
Qazalbash 2025-02-01 12:51:11 +05:00
parent ed952c8e65
commit 42b64fc06c
No known key found for this signature in database
GPG Key ID: 624E2F28F6A2AAA7
3 changed files with 275 additions and 1 deletions

View File

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import lax
import jax.numpy as jnp
from jax import lax
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
@ -41,7 +42,13 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
log_scale = lax.log(scale)
@ -73,6 +80,188 @@ def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.exp(logpdf(x, loc, scale))
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential cumulative density function.
JAX implementation of :obj:`scipy.stats.expon` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
where :math:`f_{pdf}` is the exponential distribution probability density function,
:func:`jax.scipy.stats.expon.pdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.cdf", x, loc, scale)
scaled_x = lax.div(lax.sub(x, loc), scale)
return jnp.where(
lax.lt(x, loc), jnp.zeros_like(scaled_x), lax.neg(lax.expm1(lax.neg(scaled_x)))
)
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log cumulative density function.
JAX implementation of :obj:`scipy.stats.expon` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
where :math:`f_{pdf}` is the exponential distribution probability density function,
:func:`jax.scipy.stats.expon.pdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.log(cdf(x, loc, scale))
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log survival function.
JAX implementation of :obj:`scipy.stats.expon` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
:func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.sf", x, loc, scale)
scaled_x = lax.div(lax.sub(x, loc), scale)
return jnp.where(lax.lt(x, loc), jnp.zeros_like(scaled_x), lax.neg(scaled_x))
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential survival function.
JAX implementation of :obj:`scipy.stats.expon` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
:func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.exp(logsf(x, loc, scale))
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential survival function.
JAX implementation of :obj:`scipy.stats.expon` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
q, loc, scale = promote_args_inexact("expon.ppf", q, loc, scale)
scaled_q = lax.div(lax.sub(q, loc), scale)
return jnp.where(
jnp.isnan(q) | (q < 0) | (q > 1),
jnp.nan,
lax.neg(lax.log1p(lax.neg(scaled_q))),
)

View File

@ -16,6 +16,11 @@
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.expon import (
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
ppf as ppf,
sf as sf,
)

View File

@ -523,6 +523,86 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testExponLogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.logcdf
lax_fun = lsp_stats.expon.logcdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testExponCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.cdf
lax_fun = lsp_stats.expon.cdf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testExponSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.sf
lax_fun = lsp_stats.expon.sf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testExponLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.logsf
lax_fun = lsp_stats.expon.logsf
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testExponPpf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.ppf
lax_fun = lsp_stats.expon.ppf
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
return [q, loc, scale]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testGammaLogPdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())