2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2020 The JAX Authors.
|
2020-10-16 16:55:14 -04:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import scipy.stats as osp_stats
|
|
|
|
|
|
|
|
from jax import lax
|
2023-03-08 10:29:04 -08:00
|
|
|
import jax.numpy as jnp
|
2022-03-07 12:25:01 -08:00
|
|
|
from jax._src.lax.lax import _const as _lax_const
|
2024-01-24 14:14:19 -08:00
|
|
|
from jax._src.numpy.util import implements, promote_args_inexact
|
2020-10-16 16:55:14 -04:00
|
|
|
from jax.scipy.special import xlog1py
|
2022-10-12 13:42:11 -07:00
|
|
|
from jax._src.typing import Array, ArrayLike
|
|
|
|
|
2020-10-16 16:55:14 -04:00
|
|
|
|
2024-01-24 14:14:19 -08:00
|
|
|
@implements(osp_stats.geom.logpmf, update_doc=False)
|
2022-10-12 13:42:11 -07:00
|
|
|
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
2023-03-13 12:18:36 -07:00
|
|
|
k, p, loc = promote_args_inexact("geom.logpmf", k, p, loc)
|
2022-03-07 12:25:01 -08:00
|
|
|
zero = _lax_const(k, 0)
|
|
|
|
one = _lax_const(k, 1)
|
2020-10-16 16:55:14 -04:00
|
|
|
x = lax.sub(k, loc)
|
|
|
|
log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p)
|
|
|
|
return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
|
|
|
|
|
2022-10-12 13:42:11 -07:00
|
|
|
|
2024-01-24 14:14:19 -08:00
|
|
|
@implements(osp_stats.geom.pmf, update_doc=False)
|
2022-10-12 13:42:11 -07:00
|
|
|
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
2020-10-16 16:55:14 -04:00
|
|
|
return jnp.exp(logpmf(k, p, loc))
|