diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 7cf3a71cd..72ef6db35 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -357,3 +357,12 @@ jax.scipy.stats.gaussian_kde gaussian_kde.resample gaussian_kde.pdf gaussian_kde.logpdf + +jax.scipy.stats.vonmises +~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: jax.scipy.stats.vonmises +.. autosummary:: + :toctree: _autosummary + + logpdf + pdf \ No newline at end of file diff --git a/jax/_src/scipy/stats/vonmises.py b/jax/_src/scipy/stats/vonmises.py new file mode 100644 index 000000000..b6a447c65 --- /dev/null +++ b/jax/_src/scipy/stats/vonmises.py @@ -0,0 +1,31 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import scipy.stats as osp_stats +from jax import lax +from jax._src.lax.lax import _const as _lax_const +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.util import _wraps, _promote_args_inexact +from jax._src.typing import Array, ArrayLike + +@_wraps(osp_stats.vonmises.logpdf, update_doc=False) +def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array: + x, kappa = _promote_args_inexact('vonmises.pdf', x, kappa) + zero = _lax_const(kappa, 0) + return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * jnp.pi * lax.bessel_i0e(kappa)), jnp.nan) + +@_wraps(osp_stats.vonmises.pdf, update_doc=False) +def pdf(x: ArrayLike, kappa: ArrayLike) -> Array: + return lax.exp(logpdf(x, kappa)) diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 3d08c7124..56b815412 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -35,3 +35,4 @@ from jax.scipy.stats import gennorm as gennorm from jax.scipy.stats import truncnorm as truncnorm from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde from jax._src.scipy.stats._core import mode as mode +from jax.scipy.stats import vonmises as vonmises diff --git a/jax/scipy/stats/vonmises.py b/jax/scipy/stats/vonmises.py new file mode 100644 index 000000000..1277b5777 --- /dev/null +++ b/jax/scipy/stats/vonmises.py @@ -0,0 +1,18 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.scipy.stats.vonmises import ( + logpdf as logpdf, + pdf as pdf, +) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index a5f4489a9..92b246a21 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -49,6 +49,38 @@ def genNamedParametersNArgs(n): class LaxBackedScipyStatsTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" + @genNamedParametersNArgs(2) + def testVonMisesPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.vonmises.pdf + lax_fun = lsp_stats.vonmises.pdf + + def args_maker(): + x, kappa = map(rng, shapes, dtypes) + kappa = np.where(kappa < 0, kappa * -1, kappa).astype(kappa.dtype) + return [x, kappa] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(2) + def testVonMisesLogPdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.vonmises.pdf + lax_fun = lsp_stats.vonmises.pdf + + def args_maker(): + x, kappa = map(rng, shapes, dtypes) + kappa = np.where(kappa < 0, kappa * -1, kappa).astype(kappa.dtype) + return [x, kappa] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) def testPoissonLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng())