Merge pull request #18499 from renecotyfanboy:hyp1f1_poch

PiperOrigin-RevId: 582765493
This commit is contained in:
jax authors 2023-11-15 12:25:59 -08:00
commit 840b5c5d6d
4 changed files with 186 additions and 1 deletions

View File

@ -147,6 +147,7 @@ jax.scipy.special
gammainc
gammaincc
gammaln
hyp1f1
i0
i0e
i1
@ -159,6 +160,7 @@ jax.scipy.special
multigammaln
ndtr
ndtri
poch
polygamma
spence
sph_harm

View File

@ -1719,3 +1719,183 @@ def bernoulli(n: int) -> Array:
k = jnp.arange(2, 50, dtype=bn.dtype) # Choose 50 because 2 ** -50 < 1E-15
q2 = jnp.sum(k[:, None] ** -m[None, :], axis=0)
return bn.at[4::2].set(q1 * (1 + q2))
@custom_derivatives.custom_jvp
@_wraps(osp_special.poch, module='scipy.special', lax_description="""\
The JAX version only accepts positive and real inputs.""")
def poch(z: ArrayLike, m: ArrayLike) -> Array:
# Factorial definition when m is close to an integer, otherwise gamma definition.
z, m = promote_args_inexact("poch", z, m)
return jnp.where(m == 0., jnp.array(1, dtype=z.dtype), gamma(z + m) / gamma(z))
def _poch_z_derivative(z, m):
"""
Defined in :
https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/01/
"""
return (digamma(z + m) - digamma(z)) * poch(z, m)
def _poch_m_derivative(z, m):
"""
Defined in :
https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/02/
"""
return digamma(z + m) * poch(z, m)
poch.defjvps(
lambda z_dot, primal_out, z, m: _poch_z_derivative(z, m) * z_dot,
lambda m_dot, primal_out, z, m: _poch_m_derivative(z, m) * m_dot,
)
def _hyp1f1_serie(a, b, x):
"""
Compute the 1F1 hypergeometric function using the taylor expansion
See Eq. 3.2 and associated method (a) from PEARSON, OLVER & PORTER 2014
https://doi.org/10.48550/arXiv.1407.7786
"""
def body(state):
serie, k, term = state
serie += term
term *= (a + k) / (b + k) * x / (k + 1)
k += 1
return serie, k, term
def cond(state):
serie, k, term = state
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8)
init = 1, 1, a / b * x
return lax.while_loop(cond, body, init)[0]
def _hyp1f1_asymptotic(a, b, x):
"""
Compute the 1F1 hypergeometric function using asymptotic expansion
See Eq. 3.8 and simplification for real inputs from PEARSON, OLVER & PORTER 2014
https://doi.org/10.48550/arXiv.1407.7786
"""
def body(state):
serie, k, term = state
serie += term
term *= (b - a + k) * (1 - a + k) / (k + 1) / x
k += 1
return serie, k, term
def cond(state):
serie, k, term = state
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8)
init = 1, 1, (b - a) * (1 - a) / x
serie = lax.while_loop(cond, body, init)[0]
return gamma(b) / gamma(a) * lax.exp(x) * x ** (a - b) * serie
@jit
@jnp.vectorize
def _hyp1f1_a_derivative(a, b, x):
"""
Define it as a serie using :
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/01/
"""
def body(state):
serie, k, term = state
serie += term * (digamma(a + k) - digamma(a))
term *= (a + k) / (b + k) * x / (k + 1)
k += 1
return serie, k, term
def cond(state):
serie, k, term = state
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15)
init = 0, 1, a / b * x
return lax.while_loop(cond, body, init)[0]
@jit
@jnp.vectorize
def _hyp1f1_b_derivative(a, b, x):
"""
Define it as a serie using :
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/02/
"""
def body(state):
serie, k, term = state
serie += term * (digamma(b) - digamma(b + k))
term *= (a + k) / (b + k) * x / (k + 1)
k += 1
return serie, k, term
def cond(state):
serie, k, term = state
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15)
init = 0, 1, a / b * x
return lax.while_loop(cond, body, init)[0]
@jit
def _hyp1f1_x_derivative(a, b, x):
"""
Define it as a serie using :
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/04/
"""
return a / b * hyp1f1(a + 1, b + 1, x)
@custom_derivatives.custom_jvp
@jit
@jnp.vectorize
@_wraps(osp_special.hyp1f1, module='scipy.special', lax_description="""\
The JAX version only accepts positive and real inputs. Values of a, b and x
leading to high values of 1F1 might be erroneous, considering enabling double
precision. Convention for a = b = 0 is 1, unlike in scipy's implementation.""")
def hyp1f1(a, b, x):
"""
Implementation of the 1F1 hypergeometric function for real valued inputs
Backed by https://doi.org/10.48550/arXiv.1407.7786
There is room for improvement in the implementation using recursion to
evaluate lower values of hyp1f1 when a or b or both are > 60-80
"""
a, b, x = promote_args_inexact('hyp1f1', a, b, x)
result = lax.cond(lax.abs(x) < 100, _hyp1f1_serie, _hyp1f1_asymptotic, a, b, x)
index = (a == 0) * 1 + ((a == b) & (a != 0)) * 2 + ((b == 0) & (a != 0)) * 3
return lax.select_n(index,
result,
jnp.array(1, dtype=x.dtype),
jnp.exp(x),
jnp.array(jnp.inf, dtype=x.dtype))
hyp1f1.defjvps(
lambda a_dot, primal_out, a, b, x: _hyp1f1_a_derivative(a, b, x) * a_dot,
lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot,
lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot
)

View File

@ -54,4 +54,6 @@ from jax._src.scipy.special import (
zeta as zeta,
kl_div as kl_div,
rel_entr as rel_entr,
poch as poch,
hyp1f1 as hyp1f1,
)

View File

@ -141,7 +141,8 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
op_record(
"rel_entr", 2, float_dtypes, jtu.rand_positive, True,
),
op_record("poch", 2, float_dtypes, jtu.rand_positive, True),
op_record("hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True)
]