mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
feat(gh-13291): Add exponential distribution functions: cdf, logcdf, sf, logsf, and ppf
This commit is contained in:
parent
ed952c8e65
commit
42b64fc06c
@ -12,8 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax import lax
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
@ -41,7 +42,13 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
|
||||
log_scale = lax.log(scale)
|
||||
@ -73,6 +80,188 @@ def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the exponential distribution probability density function,
|
||||
:func:`jax.scipy.stats.expon.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("expon.cdf", x, loc, scale)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
return jnp.where(
|
||||
lax.lt(x, loc), jnp.zeros_like(scaled_x), lax.neg(lax.expm1(lax.neg(scaled_x)))
|
||||
)
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential log cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the exponential distribution probability density function,
|
||||
:func:`jax.scipy.stats.expon.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
return lax.log(cdf(x, loc, scale))
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``logsf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
|
||||
:func:`jax.scipy.stats.expon.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("expon.sf", x, loc, scale)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
return jnp.where(lax.lt(x, loc), jnp.zeros_like(scaled_x), lax.neg(scaled_x))
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
|
||||
:func:`jax.scipy.stats.expon.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
return lax.exp(logsf(x, loc, scale))
|
||||
|
||||
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``ppf``.
|
||||
|
||||
The percent point function is defined as the inverse of the
|
||||
cumulative distribution function, :func:`jax.scipy.stats.expon.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
q, loc, scale = promote_args_inexact("expon.ppf", q, loc, scale)
|
||||
scaled_q = lax.div(lax.sub(q, loc), scale)
|
||||
return jnp.where(
|
||||
jnp.isnan(q) | (q < 0) | (q > 1),
|
||||
jnp.nan,
|
||||
lax.neg(lax.log1p(lax.neg(scaled_q))),
|
||||
)
|
||||
|
@ -16,6 +16,11 @@
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.scipy.stats.expon import (
|
||||
cdf as cdf,
|
||||
logcdf as logcdf,
|
||||
logpdf as logpdf,
|
||||
logsf as logsf,
|
||||
pdf as pdf,
|
||||
ppf as ppf,
|
||||
sf as sf,
|
||||
)
|
||||
|
@ -523,6 +523,86 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testExponLogCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.expon.logcdf
|
||||
lax_fun = lsp_stats.expon.logcdf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(
|
||||
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
|
||||
)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testExponCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.expon.cdf
|
||||
lax_fun = lsp_stats.expon.cdf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(
|
||||
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
|
||||
)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testExponSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.expon.sf
|
||||
lax_fun = lsp_stats.expon.sf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(
|
||||
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
|
||||
)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testExponLogSf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.expon.logsf
|
||||
lax_fun = lsp_stats.expon.logsf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(
|
||||
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
|
||||
)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testExponPpf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.expon.ppf
|
||||
lax_fun = lsp_stats.expon.ppf
|
||||
|
||||
def args_maker():
|
||||
q, loc, scale = map(rng, shapes, dtypes)
|
||||
return [q, loc, scale]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(
|
||||
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
|
||||
)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testGammaLogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user