Merge pull request #18126 from niqodea:wrapcauchy

PiperOrigin-RevId: 574572631
This commit is contained in:
jax authors 2023-10-18 13:18:20 -07:00
commit 3778265e2e
5 changed files with 112 additions and 0 deletions

View File

@ -446,3 +446,12 @@ jax.scipy.stats.vonmises
logpdf
pdf
jax.scipy.stats.wrapcauchy
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: jax.scipy.stats.wrapcauchy
.. autosummary::
:toctree: _autosummary
logpdf
pdf

View File

@ -0,0 +1,39 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
@_wraps(osp_stats.wrapcauchy.logpdf, update_doc=False)
def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
x, c = promote_args_inexact('wrapcauchy.logpdf', x, c)
return jnp.where(
lax.gt(c, _lax_const(c, 0)) & lax.lt(c, _lax_const(c, 1)),
jnp.where(
lax.ge(x, _lax_const(x, 0)) & lax.le(x, _lax_const(x, jnp.pi * 2)),
jnp.log(1 - c * c) - jnp.log(2 * jnp.pi) - jnp.log(1 + c * c - 2 * c * jnp.cos(x)),
-jnp.inf,
),
jnp.nan,
)
@_wraps(osp_stats.wrapcauchy.pdf, update_doc=False)
def pdf(x: ArrayLike, c: ArrayLike) -> Array:
return lax.exp(logpdf(x, c))

View File

@ -40,3 +40,4 @@ from jax.scipy.stats import truncnorm as truncnorm
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata
from jax.scipy.stats import vonmises as vonmises
from jax.scipy.stats import wrapcauchy as wrapcauchy

View File

@ -0,0 +1,21 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.scipy.stats.wrapcauchy import (
logpdf as logpdf,
pdf as pdf,
)

View File

@ -80,6 +80,48 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testWrappedCauchyPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
rng_uniform = jtu.rand_uniform(self.rng(), low=1e-3, high=1 - 1e-3)
scipy_fun = osp_stats.wrapcauchy.pdf
lax_fun = lsp_stats.wrapcauchy.pdf
def args_maker():
x = rng(shapes[0], dtypes[0])
c = rng_uniform(shapes[1], dtypes[1])
return [x, c]
tol = {
np.float32: 1e-4 if jtu.test_device_matches(["tpu"]) else 1e-5,
np.float64: 1e-11,
}
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
check_dtypes=False, tol=tol)
self._CompileAndCheck(lax_fun, args_maker, tol=tol)
@genNamedParametersNArgs(2)
def testWrappedCauchyLogPdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
rng_uniform = jtu.rand_uniform(self.rng(), low=1e-3, high=1 - 1e-3)
scipy_fun = osp_stats.wrapcauchy.logpdf
lax_fun = lsp_stats.wrapcauchy.logpdf
def args_maker():
x = rng(shapes[0], dtypes[0])
c = rng_uniform(shapes[1], dtypes[1])
return [x, c]
tol = {
np.float32: 1e-4 if jtu.test_device_matches(["tpu"]) else 1e-5,
np.float64: 1e-11,
}
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
check_dtypes=False, tol=tol)
self._CompileAndCheck(lax_fun, args_maker, tol=tol)
@genNamedParametersNArgs(3)
def testPoissonLogPmf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())