diff --git a/CHANGELOG.md b/CHANGELOG.md index 0436d9639..d0215c312 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. silently truncated to `1`. * {func}`jax.experimental.compilation_cache.initialize_cache` now supports gcs bucket path as input. + * Added {func}`jax.scipy.stats.gennorm`. ## jaxlib 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index efdfa85af..d099c02fc 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -220,6 +220,16 @@ jax.scipy.stats.gamma logpdf pdf +jax.scipy.stats.gennorm +~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: jax.scipy.stats.gennorm +.. autosummary:: + :toctree: _autosummary + + cdf + logpdf + pdf + jax.scipy.stats.geom ~~~~~~~~~~~~~~~~~~~~ .. automodule:: jax.scipy.stats.geom diff --git a/jax/_src/scipy/stats/gennorm.py b/jax/_src/scipy/stats/gennorm.py new file mode 100644 index 000000000..d4b55b050 --- /dev/null +++ b/jax/_src/scipy/stats/gennorm.py @@ -0,0 +1,32 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import scipy.stats as osp_stats +from jax import lax +from jax._src.numpy.util import _wraps +from jax._src.numpy.lax_numpy import _promote_args_inexact + +@_wraps(osp_stats.gennorm.logpdf, update_doc=False) +def logpdf(x, p): + x, p = _promote_args_inexact("gennorm.logpdf", x, p) + return lax.log(.5 * p) - lax.lgamma(1/p) - lax.abs(x)**p + +@_wraps(osp_stats.gennorm.cdf, update_doc=False) +def cdf(x, p): + x, p = _promote_args_inexact("gennorm.cdf", x, p) + return .5 * (1 + lax.sign(x) * lax.igamma(1/p, lax.abs(x)**p)) + +@_wraps(osp_stats.gennorm.pdf, update_doc=False) +def pdf(x, p): + return lax.exp(logpdf(x, p)) diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index bf36cf77f..9876d2edc 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -30,3 +30,4 @@ from jax.scipy.stats import t as t from jax.scipy.stats import uniform as uniform from jax.scipy.stats import chi2 as chi2 from jax.scipy.stats import betabinom as betabinom +from jax.scipy.stats import gennorm as gennorm diff --git a/jax/scipy/stats/gennorm.py b/jax/scipy/stats/gennorm.py new file mode 100644 index 000000000..2d069773d --- /dev/null +++ b/jax/scipy/stats/gennorm.py @@ -0,0 +1,19 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.scipy.stats.gennorm import ( + cdf as cdf, + logpdf as logpdf, + pdf as pdf, +) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index f691c1db8..fb609ce13 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -245,6 +245,34 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self.assertAllClose( osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) + @genNamedParametersNArgs(2) + def testGenNormLogPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gennorm.logpdf + lax_fun = lsp_stats.gennorm.logpdf + + def args_maker(): + x, p = map(rng, shapes, dtypes) + return [x, p] + + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4, rtol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(2) + def testGenNormCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.gennorm.cdf + lax_fun = lsp_stats.gennorm.cdf + + def args_maker(): + x, p = map(rng, shapes, dtypes) + return [x, p] + + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4, rtol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) def testNBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng())