mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Added scipy.stats.gennorm.
This commit is contained in:
parent
5b808c93d0
commit
57b89ba7cb
@ -55,6 +55,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
silently truncated to `1`.
|
silently truncated to `1`.
|
||||||
* {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs
|
* {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs
|
||||||
bucket path as input.
|
bucket path as input.
|
||||||
|
* Added {func}`jax.scipy.stats.gennorm`.
|
||||||
|
|
||||||
## jaxlib 0.3.11 (Unreleased)
|
## jaxlib 0.3.11 (Unreleased)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||||
|
@ -220,6 +220,16 @@ jax.scipy.stats.gamma
|
|||||||
logpdf
|
logpdf
|
||||||
pdf
|
pdf
|
||||||
|
|
||||||
|
jax.scipy.stats.gennorm
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
.. automodule:: jax.scipy.stats.gennorm
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
cdf
|
||||||
|
logpdf
|
||||||
|
pdf
|
||||||
|
|
||||||
jax.scipy.stats.geom
|
jax.scipy.stats.geom
|
||||||
~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
.. automodule:: jax.scipy.stats.geom
|
.. automodule:: jax.scipy.stats.geom
|
||||||
|
32
jax/_src/scipy/stats/gennorm.py
Normal file
32
jax/_src/scipy/stats/gennorm.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright 2022 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import scipy.stats as osp_stats
|
||||||
|
from jax import lax
|
||||||
|
from jax._src.numpy.util import _wraps
|
||||||
|
from jax._src.numpy.lax_numpy import _promote_args_inexact
|
||||||
|
|
||||||
|
@_wraps(osp_stats.gennorm.logpdf, update_doc=False)
|
||||||
|
def logpdf(x, p):
|
||||||
|
x, p = _promote_args_inexact("gennorm.logpdf", x, p)
|
||||||
|
return lax.log(.5 * p) - lax.lgamma(1/p) - lax.abs(x)**p
|
||||||
|
|
||||||
|
@_wraps(osp_stats.gennorm.cdf, update_doc=False)
|
||||||
|
def cdf(x, p):
|
||||||
|
x, p = _promote_args_inexact("gennorm.cdf", x, p)
|
||||||
|
return .5 * (1 + lax.sign(x) * lax.igamma(1/p, lax.abs(x)**p))
|
||||||
|
|
||||||
|
@_wraps(osp_stats.gennorm.pdf, update_doc=False)
|
||||||
|
def pdf(x, p):
|
||||||
|
return lax.exp(logpdf(x, p))
|
@ -30,3 +30,4 @@ from jax.scipy.stats import t as t
|
|||||||
from jax.scipy.stats import uniform as uniform
|
from jax.scipy.stats import uniform as uniform
|
||||||
from jax.scipy.stats import chi2 as chi2
|
from jax.scipy.stats import chi2 as chi2
|
||||||
from jax.scipy.stats import betabinom as betabinom
|
from jax.scipy.stats import betabinom as betabinom
|
||||||
|
from jax.scipy.stats import gennorm as gennorm
|
||||||
|
19
jax/scipy/stats/gennorm.py
Normal file
19
jax/scipy/stats/gennorm.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright 2022 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from jax._src.scipy.stats.gennorm import (
|
||||||
|
cdf as cdf,
|
||||||
|
logpdf as logpdf,
|
||||||
|
pdf as pdf,
|
||||||
|
)
|
@ -245,6 +245,34 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
|||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6)
|
osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(2)
|
||||||
|
def testGenNormLogPdf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
scipy_fun = osp_stats.gennorm.logpdf
|
||||||
|
lax_fun = lsp_stats.gennorm.logpdf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, p = map(rng, shapes, dtypes)
|
||||||
|
return [x, p]
|
||||||
|
|
||||||
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
|
tol=1e-4, rtol=1e-3)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
|
@genNamedParametersNArgs(2)
|
||||||
|
def testGenNormCdf(self, shapes, dtypes):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
scipy_fun = osp_stats.gennorm.cdf
|
||||||
|
lax_fun = lsp_stats.gennorm.cdf
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
x, p = map(rng, shapes, dtypes)
|
||||||
|
return [x, p]
|
||||||
|
|
||||||
|
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||||
|
tol=1e-4, rtol=1e-3)
|
||||||
|
self._CompileAndCheck(lax_fun, args_maker)
|
||||||
|
|
||||||
@genNamedParametersNArgs(4)
|
@genNamedParametersNArgs(4)
|
||||||
def testNBinomLogPmf(self, shapes, dtypes):
|
def testNBinomLogPmf(self, shapes, dtypes):
|
||||||
rng = jtu.rand_positive(self.rng())
|
rng = jtu.rand_positive(self.rng())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user