mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Added vonmises pdf, logpdf & respective tests.
Added vonmises pdf, logpdf & respective tests. Altered type-hinting, added pi as a _lax_const Changed lax constant pi to be created in _pdf instead of passed arg. Changed name in __init__.py Fixed bug in tests. Review related alterations. Review related changes. Added vonmises pdf, logpdf & respective tests. Added vonmises pdf, logpdf & respective tests. Altered type-hinting, added pi as a _lax_const Changed lax constant pi to be created in _pdf instead of passed arg. Changed name in __init__.py Fixed bug in tests. Review related alterations. PR PR PR
This commit is contained in:
parent
5832dfd812
commit
351e1874ab
@ -357,3 +357,12 @@ jax.scipy.stats.gaussian_kde
|
||||
gaussian_kde.resample
|
||||
gaussian_kde.pdf
|
||||
gaussian_kde.logpdf
|
||||
|
||||
jax.scipy.stats.vonmises
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: jax.scipy.stats.vonmises
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
logpdf
|
||||
pdf
|
31
jax/_src/scipy/stats/vonmises.py
Normal file
31
jax/_src/scipy/stats/vonmises.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import scipy.stats as osp_stats
|
||||
from jax import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.util import _wraps, _promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
@_wraps(osp_stats.vonmises.logpdf, update_doc=False)
|
||||
def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array:
|
||||
x, kappa = _promote_args_inexact('vonmises.pdf', x, kappa)
|
||||
zero = _lax_const(kappa, 0)
|
||||
return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * jnp.pi * lax.bessel_i0e(kappa)), jnp.nan)
|
||||
|
||||
@_wraps(osp_stats.vonmises.pdf, update_doc=False)
|
||||
def pdf(x: ArrayLike, kappa: ArrayLike) -> Array:
|
||||
return lax.exp(logpdf(x, kappa))
|
@ -35,3 +35,4 @@ from jax.scipy.stats import gennorm as gennorm
|
||||
from jax.scipy.stats import truncnorm as truncnorm
|
||||
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
|
||||
from jax._src.scipy.stats._core import mode as mode
|
||||
from jax.scipy.stats import vonmises as vonmises
|
||||
|
18
jax/scipy/stats/vonmises.py
Normal file
18
jax/scipy/stats/vonmises.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.scipy.stats.vonmises import (
|
||||
logpdf as logpdf,
|
||||
pdf as pdf,
|
||||
)
|
@ -49,6 +49,38 @@ def genNamedParametersNArgs(n):
|
||||
class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed scipy.stats implementations"""
|
||||
|
||||
@genNamedParametersNArgs(2)
|
||||
def testVonMisesPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.vonmises.pdf
|
||||
lax_fun = lsp_stats.vonmises.pdf
|
||||
|
||||
def args_maker():
|
||||
x, kappa = map(rng, shapes, dtypes)
|
||||
kappa = np.where(kappa < 0, kappa * -1, kappa).astype(kappa.dtype)
|
||||
return [x, kappa]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(2)
|
||||
def testVonMisesLogPdf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
scipy_fun = osp_stats.vonmises.pdf
|
||||
lax_fun = lsp_stats.vonmises.pdf
|
||||
|
||||
def args_maker():
|
||||
x, kappa = map(rng, shapes, dtypes)
|
||||
kappa = np.where(kappa < 0, kappa * -1, kappa).astype(kappa.dtype)
|
||||
return [x, kappa]
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testPoissonLogPmf(self, shapes, dtypes):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user