mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add .hypothesis/ directory to .gitignore
and ppf and cdf to scipy.stats.uniform
This commit is contained in:
parent
c0d51e7dde
commit
390e90361a
1
.gitignore
vendored
1
.gitignore
vendored
@ -21,6 +21,7 @@
|
|||||||
.envrc
|
.envrc
|
||||||
jax.iml
|
jax.iml
|
||||||
.bazelrc.user
|
.bazelrc.user
|
||||||
|
.hypothesis/
|
||||||
|
|
||||||
# virtualenv/venv directories
|
# virtualenv/venv directories
|
||||||
/venv/
|
/venv/
|
||||||
|
@ -425,6 +425,8 @@ jax.scipy.stats.uniform
|
|||||||
|
|
||||||
logpdf
|
logpdf
|
||||||
pdf
|
pdf
|
||||||
|
cdf
|
||||||
|
ppf
|
||||||
|
|
||||||
jax.scipy.stats.gaussian_kde
|
jax.scipy.stats.gaussian_kde
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
import scipy.stats as osp_stats
|
import scipy.stats as osp_stats
|
||||||
|
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
from jax import numpy as jnp
|
||||||
from jax.numpy import where, inf, logical_or
|
from jax.numpy import where, inf, logical_or
|
||||||
from jax._src.typing import Array, ArrayLike
|
from jax._src.typing import Array, ArrayLike
|
||||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||||
@ -32,3 +33,21 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
|||||||
@_wraps(osp_stats.uniform.pdf, update_doc=False)
|
@_wraps(osp_stats.uniform.pdf, update_doc=False)
|
||||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
return lax.exp(logpdf(x, loc, scale))
|
return lax.exp(logpdf(x, loc, scale))
|
||||||
|
|
||||||
|
@_wraps(osp_stats.uniform.cdf, update_doc=False)
|
||||||
|
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
|
x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale)
|
||||||
|
zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype)
|
||||||
|
conds = [lax.lt(x, loc), lax.gt(x, lax.add(loc, scale)), lax.ge(x, loc) & lax.le(x, lax.add(loc, scale))]
|
||||||
|
vals = [zero, one, lax.div(lax.sub(x, loc), scale)]
|
||||||
|
|
||||||
|
return jnp.select(conds, vals)
|
||||||
|
|
||||||
|
@_wraps(osp_stats.uniform.ppf, update_doc=False)
|
||||||
|
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||||
|
q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale)
|
||||||
|
return where(
|
||||||
|
jnp.isnan(q) | (q < 0) | (q > 1),
|
||||||
|
jnp.nan,
|
||||||
|
lax.add(loc, lax.mul(scale, q))
|
||||||
|
)
|
||||||
|
@ -18,4 +18,6 @@
|
|||||||
from jax._src.scipy.stats.uniform import (
|
from jax._src.scipy.stats.uniform import (
|
||||||
logpdf as logpdf,
|
logpdf as logpdf,
|
||||||
pdf as pdf,
|
pdf as pdf,
|
||||||
|
cdf as cdf,
|
||||||
|
ppf as ppf,
|
||||||
)
|
)
|
||||||
|
@ -1043,6 +1043,36 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
tol=1e-4)
|
tol=1e-4)
|
||||||
self._CompileAndCheck(lax_fun, args_maker)
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(3)
|
||||||
|
def testUniformCdf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
scipy_fun = osp_stats.uniform.cdf
|
||||||
|
lax_fun = lsp_stats.uniform.cdf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
return [x, loc, np.abs(scale)]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
|
tol=1e-5)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(3)
|
||||||
|
def testUniformPpf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
scipy_fun = osp_stats.uniform.ppf
|
||||||
|
lax_fun = lsp_stats.uniform.ppf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
q, loc, scale = map(rng, shapes, dtypes)
|
||||||
|
return [q, loc, np.abs(scale)]
|
||||||
|
|
||||||
|
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||||
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
|
tol=1e-5)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
@genNamedParametersNArgs(4)
|
@genNamedParametersNArgs(4)
|
||||||
def testChi2LogPdf(self, shapes, dtypes):
|
def testChi2LogPdf(self, shapes, dtypes):
|
||||||
rng = jtu.rand_positive(self.rng())
|
rng = jtu.rand_positive(self.rng())
|
||||||
@ -1058,6 +1088,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
tol=5e-4)
|
tol=5e-4)
|
||||||
self._CompileAndCheck(lax_fun, args_maker)
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
|
||||||
@genNamedParametersNArgs(4)
|
@genNamedParametersNArgs(4)
|
||||||
def testChi2LogCdf(self, shapes, dtypes):
|
def testChi2LogCdf(self, shapes, dtypes):
|
||||||
rng = jtu.rand_positive(self.rng())
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user