mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #18499 from renecotyfanboy:hyp1f1_poch
PiperOrigin-RevId: 582765493
This commit is contained in:
commit
840b5c5d6d
@ -147,6 +147,7 @@ jax.scipy.special
|
||||
gammainc
|
||||
gammaincc
|
||||
gammaln
|
||||
hyp1f1
|
||||
i0
|
||||
i0e
|
||||
i1
|
||||
@ -159,6 +160,7 @@ jax.scipy.special
|
||||
multigammaln
|
||||
ndtr
|
||||
ndtri
|
||||
poch
|
||||
polygamma
|
||||
spence
|
||||
sph_harm
|
||||
|
@ -1719,3 +1719,183 @@ def bernoulli(n: int) -> Array:
|
||||
k = jnp.arange(2, 50, dtype=bn.dtype) # Choose 50 because 2 ** -50 < 1E-15
|
||||
q2 = jnp.sum(k[:, None] ** -m[None, :], axis=0)
|
||||
return bn.at[4::2].set(q1 * (1 + q2))
|
||||
|
||||
|
||||
@custom_derivatives.custom_jvp
|
||||
@_wraps(osp_special.poch, module='scipy.special', lax_description="""\
|
||||
The JAX version only accepts positive and real inputs.""")
|
||||
def poch(z: ArrayLike, m: ArrayLike) -> Array:
|
||||
# Factorial definition when m is close to an integer, otherwise gamma definition.
|
||||
z, m = promote_args_inexact("poch", z, m)
|
||||
|
||||
return jnp.where(m == 0., jnp.array(1, dtype=z.dtype), gamma(z + m) / gamma(z))
|
||||
|
||||
|
||||
def _poch_z_derivative(z, m):
|
||||
"""
|
||||
Defined in :
|
||||
https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/01/
|
||||
"""
|
||||
|
||||
return (digamma(z + m) - digamma(z)) * poch(z, m)
|
||||
|
||||
|
||||
def _poch_m_derivative(z, m):
|
||||
"""
|
||||
Defined in :
|
||||
https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/02/
|
||||
"""
|
||||
|
||||
return digamma(z + m) * poch(z, m)
|
||||
|
||||
|
||||
poch.defjvps(
|
||||
lambda z_dot, primal_out, z, m: _poch_z_derivative(z, m) * z_dot,
|
||||
lambda m_dot, primal_out, z, m: _poch_m_derivative(z, m) * m_dot,
|
||||
)
|
||||
|
||||
|
||||
def _hyp1f1_serie(a, b, x):
|
||||
"""
|
||||
Compute the 1F1 hypergeometric function using the taylor expansion
|
||||
See Eq. 3.2 and associated method (a) from PEARSON, OLVER & PORTER 2014
|
||||
https://doi.org/10.48550/arXiv.1407.7786
|
||||
"""
|
||||
|
||||
def body(state):
|
||||
serie, k, term = state
|
||||
serie += term
|
||||
term *= (a + k) / (b + k) * x / (k + 1)
|
||||
k += 1
|
||||
|
||||
return serie, k, term
|
||||
|
||||
def cond(state):
|
||||
serie, k, term = state
|
||||
|
||||
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8)
|
||||
|
||||
init = 1, 1, a / b * x
|
||||
|
||||
return lax.while_loop(cond, body, init)[0]
|
||||
|
||||
|
||||
def _hyp1f1_asymptotic(a, b, x):
|
||||
"""
|
||||
Compute the 1F1 hypergeometric function using asymptotic expansion
|
||||
See Eq. 3.8 and simplification for real inputs from PEARSON, OLVER & PORTER 2014
|
||||
https://doi.org/10.48550/arXiv.1407.7786
|
||||
"""
|
||||
|
||||
def body(state):
|
||||
serie, k, term = state
|
||||
serie += term
|
||||
term *= (b - a + k) * (1 - a + k) / (k + 1) / x
|
||||
k += 1
|
||||
|
||||
return serie, k, term
|
||||
|
||||
def cond(state):
|
||||
serie, k, term = state
|
||||
|
||||
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-8)
|
||||
|
||||
init = 1, 1, (b - a) * (1 - a) / x
|
||||
serie = lax.while_loop(cond, body, init)[0]
|
||||
|
||||
return gamma(b) / gamma(a) * lax.exp(x) * x ** (a - b) * serie
|
||||
|
||||
|
||||
@jit
|
||||
@jnp.vectorize
|
||||
def _hyp1f1_a_derivative(a, b, x):
|
||||
"""
|
||||
Define it as a serie using :
|
||||
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/01/
|
||||
"""
|
||||
|
||||
def body(state):
|
||||
serie, k, term = state
|
||||
serie += term * (digamma(a + k) - digamma(a))
|
||||
term *= (a + k) / (b + k) * x / (k + 1)
|
||||
k += 1
|
||||
|
||||
return serie, k, term
|
||||
|
||||
def cond(state):
|
||||
serie, k, term = state
|
||||
|
||||
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15)
|
||||
|
||||
init = 0, 1, a / b * x
|
||||
|
||||
return lax.while_loop(cond, body, init)[0]
|
||||
|
||||
|
||||
@jit
|
||||
@jnp.vectorize
|
||||
def _hyp1f1_b_derivative(a, b, x):
|
||||
"""
|
||||
Define it as a serie using :
|
||||
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/02/
|
||||
"""
|
||||
|
||||
def body(state):
|
||||
serie, k, term = state
|
||||
serie += term * (digamma(b) - digamma(b + k))
|
||||
term *= (a + k) / (b + k) * x / (k + 1)
|
||||
k += 1
|
||||
|
||||
return serie, k, term
|
||||
|
||||
def cond(state):
|
||||
serie, k, term = state
|
||||
|
||||
return (k < 250) & (lax.abs(term) / lax.abs(serie) > 1e-15)
|
||||
|
||||
init = 0, 1, a / b * x
|
||||
|
||||
return lax.while_loop(cond, body, init)[0]
|
||||
|
||||
|
||||
@jit
|
||||
def _hyp1f1_x_derivative(a, b, x):
|
||||
"""
|
||||
Define it as a serie using :
|
||||
https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/04/
|
||||
"""
|
||||
|
||||
return a / b * hyp1f1(a + 1, b + 1, x)
|
||||
|
||||
|
||||
@custom_derivatives.custom_jvp
|
||||
@jit
|
||||
@jnp.vectorize
|
||||
@_wraps(osp_special.hyp1f1, module='scipy.special', lax_description="""\
|
||||
The JAX version only accepts positive and real inputs. Values of a, b and x
|
||||
leading to high values of 1F1 might be erroneous, considering enabling double
|
||||
precision. Convention for a = b = 0 is 1, unlike in scipy's implementation.""")
|
||||
def hyp1f1(a, b, x):
|
||||
"""
|
||||
Implementation of the 1F1 hypergeometric function for real valued inputs
|
||||
Backed by https://doi.org/10.48550/arXiv.1407.7786
|
||||
There is room for improvement in the implementation using recursion to
|
||||
evaluate lower values of hyp1f1 when a or b or both are > 60-80
|
||||
"""
|
||||
a, b, x = promote_args_inexact('hyp1f1', a, b, x)
|
||||
|
||||
result = lax.cond(lax.abs(x) < 100, _hyp1f1_serie, _hyp1f1_asymptotic, a, b, x)
|
||||
index = (a == 0) * 1 + ((a == b) & (a != 0)) * 2 + ((b == 0) & (a != 0)) * 3
|
||||
|
||||
return lax.select_n(index,
|
||||
result,
|
||||
jnp.array(1, dtype=x.dtype),
|
||||
jnp.exp(x),
|
||||
jnp.array(jnp.inf, dtype=x.dtype))
|
||||
|
||||
|
||||
hyp1f1.defjvps(
|
||||
lambda a_dot, primal_out, a, b, x: _hyp1f1_a_derivative(a, b, x) * a_dot,
|
||||
lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot,
|
||||
lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot
|
||||
)
|
||||
|
@ -54,4 +54,6 @@ from jax._src.scipy.special import (
|
||||
zeta as zeta,
|
||||
kl_div as kl_div,
|
||||
rel_entr as rel_entr,
|
||||
poch as poch,
|
||||
hyp1f1 as hyp1f1,
|
||||
)
|
||||
|
@ -141,7 +141,8 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
op_record(
|
||||
"rel_entr", 2, float_dtypes, jtu.rand_positive, True,
|
||||
),
|
||||
|
||||
op_record("poch", 2, float_dtypes, jtu.rand_positive, True),
|
||||
op_record("hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True)
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user