mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
feat: add wrapcauchy logpdf and pdf
This commit is contained in:
parent
8bf605a7db
commit
890b762a3e
@ -446,3 +446,12 @@ jax.scipy.stats.vonmises
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
|
||||
jax.scipy.stats.wrapcauchy
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.wrapcauchy
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
logpdf
|
||||
pdf
|
||||
|
39
jax/_src/scipy/stats/wrapcauchy.py
Normal file
39
jax/_src/scipy/stats/wrapcauchy.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.wrapcauchy.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
|
||||
x, c = promote_args_inexact('wrapcauchy.logpdf', x, c)
|
||||
return jnp.where(
|
||||
lax.gt(c, _lax_const(c, 0)) & lax.lt(c, _lax_const(c, 1)),
|
||||
jnp.where(
|
||||
lax.ge(x, _lax_const(x, 0)) & lax.le(x, _lax_const(x, jnp.pi * 2)),
|
||||
jnp.log(1 - c * c) - jnp.log(2 * jnp.pi) - jnp.log(1 + c * c - 2 * c * jnp.cos(x)),
|
||||
-jnp.inf,
|
||||
),
|
||||
jnp.nan,
|
||||
)
|
||||
|
||||
@_wraps(osp_stats.wrapcauchy.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, c: ArrayLike) -> Array:
|
||||
return lax.exp(logpdf(x, c))
|
@ -40,3 +40,4 @@ from jax.scipy.stats import truncnorm as truncnorm
|
||||
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
|
||||
from jax._src.scipy.stats._core import mode as mode, rankdata as rankdata
|
||||
from jax.scipy.stats import vonmises as vonmises
|
||||
from jax.scipy.stats import wrapcauchy as wrapcauchy
|
||||
|
21
jax/scipy/stats/wrapcauchy.py
Normal file
21
jax/scipy/stats/wrapcauchy.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.scipy.stats.wrapcauchy import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
@ -80,6 +80,42 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(2)
|
||||
def testWrappedCauchyPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
rng_uniform = jtu.rand_uniform(self.rng(), low=1e-3, high=1 - 1e-3)
|
||||
scipy_fun = osp_stats.wrapcauchy.pdf
|
||||
lax_fun = lsp_stats.wrapcauchy.pdf
|
||||
|
||||
def args_maker():
|
||||
x = rng(shapes[0], dtypes[0])
|
||||
c = rng_uniform(shapes[1], dtypes[1])
|
||||
return [x, c]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol={np.float32: 1e-4})
|
||||
self._CompileAndCheck(lax_fun, args_maker, tol={np.float32: 1e-5,
|
||||
np.float64: 1e-11})
|
||||
|
||||
@genNamedParametersNArgs(2)
|
||||
def testWrappedCauchyLogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
rng_uniform = jtu.rand_uniform(self.rng(), low=1e-3, high=1 - 1e-3)
|
||||
scipy_fun = osp_stats.wrapcauchy.logpdf
|
||||
lax_fun = lsp_stats.wrapcauchy.logpdf
|
||||
|
||||
def args_maker():
|
||||
x = rng(shapes[0], dtypes[0])
|
||||
c = rng_uniform(shapes[1], dtypes[1])
|
||||
return [x, c]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol={np.float32: 1e-4})
|
||||
self._CompileAndCheck(lax_fun, args_maker, tol={np.float32: 1e-5,
|
||||
np.float64: 1e-13})
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testPoissonLogPmf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user