From 890b762a3e9f7b96486e438f4c5e9cf62c1a8766 Mon Sep 17 00:00:00 2001 From: Nicola De Angeli <112023843+niqodea@users.noreply.github.com> Date: Mon, 16 Oct 2023 03:45:24 +0200 Subject: [PATCH] feat: add wrapcauchy logpdf and pdf --- docs/jax.scipy.rst | 9 +++++++ jax/_src/scipy/stats/wrapcauchy.py | 39 ++++++++++++++++++++++++++++++ jax/scipy/stats/__init__.py | 1 + jax/scipy/stats/wrapcauchy.py | 21 ++++++++++++++++ tests/scipy_stats_test.py | 36 +++++++++++++++++++++++++++ 5 files changed, 106 insertions(+) create mode 100644 jax/_src/scipy/stats/wrapcauchy.py create mode 100644 jax/scipy/stats/wrapcauchy.py diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index e3a6be839..16006ffae 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -446,3 +446,12 @@ jax.scipy.stats.vonmises logpdf pdf + +jax.scipy.stats.wrapcauchy +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: jax.scipy.stats.wrapcauchy +.. autosummary:: + :toctree: _autosummary + + logpdf + pdf diff --git a/jax/_src/scipy/stats/wrapcauchy.py b/jax/_src/scipy/stats/wrapcauchy.py new file mode 100644 index 000000000..6f45b5dec --- /dev/null +++ b/jax/_src/scipy/stats/wrapcauchy.py @@ -0,0 +1,39 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import scipy.stats as osp_stats +from jax import lax +import jax.numpy as jnp +from jax._src.lax.lax import _const as _lax_const +from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.typing import Array, ArrayLike + + +@_wraps(osp_stats.wrapcauchy.logpdf, update_doc=False) +def logpdf(x: ArrayLike, c: ArrayLike) -> Array: + x, c = promote_args_inexact('wrapcauchy.logpdf', x, c) + return jnp.where( + lax.gt(c, _lax_const(c, 0)) & lax.lt(c, _lax_const(c, 1)), + jnp.where( + lax.ge(x, _lax_const(x, 0)) & lax.le(x, _lax_const(x, jnp.pi * 2)), + jnp.log(1 - c * c) - jnp.log(2 * jnp.pi) - jnp.log(1 + c * c - 2 * c * jnp.cos(x)), + -jnp.inf, + ), + jnp.nan, + ) + +@_wraps(osp_stats.wrapcauchy.pdf, update_doc=False) +def pdf(x: ArrayLike, c: ArrayLike) -> Array: + return lax.exp(logpdf(x, c)) diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 5d57b1b61..9458f3b7e 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -40,3 +40,4 @@ from jax.scipy.stats import truncnorm as truncnorm from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata from jax.scipy.stats import vonmises as vonmises +from jax.scipy.stats import wrapcauchy as wrapcauchy diff --git a/jax/scipy/stats/wrapcauchy.py b/jax/scipy/stats/wrapcauchy.py new file mode 100644 index 000000000..6e2420c5a --- /dev/null +++ b/jax/scipy/stats/wrapcauchy.py @@ -0,0 +1,21 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +from jax._src.scipy.stats.wrapcauchy import ( + logpdf as logpdf, + pdf as pdf, +) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index afa0bee57..151cf8549 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -80,6 +80,42 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(2) + def testWrappedCauchyPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + rng_uniform = jtu.rand_uniform(self.rng(), low=1e-3, high=1 - 1e-3) + scipy_fun = osp_stats.wrapcauchy.pdf + lax_fun = lsp_stats.wrapcauchy.pdf + + def args_maker(): + x = rng(shapes[0], dtypes[0]) + c = rng_uniform(shapes[1], dtypes[1]) + return [x, c] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol={np.float32: 1e-4}) + self._CompileAndCheck(lax_fun, args_maker, tol={np.float32: 1e-5, + np.float64: 1e-11}) + + @genNamedParametersNArgs(2) + def testWrappedCauchyLogPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + rng_uniform = jtu.rand_uniform(self.rng(), low=1e-3, high=1 - 1e-3) + scipy_fun = osp_stats.wrapcauchy.logpdf + lax_fun = lsp_stats.wrapcauchy.logpdf + + def args_maker(): + x = rng(shapes[0], dtypes[0]) + c = rng_uniform(shapes[1], dtypes[1]) + return [x, c] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol={np.float32: 1e-4}) + self._CompileAndCheck(lax_fun, args_maker, tol={np.float32: 1e-5, + np.float64: 1e-13}) + @genNamedParametersNArgs(3) def testPoissonLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng())