Added scipy.stats.gennorm.

This commit is contained in:
carlosgmartin 2022-06-11 14:24:19 -04:00
parent 5b808c93d0
commit 57b89ba7cb
6 changed files with 91 additions and 0 deletions

View File

@ -55,6 +55,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
silently truncated to `1`. silently truncated to `1`.
* {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs * {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs
bucket path as input. bucket path as input.
* Added {func}`jax.scipy.stats.gennorm`.
## jaxlib 0.3.11 (Unreleased) ## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -220,6 +220,16 @@ jax.scipy.stats.gamma
logpdf logpdf
pdf pdf
jax.scipy.stats.gennorm
~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.gennorm
.. autosummary::
:toctree: _autosummary
cdf
logpdf
pdf
jax.scipy.stats.geom jax.scipy.stats.geom
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.geom .. automodule:: jax.scipy.stats.geom

View File

@ -0,0 +1,32 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import scipy.stats as osp_stats
from jax import lax
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
@_wraps(osp_stats.gennorm.logpdf, update_doc=False)
def logpdf(x, p):
x, p = _promote_args_inexact("gennorm.logpdf", x, p)
return lax.log(.5 * p) - lax.lgamma(1/p) - lax.abs(x)**p
@_wraps(osp_stats.gennorm.cdf, update_doc=False)
def cdf(x, p):
x, p = _promote_args_inexact("gennorm.cdf", x, p)
return .5 * (1 + lax.sign(x) * lax.igamma(1/p, lax.abs(x)**p))
@_wraps(osp_stats.gennorm.pdf, update_doc=False)
def pdf(x, p):
return lax.exp(logpdf(x, p))

View File

@ -30,3 +30,4 @@ from jax.scipy.stats import t as t
from jax.scipy.stats import uniform as uniform from jax.scipy.stats import uniform as uniform
from jax.scipy.stats import chi2 as chi2 from jax.scipy.stats import chi2 as chi2
from jax.scipy.stats import betabinom as betabinom from jax.scipy.stats import betabinom as betabinom
from jax.scipy.stats import gennorm as gennorm

View File

@ -0,0 +1,19 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.scipy.stats.gennorm import (
cdf as cdf,
logpdf as logpdf,
pdf as pdf,
)

View File

@ -245,6 +245,34 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self.assertAllClose( self.assertAllClose(
osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6)
@genNamedParametersNArgs(2)
def testGenNormLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.gennorm.logpdf
lax_fun = lsp_stats.gennorm.logpdf
def args_maker():
x, p = map(rng, shapes, dtypes)
return [x, p]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4, rtol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testGenNormCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.gennorm.cdf
lax_fun = lsp_stats.gennorm.cdf
def args_maker():
x, p = map(rng, shapes, dtypes)
return [x, p]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4, rtol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4) @genNamedParametersNArgs(4)
def testNBinomLogPmf(self, shapes, dtypes): def testNBinomLogPmf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng()) rng = jtu.rand_positive(self.rng())