Added scipy.stats.gennorm.

This commit is contained in:
carlosgmartin 2022-06-11 14:24:19 -04:00
parent 5b808c93d0
commit 57b89ba7cb
6 changed files with 91 additions and 0 deletions

View File

@ -55,6 +55,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
silently truncated to `1`.
* {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs
bucket path as input.
* Added {func}`jax.scipy.stats.gennorm`.
## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -220,6 +220,16 @@ jax.scipy.stats.gamma
logpdf
pdf
jax.scipy.stats.gennorm
~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.gennorm
.. autosummary::
:toctree: _autosummary
cdf
logpdf
pdf
jax.scipy.stats.geom
~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.geom

View File

@ -0,0 +1,32 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import scipy.stats as osp_stats
from jax import lax
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
@_wraps(osp_stats.gennorm.logpdf, update_doc=False)
def logpdf(x, p):
x, p = _promote_args_inexact("gennorm.logpdf", x, p)
return lax.log(.5 * p) - lax.lgamma(1/p) - lax.abs(x)**p
@_wraps(osp_stats.gennorm.cdf, update_doc=False)
def cdf(x, p):
x, p = _promote_args_inexact("gennorm.cdf", x, p)
return .5 * (1 + lax.sign(x) * lax.igamma(1/p, lax.abs(x)**p))
@_wraps(osp_stats.gennorm.pdf, update_doc=False)
def pdf(x, p):
return lax.exp(logpdf(x, p))

View File

@ -30,3 +30,4 @@ from jax.scipy.stats import t as t
from jax.scipy.stats import uniform as uniform
from jax.scipy.stats import chi2 as chi2
from jax.scipy.stats import betabinom as betabinom
from jax.scipy.stats import gennorm as gennorm

View File

@ -0,0 +1,19 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.scipy.stats.gennorm import (
cdf as cdf,
logpdf as logpdf,
pdf as pdf,
)

View File

@ -245,6 +245,34 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self.assertAllClose(
osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6)
@genNamedParametersNArgs(2)
def testGenNormLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.gennorm.logpdf
lax_fun = lsp_stats.gennorm.logpdf
def args_maker():
x, p = map(rng, shapes, dtypes)
return [x, p]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4, rtol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testGenNormCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.gennorm.cdf
lax_fun = lsp_stats.gennorm.cdf
def args_maker():
x, p = map(rng, shapes, dtypes)
return [x, p]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4, rtol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testNBinomLogPmf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())