2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2021 The JAX Authors.
|
2021-02-03 15:09:21 -05:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License
|
|
|
|
|
|
|
|
|
|
|
|
import scipy.stats as osp_stats
|
|
|
|
|
|
|
|
from jax import lax
|
2022-03-07 12:25:01 -08:00
|
|
|
from jax._src.lax.lax import _const as _lax_const
|
2021-02-03 15:09:21 -05:00
|
|
|
from jax._src.numpy.util import _wraps
|
2022-03-02 09:13:58 -08:00
|
|
|
from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf
|
2022-10-12 13:42:11 -07:00
|
|
|
from jax._src.typing import Array, ArrayLike
|
2021-02-03 15:09:21 -05:00
|
|
|
|
|
|
|
|
|
|
|
@_wraps(osp_stats.chi2.logpdf, update_doc=False)
|
2022-10-12 13:42:11 -07:00
|
|
|
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
2023-03-07 18:30:45 -08:00
|
|
|
x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
|
|
|
|
one = _lax_const(x, 1)
|
|
|
|
two = _lax_const(x, 2)
|
|
|
|
y = lax.div(lax.sub(x, loc), scale)
|
|
|
|
df_on_two = lax.div(df, two)
|
2021-02-03 15:09:21 -05:00
|
|
|
|
2023-03-07 18:30:45 -08:00
|
|
|
kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))
|
2021-02-03 15:09:21 -05:00
|
|
|
|
2023-03-07 18:30:45 -08:00
|
|
|
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
|
2021-02-03 15:09:21 -05:00
|
|
|
|
2023-03-07 18:30:45 -08:00
|
|
|
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
|
|
|
|
return where(lax.lt(x, loc), -inf, log_probs)
|
2021-02-03 15:09:21 -05:00
|
|
|
|
|
|
|
@_wraps(osp_stats.chi2.pdf, update_doc=False)
|
2022-10-12 13:42:11 -07:00
|
|
|
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
2023-03-07 18:30:45 -08:00
|
|
|
return lax.exp(logpdf(x, df, loc, scale))
|