mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Adds associated Legendre functions of the first kind.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
This commit is contained in:
parent
46cc654537
commit
a02bf59233
@ -99,6 +99,7 @@ jax.scipy.special
|
||||
log_ndtr
|
||||
logit
|
||||
logsumexp
|
||||
lpmn
|
||||
multigammaln
|
||||
ndtr
|
||||
ndtri
|
||||
|
@ -18,13 +18,17 @@ import numpy as np
|
||||
import scipy.special as osp_special
|
||||
|
||||
from jax._src import api
|
||||
from jax import jit
|
||||
from jax import lax, core
|
||||
from jax import ops
|
||||
from jax.interpreters import ad
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
|
||||
_promote_args_inexact)
|
||||
from jax._src.numpy.util import _wraps
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@_wraps(osp_special.gammaln)
|
||||
def gammaln(x):
|
||||
@ -667,3 +671,287 @@ def i1e(x):
|
||||
def i1(x):
|
||||
x, = _promote_args_inexact("i1", x)
|
||||
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
|
||||
|
||||
|
||||
def _gen_recurrence_mask(
|
||||
l_max: int, is_normalized: bool = True
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
"""Generates mask for recurrence relation on the remaining entries.
|
||||
|
||||
The remaining entries are with respect to the diagonal and offdiagonal
|
||||
entries.
|
||||
|
||||
Args:
|
||||
l_max: see `gen_normalized_legendre`.
|
||||
is_normalized: True if the recurrence mask is used by normalized associated
|
||||
Legendre functions.
|
||||
|
||||
Returns:
|
||||
Arrays representing the mask used by the recurrence relations.
|
||||
"""
|
||||
|
||||
# Computes all coefficients.
|
||||
m_mat, l_mat = jnp.mgrid[:l_max + 1, :l_max + 1]
|
||||
if is_normalized:
|
||||
c0 = l_mat * l_mat
|
||||
c1 = m_mat * m_mat
|
||||
c2 = 2.0 * l_mat
|
||||
c3 = (l_mat - 1.0) * (l_mat - 1.0)
|
||||
d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
|
||||
d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
|
||||
else:
|
||||
d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
|
||||
d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
|
||||
|
||||
d0_mask_indices = jnp.triu_indices(l_max + 1, 1)
|
||||
d1_mask_indices = jnp.triu_indices(l_max + 1, 2)
|
||||
d_zeros = jnp.zeros((l_max + 1, l_max + 1))
|
||||
d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices])
|
||||
d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices])
|
||||
|
||||
# Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
|
||||
# i = jnp.arange(l_max + 1)[:, None, None]
|
||||
# j = jnp.arange(l_max + 1)[None, :, None]
|
||||
# k = jnp.arange(l_max + 1)[None, None, :]
|
||||
i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1]
|
||||
mask = 1.0 * (i + j - k == 0)
|
||||
|
||||
d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask)
|
||||
d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask)
|
||||
|
||||
return (d0_mask_3d, d1_mask_3d)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2))
|
||||
def _gen_derivatives(p: jnp.ndarray,
|
||||
x: jnp.ndarray,
|
||||
is_normalized: bool) -> jnp.ndarray:
|
||||
"""Generates derivatives of associated Legendre functions of the first kind.
|
||||
|
||||
Args:
|
||||
p: The 3D array containing the values of associated Legendre functions; the
|
||||
dimensions are in the sequence of order (m), degree (l), and evalution
|
||||
points.
|
||||
x: A vector of type `float32` or `float64` containing the sampled points.
|
||||
is_normalized: True if the associated Legendre functions are normalized.
|
||||
Returns:
|
||||
The 3D array representing the derivatives of associated Legendre functions
|
||||
of the first kind.
|
||||
"""
|
||||
|
||||
num_m, num_l, num_x = p.shape
|
||||
|
||||
# p_{l-1}^m.
|
||||
p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :]
|
||||
|
||||
# p_{l-1}^{m+2}.
|
||||
p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :]
|
||||
|
||||
# p_{l-1}^{m-2}.
|
||||
p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :]
|
||||
|
||||
# Derivative computation requires negative orders.
|
||||
if is_normalized:
|
||||
raise NotImplementedError(
|
||||
'Negative orders for normalization is not implemented yet.')
|
||||
else:
|
||||
if num_l > 1:
|
||||
l_vec = jnp.arange(1, num_l - 1)
|
||||
p_p1 = p[1, 1:num_l - 1, :]
|
||||
coeff = -1.0 / ((l_vec + 1) * l_vec)
|
||||
update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
|
||||
p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1)
|
||||
|
||||
if num_l > 2:
|
||||
l_vec = jnp.arange(2, num_l - 1)
|
||||
p_p2 = p[2, 2:num_l - 1, :]
|
||||
coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec)
|
||||
update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
|
||||
p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2)
|
||||
|
||||
m_mat, l_mat = jnp.mgrid[:num_m, :num_l]
|
||||
|
||||
coeff_zeros = jnp.zeros((num_m, num_l))
|
||||
upper_0_indices = jnp.triu_indices(num_m, 0, num_l)
|
||||
zero_vec = jnp.zeros((num_l,))
|
||||
|
||||
a0 = -0.5 / (m_mat - 1.0)
|
||||
a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices])
|
||||
a0_masked = a0_masked.at[1, :].set(zero_vec)
|
||||
|
||||
b0 = l_mat + m_mat
|
||||
c0 = a0 * (b0 - 2.0) * (b0 - 1.0)
|
||||
c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices])
|
||||
c0_masked = c0_masked.at[1, :].set(zero_vec)
|
||||
|
||||
# p_l^{m-1}.
|
||||
p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) +
|
||||
jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1))
|
||||
|
||||
d0 = -0.5 / (m_mat + 1.0)
|
||||
d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices])
|
||||
e0 = d0 * b0 * (b0 + 1.0)
|
||||
e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices])
|
||||
|
||||
# p_l^{m+1}.
|
||||
p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) +
|
||||
jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1))
|
||||
|
||||
f0 = b0 * (l_mat - m_mat + 1.0) / 2.0
|
||||
f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices])
|
||||
p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l
|
||||
|
||||
# Special treatment of the singularity at m = 1.
|
||||
if num_m > 1:
|
||||
l_vec = jnp.arange(num_l)
|
||||
g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :])
|
||||
if num_l > 2:
|
||||
g0 = g0 - p[2, :, :]
|
||||
p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0)
|
||||
p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0)
|
||||
p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x,)))
|
||||
|
||||
return p_derivative
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(0, 2))
|
||||
def _gen_associated_legendre(l_max: int,
|
||||
x: jnp.ndarray,
|
||||
is_normalized: bool) -> jnp.ndarray:
|
||||
r"""Computes associated Legendre functions (ALFs) of the first kind.
|
||||
|
||||
The ALFs of the first kind are used in spherical harmonics. The spherical
|
||||
harmonic of degree `l` and order `m` can be written as
|
||||
`Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
|
||||
normalization factor and θ and φ are the colatitude and longitude,
|
||||
repectively. `N_l^m` is chosen in the way that the spherical harmonics form
|
||||
a set of orthonormal basis function of L^2(S^2). For the computational
|
||||
efficiency of spherical harmonics transform, the normalization factor is
|
||||
used in the computation of the ALFs. In addition, normalizing `P_l^m`
|
||||
avoids overflow/underflow and achieves better numerical stability. Three
|
||||
recurrence relations are used in the computation.
|
||||
|
||||
Args:
|
||||
l_max: The maximum degree of the associated Legendre function. Both the
|
||||
degrees and orders are `[0, 1, 2, ..., l_max]`.
|
||||
x: A vector of type `float32`, `float64` containing the sampled points in
|
||||
spherical coordinates, at which the ALFs are computed; `x` is essentially
|
||||
`cos(θ)`. For the numerical integration used by the spherical harmonics
|
||||
transforms, `x` contains the quadrature points in the interval of
|
||||
`[-1, 1]`. There are several approaches to provide the quadrature points:
|
||||
Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
|
||||
method (`scipy.special.roots_chebyu`), and Driscoll & Healy
|
||||
method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
|
||||
transforms and convolutions on the 2-sphere." Advances in applied
|
||||
mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
|
||||
points are nearly equal-spaced along θ and provide exact discrete
|
||||
orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
|
||||
operation, `W` is a diagonal matrix containing the quadrature weights,
|
||||
and `I` is the identity matrix. The Gauss-Chebyshev points are equally
|
||||
spaced, which only provide approximate discrete orthogonality. The
|
||||
Driscoll & Healy qudarture points are equally spaced and provide the
|
||||
exact discrete orthogonality. The number of sampling points is required to
|
||||
be twice as the number of frequency points (modes) in the Driscoll & Healy
|
||||
approach, which enables FFT and achieves a fast spherical harmonics
|
||||
transform.
|
||||
is_normalized: True if the associated Legendre functions are normalized.
|
||||
With normalization, `N_l^m` is applied such that the spherical harmonics
|
||||
form a set of orthonormal basis functions of L^2(S^2).
|
||||
|
||||
Returns:
|
||||
The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
|
||||
of the ALFs at `x`; the dimensions in the sequence of order, degree, and
|
||||
evalution points.
|
||||
"""
|
||||
p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]))
|
||||
|
||||
a_idx = jnp.arange(1, l_max + 1)
|
||||
b_idx = jnp.arange(l_max)
|
||||
if is_normalized:
|
||||
initial_value = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0).
|
||||
f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx))
|
||||
f_b = jnp.sqrt(2.0 * b_idx + 3.0)
|
||||
else:
|
||||
initial_value = 1.0 # The initial value p(0,0).
|
||||
f_a = jnp.cumprod(1.0 - 2.0 * a_idx)
|
||||
f_b = 2.0 * b_idx + 1.0
|
||||
|
||||
p = p.at[(0, 0)].set(initial_value)
|
||||
|
||||
# Compute the diagonal entries p(l,l) with recurrence.
|
||||
y = jnp.cumprod(
|
||||
jnp.broadcast_to(jnp.sqrt(1.0 - x * x), (l_max, x.shape[0])),
|
||||
axis=0)
|
||||
p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y)
|
||||
diag_indices = jnp.diag_indices(l_max + 1)
|
||||
p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag)
|
||||
|
||||
# Compute the off-diagonal entries with recurrence.
|
||||
p_offdiag = jnp.einsum('ij,ij->ij',
|
||||
jnp.einsum('i,j->ij', f_b, x),
|
||||
p[jnp.diag_indices(l_max)])
|
||||
offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1)
|
||||
p = p.at[offdiag_indices].set(p_offdiag)
|
||||
|
||||
# Compute the remaining entries with recurrence.
|
||||
d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(
|
||||
l_max, is_normalized=is_normalized)
|
||||
|
||||
def body_fun(i, p_val):
|
||||
coeff_0 = d0_mask_3d[i]
|
||||
coeff_1 = d1_mask_3d[i]
|
||||
h = (jnp.einsum('ij,ijk->ijk',
|
||||
coeff_0,
|
||||
jnp.einsum(
|
||||
'ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) -
|
||||
jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(p_val, shift=2, axis=1)))
|
||||
p_val = p_val + h
|
||||
return p_val
|
||||
|
||||
p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p)
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def lpmn(m, n, z):
|
||||
"""The associated Legendre functions (ALFs) of the first kind.
|
||||
|
||||
Args:
|
||||
m: The maximum order of the associated Legendre functions.
|
||||
n: The maximum degree of the associated Legendre function, often called
|
||||
`l` in describing ALFs. Both the degrees and orders are
|
||||
`[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree.
|
||||
z: A vector of type `float32` or `float64` containing the sampling
|
||||
points at which the ALFs are computed.
|
||||
|
||||
Returns:
|
||||
A 2-tuple of 3D arrays of shape `(l_max + 1, l_max + 1, len(z))` containing
|
||||
the values and derivatives of the associated Legendre functions of the
|
||||
first kind. The return type matches the type of `z`.
|
||||
|
||||
Raises:
|
||||
TypeError if elements of array `z` are not in (float32, float64).
|
||||
ValueError if array `z` is not 1D.
|
||||
NotImplementedError if `m!=n`.
|
||||
"""
|
||||
dtype = lax.dtype(z)
|
||||
if dtype not in (jnp.float32, jnp.float64):
|
||||
raise TypeError(
|
||||
'z.dtype={} is not supported, see docstring for supported types.'
|
||||
.format(dtype))
|
||||
|
||||
if z.ndim != 1:
|
||||
raise ValueError('z must be a 1D array.')
|
||||
|
||||
m = core.concrete_or_error(int, m, 'Argument m of lpmn.')
|
||||
n = core.concrete_or_error(int, n, 'Argument n of lpmn.')
|
||||
|
||||
if m != n:
|
||||
raise NotImplementedError('Computations for m!=n are not yet supported.')
|
||||
|
||||
l_max = n
|
||||
is_normalized = False
|
||||
p_vals = _gen_associated_legendre(l_max, z, is_normalized)
|
||||
p_derivatives = _gen_derivatives(p_vals, z, is_normalized)
|
||||
|
||||
return (p_vals, p_derivatives)
|
||||
|
@ -32,6 +32,7 @@ from jax._src.scipy.special import (
|
||||
i1e,
|
||||
logit,
|
||||
logsumexp,
|
||||
lpmn,
|
||||
multigammaln,
|
||||
log_ndtr,
|
||||
ndtr,
|
||||
|
@ -239,6 +239,33 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
|
||||
self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_maxdegree={}_inputsize={}".format(l_max, num_z),
|
||||
"l_max": l_max,
|
||||
"num_z": num_z}
|
||||
for l_max, num_z in zip([1, 2, 3], [6, 7, 8])))
|
||||
def testLpmn(self, l_max, num_z):
|
||||
# Points on which the associated Legendre functions areevaluated.
|
||||
z = np.linspace(-0.2, 0.9, num_z)
|
||||
actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z)
|
||||
|
||||
# The expected results are obtained from scipy.
|
||||
expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
|
||||
expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z))
|
||||
|
||||
for i in range(num_z):
|
||||
val, derivative = osp_special.lpmn(l_max, l_max, z[i])
|
||||
expected_p_vals[:, :, i] = val
|
||||
expected_p_derivatives[:, :, i] = derivative
|
||||
|
||||
with self.subTest('Test values.'):
|
||||
self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6)
|
||||
|
||||
with self.subTest('Test derivatives.'):
|
||||
self.assertAllClose(actual_p_derivatives,expected_p_derivatives,
|
||||
rtol=1e-6, atol=8.4e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user