feat: add wrapcauchy logpdf and pdf

This commit is contained in:
Nicola De Angeli 2023-10-16 03:45:24 +02:00
parent 8bf605a7db
commit 890b762a3e
5 changed files with 106 additions and 0 deletions

View File

@ -446,3 +446,12 @@ jax.scipy.stats.vonmises
logpdf
pdf
jax.scipy.stats.wrapcauchy
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.wrapcauchy
.. autosummary::
:toctree: _autosummary
logpdf
pdf

View File

@ -0,0 +1,39 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.wrapcauchy.logpdf, update_doc=False)
def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
x, c = promote_args_inexact('wrapcauchy.logpdf', x, c)
return jnp.where(
lax.gt(c, _lax_const(c, 0)) & lax.lt(c, _lax_const(c, 1)),
jnp.where(
lax.ge(x, _lax_const(x, 0)) & lax.le(x, _lax_const(x, jnp.pi * 2)),
jnp.log(1 - c * c) - jnp.log(2 * jnp.pi) - jnp.log(1 + c * c - 2 * c * jnp.cos(x)),
-jnp.inf,
),
jnp.nan,
)
@_wraps(osp_stats.wrapcauchy.pdf, update_doc=False)
def pdf(x: ArrayLike, c: ArrayLike) -> Array:
return lax.exp(logpdf(x, c))

View File

@ -40,3 +40,4 @@ from jax.scipy.stats import truncnorm as truncnorm
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata
from jax.scipy.stats import vonmises as vonmises
from jax.scipy.stats import wrapcauchy as wrapcauchy

View File

@ -0,0 +1,21 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.stats.wrapcauchy import (
logpdf as logpdf,
pdf as pdf,
)

View File

@ -80,6 +80,42 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testWrappedCauchyPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
rng_uniform = jtu.rand_uniform(self.rng(), low=1e-3, high=1 - 1e-3)
scipy_fun = osp_stats.wrapcauchy.pdf
lax_fun = lsp_stats.wrapcauchy.pdf
def args_maker():
x = rng(shapes[0], dtypes[0])
c = rng_uniform(shapes[1], dtypes[1])
return [x, c]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol={np.float32: 1e-4})
self._CompileAndCheck(lax_fun, args_maker, tol={np.float32: 1e-5,
np.float64: 1e-11})
@genNamedParametersNArgs(2)
def testWrappedCauchyLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
rng_uniform = jtu.rand_uniform(self.rng(), low=1e-3, high=1 - 1e-3)
scipy_fun = osp_stats.wrapcauchy.logpdf
lax_fun = lsp_stats.wrapcauchy.logpdf
def args_maker():
x = rng(shapes[0], dtypes[0])
c = rng_uniform(shapes[1], dtypes[1])
return [x, c]
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol={np.float32: 1e-4})
self._CompileAndCheck(lax_fun, args_maker, tol={np.float32: 1e-5,
np.float64: 1e-13})
@genNamedParametersNArgs(3)
def testPoissonLogPmf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())