mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
feat(gh-13291): Add exponential distribution functions: cdf, logcdf, sf, logsf, and ppf
This commit is contained in:
parent
ed952c8e65
commit
42b64fc06c
@ -12,8 +12,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from jax import lax
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from jax import lax
|
||||||
from jax._src.numpy.util import promote_args_inexact
|
from jax._src.numpy.util import promote_args_inexact
|
||||||
from jax._src.typing import Array, ArrayLike
|
from jax._src.typing import Array, ArrayLike
|
||||||
|
|
||||||
@ -41,7 +42,13 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
|||||||
array of logpdf values.
|
array of logpdf values.
|
||||||
|
|
||||||
See Also:
|
See Also:
|
||||||
|
:func:`jax.scipy.stats.expon.cdf`
|
||||||
:func:`jax.scipy.stats.expon.pdf`
|
:func:`jax.scipy.stats.expon.pdf`
|
||||||
|
:func:`jax.scipy.stats.expon.ppf`
|
||||||
|
:func:`jax.scipy.stats.expon.sf`
|
||||||
|
:func:`jax.scipy.stats.expon.logcdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logpdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logsf`
|
||||||
"""
|
"""
|
||||||
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
|
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
|
||||||
log_scale = lax.log(scale)
|
log_scale = lax.log(scale)
|
||||||
@ -73,6 +80,188 @@ def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
|||||||
array of pdf values.
|
array of pdf values.
|
||||||
|
|
||||||
See Also:
|
See Also:
|
||||||
|
:func:`jax.scipy.stats.expon.cdf`
|
||||||
|
:func:`jax.scipy.stats.expon.pdf`
|
||||||
|
:func:`jax.scipy.stats.expon.ppf`
|
||||||
|
:func:`jax.scipy.stats.expon.sf`
|
||||||
|
:func:`jax.scipy.stats.expon.logcdf`
|
||||||
:func:`jax.scipy.stats.expon.logpdf`
|
:func:`jax.scipy.stats.expon.logpdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logsf`
|
||||||
"""
|
"""
|
||||||
return lax.exp(logpdf(x, loc, scale))
|
return lax.exp(logpdf(x, loc, scale))
|
||||||
|
|
||||||
|
|
||||||
|
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
|
r"""Exponential cumulative density function.
|
||||||
|
|
||||||
|
JAX implementation of :obj:`scipy.stats.expon` ``cdf``.
|
||||||
|
|
||||||
|
The cdf is defined as
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
|
||||||
|
|
||||||
|
where :math:`f_{pdf}` is the exponential distribution probability density function,
|
||||||
|
:func:`jax.scipy.stats.expon.pdf`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: arraylike, value at which to evaluate the PDF
|
||||||
|
loc: arraylike, distribution offset parameter
|
||||||
|
scale: arraylike, distribution scale parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array of pdf values.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
:func:`jax.scipy.stats.expon.cdf`
|
||||||
|
:func:`jax.scipy.stats.expon.pdf`
|
||||||
|
:func:`jax.scipy.stats.expon.ppf`
|
||||||
|
:func:`jax.scipy.stats.expon.sf`
|
||||||
|
:func:`jax.scipy.stats.expon.logcdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logpdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logsf`
|
||||||
|
"""
|
||||||
|
x, loc, scale = promote_args_inexact("expon.cdf", x, loc, scale)
|
||||||
|
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||||
|
return jnp.where(
|
||||||
|
lax.lt(x, loc), jnp.zeros_like(scaled_x), lax.neg(lax.expm1(lax.neg(scaled_x)))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
|
r"""Exponential log cumulative density function.
|
||||||
|
|
||||||
|
JAX implementation of :obj:`scipy.stats.expon` ``logcdf``.
|
||||||
|
|
||||||
|
The cdf is defined as
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
|
||||||
|
|
||||||
|
where :math:`f_{pdf}` is the exponential distribution probability density function,
|
||||||
|
:func:`jax.scipy.stats.expon.pdf`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: arraylike, value at which to evaluate the PDF
|
||||||
|
loc: arraylike, distribution offset parameter
|
||||||
|
scale: arraylike, distribution scale parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array of pdf values.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
:func:`jax.scipy.stats.expon.cdf`
|
||||||
|
:func:`jax.scipy.stats.expon.pdf`
|
||||||
|
:func:`jax.scipy.stats.expon.ppf`
|
||||||
|
:func:`jax.scipy.stats.expon.sf`
|
||||||
|
:func:`jax.scipy.stats.expon.logcdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logpdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logsf`
|
||||||
|
"""
|
||||||
|
return lax.log(cdf(x, loc, scale))
|
||||||
|
|
||||||
|
|
||||||
|
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
|
r"""Exponential log survival function.
|
||||||
|
|
||||||
|
JAX implementation of :obj:`scipy.stats.expon` ``logsf``.
|
||||||
|
|
||||||
|
The survival function is defined as
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||||
|
|
||||||
|
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
|
||||||
|
:func:`jax.scipy.stats.expon.cdf`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: arraylike, value at which to evaluate the PDF
|
||||||
|
loc: arraylike, distribution offset parameter
|
||||||
|
scale: arraylike, distribution scale parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array of pdf values.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
:func:`jax.scipy.stats.expon.cdf`
|
||||||
|
:func:`jax.scipy.stats.expon.pdf`
|
||||||
|
:func:`jax.scipy.stats.expon.ppf`
|
||||||
|
:func:`jax.scipy.stats.expon.sf`
|
||||||
|
:func:`jax.scipy.stats.expon.logcdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logpdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logsf`
|
||||||
|
"""
|
||||||
|
x, loc, scale = promote_args_inexact("expon.sf", x, loc, scale)
|
||||||
|
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||||
|
return jnp.where(lax.lt(x, loc), jnp.zeros_like(scaled_x), lax.neg(scaled_x))
|
||||||
|
|
||||||
|
|
||||||
|
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
|
r"""Exponential survival function.
|
||||||
|
|
||||||
|
JAX implementation of :obj:`scipy.stats.expon` ``sf``.
|
||||||
|
|
||||||
|
The survival function is defined as
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||||
|
|
||||||
|
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
|
||||||
|
:func:`jax.scipy.stats.expon.cdf`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: arraylike, value at which to evaluate the PDF
|
||||||
|
loc: arraylike, distribution offset parameter
|
||||||
|
scale: arraylike, distribution scale parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array of pdf values.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
:func:`jax.scipy.stats.expon.cdf`
|
||||||
|
:func:`jax.scipy.stats.expon.pdf`
|
||||||
|
:func:`jax.scipy.stats.expon.ppf`
|
||||||
|
:func:`jax.scipy.stats.expon.sf`
|
||||||
|
:func:`jax.scipy.stats.expon.logcdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logpdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logsf`
|
||||||
|
"""
|
||||||
|
return lax.exp(logsf(x, loc, scale))
|
||||||
|
|
||||||
|
|
||||||
|
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
|
r"""Exponential survival function.
|
||||||
|
|
||||||
|
JAX implementation of :obj:`scipy.stats.expon` ``ppf``.
|
||||||
|
|
||||||
|
The percent point function is defined as the inverse of the
|
||||||
|
cumulative distribution function, :func:`jax.scipy.stats.expon.cdf`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: arraylike, value at which to evaluate the PDF
|
||||||
|
loc: arraylike, distribution offset parameter
|
||||||
|
scale: arraylike, distribution scale parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array of pdf values.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
:func:`jax.scipy.stats.expon.cdf`
|
||||||
|
:func:`jax.scipy.stats.expon.pdf`
|
||||||
|
:func:`jax.scipy.stats.expon.ppf`
|
||||||
|
:func:`jax.scipy.stats.expon.sf`
|
||||||
|
:func:`jax.scipy.stats.expon.logcdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logpdf`
|
||||||
|
:func:`jax.scipy.stats.expon.logsf`
|
||||||
|
"""
|
||||||
|
q, loc, scale = promote_args_inexact("expon.ppf", q, loc, scale)
|
||||||
|
scaled_q = lax.div(lax.sub(q, loc), scale)
|
||||||
|
return jnp.where(
|
||||||
|
jnp.isnan(q) | (q < 0) | (q > 1),
|
||||||
|
jnp.nan,
|
||||||
|
lax.neg(lax.log1p(lax.neg(scaled_q))),
|
||||||
|
)
|
||||||
|
@ -16,6 +16,11 @@
|
|||||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||||
|
|
||||||
from jax._src.scipy.stats.expon import (
|
from jax._src.scipy.stats.expon import (
|
||||||
|
cdf as cdf,
|
||||||
|
logcdf as logcdf,
|
||||||
logpdf as logpdf,
|
logpdf as logpdf,
|
||||||
|
logsf as logsf,
|
||||||
pdf as pdf,
|
pdf as pdf,
|
||||||
|
ppf as ppf,
|
||||||
|
sf as sf,
|
||||||
)
|
)
|
||||||
|
@ -523,6 +523,86 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
tol=1e-4)
|
tol=1e-4)
|
||||||
self._CompileAndCheck(lax_fun, args_maker)
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(3)
|
||||||
|
def testExponLogCdf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
scipy_fun = osp_stats.expon.logcdf
|
||||||
|
lax_fun = lsp_stats.expon.logcdf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
return [x, loc, scale]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(
|
||||||
|
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
|
||||||
|
)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(3)
|
||||||
|
def testExponCdf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
scipy_fun = osp_stats.expon.cdf
|
||||||
|
lax_fun = lsp_stats.expon.cdf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
return [x, loc, scale]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(
|
||||||
|
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
|
||||||
|
)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(3)
|
||||||
|
def testExponSf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
scipy_fun = osp_stats.expon.sf
|
||||||
|
lax_fun = lsp_stats.expon.sf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
return [x, loc, scale]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(
|
||||||
|
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
|
||||||
|
)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(3)
|
||||||
|
def testExponLogSf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
scipy_fun = osp_stats.expon.logsf
|
||||||
|
lax_fun = lsp_stats.expon.logsf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
return [x, loc, scale]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(
|
||||||
|
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
|
||||||
|
)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(3)
|
||||||
|
def testExponPpf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
scipy_fun = osp_stats.expon.ppf
|
||||||
|
lax_fun = lsp_stats.expon.ppf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
q, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
return [q, loc, scale]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(
|
||||||
|
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
|
||||||
|
)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
@genNamedParametersNArgs(4)
|
@genNamedParametersNArgs(4)
|
||||||
def testGammaLogPdf(self, shapes, dtypes):
|
def testGammaLogPdf(self, shapes, dtypes):
|
||||||
rng = jtu.rand_positive(self.rng())
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user