mirror of
https://github.com/ROCm/jax.git
synced 2025-04-13 02:16:06 +00:00
Add .hypothesis/ directory to .gitignore
and ppf and cdf to scipy.stats.uniform
This commit is contained in:
parent
c0d51e7dde
commit
390e90361a
1
.gitignore
vendored
1
.gitignore
vendored
@ -21,6 +21,7 @@
|
||||
.envrc
|
||||
jax.iml
|
||||
.bazelrc.user
|
||||
.hypothesis/
|
||||
|
||||
# virtualenv/venv directories
|
||||
/venv/
|
||||
|
@ -425,6 +425,8 @@ jax.scipy.stats.uniform
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
cdf
|
||||
ppf
|
||||
|
||||
jax.scipy.stats.gaussian_kde
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -16,6 +16,7 @@
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax.numpy import where, inf, logical_or
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
@ -32,3 +33,21 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
@_wraps(osp_stats.uniform.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
@_wraps(osp_stats.uniform.cdf, update_doc=False)
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale)
|
||||
zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype)
|
||||
conds = [lax.lt(x, loc), lax.gt(x, lax.add(loc, scale)), lax.ge(x, loc) & lax.le(x, lax.add(loc, scale))]
|
||||
vals = [zero, one, lax.div(lax.sub(x, loc), scale)]
|
||||
|
||||
return jnp.select(conds, vals)
|
||||
|
||||
@_wraps(osp_stats.uniform.ppf, update_doc=False)
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale)
|
||||
return where(
|
||||
jnp.isnan(q) | (q < 0) | (q > 1),
|
||||
jnp.nan,
|
||||
lax.add(loc, lax.mul(scale, q))
|
||||
)
|
||||
|
@ -18,4 +18,6 @@
|
||||
from jax._src.scipy.stats.uniform import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
cdf as cdf,
|
||||
ppf as ppf,
|
||||
)
|
||||
|
@ -1043,6 +1043,36 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testUniformCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.uniform.cdf
|
||||
lax_fun = lsp_stats.uniform.cdf
|
||||
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, np.abs(scale)]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-5)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testUniformPpf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.uniform.ppf
|
||||
lax_fun = lsp_stats.uniform.ppf
|
||||
|
||||
def args_maker():
|
||||
q, loc, scale = map(rng, shapes, dtypes)
|
||||
return [q, loc, np.abs(scale)]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-5)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testChi2LogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
@ -1058,6 +1088,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testChi2LogCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user