mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add beta function
This commit is contained in:
parent
1126945da8
commit
02f6fcb9da
@ -131,6 +131,7 @@ jax.scipy.special
|
||||
:toctree: _autosummary
|
||||
|
||||
bernoulli
|
||||
beta
|
||||
betainc
|
||||
betaln
|
||||
digamma
|
||||
|
@ -56,6 +56,12 @@ betaln = _wraps(
|
||||
)(_betaln_impl)
|
||||
|
||||
|
||||
@_wraps(osp_special.beta, module='scipy.special')
|
||||
def beta(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
x, y = promote_args_inexact("beta", x, y)
|
||||
return lax.exp(betaln(x, y))
|
||||
|
||||
|
||||
@_wraps(osp_special.betainc, module='scipy.special')
|
||||
def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
|
||||
a, b, x = promote_args_inexact("betainc", a, b, x)
|
||||
|
@ -19,6 +19,7 @@ from jax._src.scipy.special import (
|
||||
bernoulli as bernoulli,
|
||||
betainc as betainc,
|
||||
betaln as betaln,
|
||||
beta as beta,
|
||||
bessel_jn as bessel_jn,
|
||||
digamma as digamma,
|
||||
entr as entr,
|
||||
|
@ -52,6 +52,9 @@ int_dtypes = jtu.dtypes.integer
|
||||
# don't expect numerical gradient tests to pass for inputs very close to 0.
|
||||
|
||||
JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
op_record(
|
||||
"beta", 2, float_dtypes, jtu.rand_positive, False
|
||||
),
|
||||
op_record(
|
||||
"betaln", 2, float_dtypes, jtu.rand_positive, False
|
||||
),
|
||||
|
Loading…
x
Reference in New Issue
Block a user