2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2020-10-16 16:55:14 -04:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import scipy.stats as osp_stats
|
|
|
|
|
|
|
|
from jax import lax
|
2022-03-07 12:25:01 -08:00
|
|
|
from jax._src.lax.lax import _const as _lax_const
|
2020-10-16 18:08:20 -04:00
|
|
|
from jax._src.numpy.util import _wraps
|
2022-03-02 09:13:58 -08:00
|
|
|
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or
|
2022-10-12 13:42:11 -07:00
|
|
|
from jax._src.typing import Array, ArrayLike
|
2023-03-07 18:30:45 -08:00
|
|
|
from jax.scipy.special import betaln, xlogy, xlog1py
|
2020-10-16 16:55:14 -04:00
|
|
|
|
|
|
|
|
|
|
|
@_wraps(osp_stats.beta.logpdf, update_doc=False)
|
2022-10-12 13:42:11 -07:00
|
|
|
def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
|
|
|
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
2020-10-16 16:55:14 -04:00
|
|
|
x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale)
|
2022-03-07 12:25:01 -08:00
|
|
|
one = _lax_const(x, 1)
|
2020-10-16 16:55:14 -04:00
|
|
|
shape_term = lax.neg(betaln(a, b))
|
|
|
|
y = lax.div(lax.sub(x, loc), scale)
|
2021-08-17 09:52:11 -07:00
|
|
|
log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
|
|
|
|
xlog1py(lax.sub(b, one), lax.neg(y)))
|
2020-10-16 16:55:14 -04:00
|
|
|
log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
|
|
|
|
return where(logical_or(lax.gt(x, lax.add(loc, scale)),
|
|
|
|
lax.lt(x, loc)), -inf, log_probs)
|
|
|
|
|
2022-10-12 13:42:11 -07:00
|
|
|
|
2020-10-16 16:55:14 -04:00
|
|
|
@_wraps(osp_stats.beta.pdf, update_doc=False)
|
2022-10-12 13:42:11 -07:00
|
|
|
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
|
|
|
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
2020-10-16 16:55:14 -04:00
|
|
|
return lax.exp(logpdf(x, a, b, loc, scale))
|