mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Add jax.scipy.stats.binom
This commit is contained in:
parent
6bc74d2a98
commit
30d1a8a80f
42
jax/_src/scipy/stats/binom.py
Normal file
42
jax/_src/scipy/stats/binom.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
|
||||
import scipy.stats as osp_stats
|
||||
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||||
from jax._src.scipy.special import gammaln, xlogy, xlog1py
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@_wraps(osp_stats.nbinom.logpmf, update_doc=False)
|
||||
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
"""JAX implementation of scipy.stats.binom.logpmf."""
|
||||
k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc)
|
||||
y = lax.sub(k, loc)
|
||||
comb_term = lax.sub(
|
||||
gammaln(n + 1),
|
||||
lax.add(gammaln(y + 1), gammaln(n - y + 1))
|
||||
)
|
||||
log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p)))
|
||||
log_probs = lax.add(comb_term, log_linear_term)
|
||||
return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs)
|
||||
|
||||
|
||||
@_wraps(osp_stats.nbinom.pmf, update_doc=False)
|
||||
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
"""JAX implementation of scipy.stats.binom.pmf."""
|
||||
return lax.exp(logpmf(k, n, p, loc))
|
@ -17,6 +17,7 @@
|
||||
|
||||
from jax.scipy.stats import bernoulli as bernoulli
|
||||
from jax.scipy.stats import beta as beta
|
||||
from jax.scipy.stats import binom as binom
|
||||
from jax.scipy.stats import cauchy as cauchy
|
||||
from jax.scipy.stats import dirichlet as dirichlet
|
||||
from jax.scipy.stats import expon as expon
|
||||
|
18
jax/scipy/stats/binom.py
Normal file
18
jax/scipy/stats/binom.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.scipy.stats.binom import (
|
||||
logpmf as logpmf,
|
||||
pmf as pmf,
|
||||
)
|
@ -971,6 +971,27 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testBinomLogPmf(self, shapes, dtypes):
|
||||
rng = jtu.rand_positive(self.rng())
|
||||
scipy_fun = osp_stats.binom.logpmf
|
||||
lax_fun = lsp_stats.binom.logpmf
|
||||
|
||||
def args_maker():
|
||||
k, n, logit, loc = map(rng, shapes, dtypes)
|
||||
k = np.floor(np.abs(k))
|
||||
n = np.ceil(np.abs(n))
|
||||
p = expit(logit)
|
||||
loc = np.floor(loc)
|
||||
return [k, n, p, loc]
|
||||
|
||||
tol = {np.float32: 1e-6, np.float64: 1e-8}
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
def testIssue972(self):
|
||||
self.assertAllClose(
|
||||
np.ones((4,), np.float32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user