Add beta function

This commit is contained in:
Ben West 2023-11-05 15:37:38 -08:00
parent 1126945da8
commit 02f6fcb9da
4 changed files with 11 additions and 0 deletions

View File

@ -131,6 +131,7 @@ jax.scipy.special
:toctree: _autosummary
bernoulli
beta
betainc
betaln
digamma

View File

@ -56,6 +56,12 @@ betaln = _wraps(
)(_betaln_impl)
@_wraps(osp_special.beta, module='scipy.special')
def beta(x: ArrayLike, y: ArrayLike) -> Array:
x, y = promote_args_inexact("beta", x, y)
return lax.exp(betaln(x, y))
@_wraps(osp_special.betainc, module='scipy.special')
def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
a, b, x = promote_args_inexact("betainc", a, b, x)

View File

@ -19,6 +19,7 @@ from jax._src.scipy.special import (
bernoulli as bernoulli,
betainc as betainc,
betaln as betaln,
beta as beta,
bessel_jn as bessel_jn,
digamma as digamma,
entr as entr,

View File

@ -52,6 +52,9 @@ int_dtypes = jtu.dtypes.integer
# don't expect numerical gradient tests to pass for inputs very close to 0.
JAX_SPECIAL_FUNCTION_RECORDS = [
op_record(
"beta", 2, float_dtypes, jtu.rand_positive, False
),
op_record(
"betaln", 2, float_dtypes, jtu.rand_positive, False
),