mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Added scipy.stats.gennorm.
This commit is contained in:
parent
5b808c93d0
commit
57b89ba7cb
@ -55,6 +55,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
silently truncated to `1`.
|
||||
* {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs
|
||||
bucket path as input.
|
||||
* Added {func}`jax.scipy.stats.gennorm`.
|
||||
|
||||
## jaxlib 0.3.11 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||
|
@ -220,6 +220,16 @@ jax.scipy.stats.gamma
|
||||
logpdf
|
||||
pdf
|
||||
|
||||
jax.scipy.stats.gennorm
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.gennorm
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
cdf
|
||||
logpdf
|
||||
pdf
|
||||
|
||||
jax.scipy.stats.geom
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.geom
|
||||
|
32
jax/_src/scipy/stats/gennorm.py
Normal file
32
jax/_src/scipy/stats/gennorm.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.lax_numpy import _promote_args_inexact
|
||||
|
||||
@_wraps(osp_stats.gennorm.logpdf, update_doc=False)
|
||||
def logpdf(x, p):
|
||||
x, p = _promote_args_inexact("gennorm.logpdf", x, p)
|
||||
return lax.log(.5 * p) - lax.lgamma(1/p) - lax.abs(x)**p
|
||||
|
||||
@_wraps(osp_stats.gennorm.cdf, update_doc=False)
|
||||
def cdf(x, p):
|
||||
x, p = _promote_args_inexact("gennorm.cdf", x, p)
|
||||
return .5 * (1 + lax.sign(x) * lax.igamma(1/p, lax.abs(x)**p))
|
||||
|
||||
@_wraps(osp_stats.gennorm.pdf, update_doc=False)
|
||||
def pdf(x, p):
|
||||
return lax.exp(logpdf(x, p))
|
@ -30,3 +30,4 @@ from jax.scipy.stats import t as t
|
||||
from jax.scipy.stats import uniform as uniform
|
||||
from jax.scipy.stats import chi2 as chi2
|
||||
from jax.scipy.stats import betabinom as betabinom
|
||||
from jax.scipy.stats import gennorm as gennorm
|
||||
|
19
jax/scipy/stats/gennorm.py
Normal file
19
jax/scipy/stats/gennorm.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.scipy.stats.gennorm import (
|
||||
cdf as cdf,
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
@ -245,6 +245,34 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(
|
||||
osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6)
|
||||
|
||||
@genNamedParametersNArgs(2)
|
||||
def testGenNormLogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.gennorm.logpdf
|
||||
lax_fun = lsp_stats.gennorm.logpdf
|
||||
|
||||
def args_maker():
|
||||
x, p = map(rng, shapes, dtypes)
|
||||
return [x, p]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4, rtol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(2)
|
||||
def testGenNormCdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.gennorm.cdf
|
||||
lax_fun = lsp_stats.gennorm.cdf
|
||||
|
||||
def args_maker():
|
||||
x, p = map(rng, shapes, dtypes)
|
||||
return [x, p]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4, rtol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testNBinomLogPmf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user