Adds associated Legendre functions of the first kind.

Co-authored-by: Jake VanderPlas <jakevdp@google.com>
This commit is contained in:
tlu7 2021-06-02 11:37:37 -07:00
parent 46cc654537
commit a02bf59233
4 changed files with 317 additions and 0 deletions

View File

@ -99,6 +99,7 @@ jax.scipy.special
log_ndtr
logit
logsumexp
lpmn
multigammaln
ndtr
ndtri

View File

@ -18,13 +18,17 @@ import numpy as np
import scipy.special as osp_special
from jax._src import api
from jax import jit
from jax import lax, core
from jax import ops
from jax.interpreters import ad
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
_promote_args_inexact)
from jax._src.numpy.util import _wraps
from typing import Tuple
@_wraps(osp_special.gammaln)
def gammaln(x):
@ -667,3 +671,287 @@ def i1e(x):
def i1(x):
x, = _promote_args_inexact("i1", x)
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
def _gen_recurrence_mask(
l_max: int, is_normalized: bool = True
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Generates mask for recurrence relation on the remaining entries.
The remaining entries are with respect to the diagonal and offdiagonal
entries.
Args:
l_max: see `gen_normalized_legendre`.
is_normalized: True if the recurrence mask is used by normalized associated
Legendre functions.
Returns:
Arrays representing the mask used by the recurrence relations.
"""
# Computes all coefficients.
m_mat, l_mat = jnp.mgrid[:l_max + 1, :l_max + 1]
if is_normalized:
c0 = l_mat * l_mat
c1 = m_mat * m_mat
c2 = 2.0 * l_mat
c3 = (l_mat - 1.0) * (l_mat - 1.0)
d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
else:
d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
d0_mask_indices = jnp.triu_indices(l_max + 1, 1)
d1_mask_indices = jnp.triu_indices(l_max + 1, 2)
d_zeros = jnp.zeros((l_max + 1, l_max + 1))
d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices])
d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices])
# Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
# i = jnp.arange(l_max + 1)[:, None, None]
# j = jnp.arange(l_max + 1)[None, :, None]
# k = jnp.arange(l_max + 1)[None, None, :]
i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1]
mask = 1.0 * (i + j - k == 0)
d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask)
d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask)
return (d0_mask_3d, d1_mask_3d)
@partial(jit, static_argnums=(2))
def _gen_derivatives(p: jnp.ndarray,
x: jnp.ndarray,
is_normalized: bool) -> jnp.ndarray:
"""Generates derivatives of associated Legendre functions of the first kind.
Args:
p: The 3D array containing the values of associated Legendre functions; the
dimensions are in the sequence of order (m), degree (l), and evalution
points.
x: A vector of type `float32` or `float64` containing the sampled points.
is_normalized: True if the associated Legendre functions are normalized.
Returns:
The 3D array representing the derivatives of associated Legendre functions
of the first kind.
"""
num_m, num_l, num_x = p.shape
# p_{l-1}^m.
p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :]
# p_{l-1}^{m+2}.
p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :]
# p_{l-1}^{m-2}.
p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :]
# Derivative computation requires negative orders.
if is_normalized:
raise NotImplementedError(
'Negative orders for normalization is not implemented yet.')
else:
if num_l > 1:
l_vec = jnp.arange(1, num_l - 1)
p_p1 = p[1, 1:num_l - 1, :]
coeff = -1.0 / ((l_vec + 1) * l_vec)
update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1)
if num_l > 2:
l_vec = jnp.arange(2, num_l - 1)
p_p2 = p[2, 2:num_l - 1, :]
coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec)
update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2)
m_mat, l_mat = jnp.mgrid[:num_m, :num_l]
coeff_zeros = jnp.zeros((num_m, num_l))
upper_0_indices = jnp.triu_indices(num_m, 0, num_l)
zero_vec = jnp.zeros((num_l,))
a0 = -0.5 / (m_mat - 1.0)
a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices])
a0_masked = a0_masked.at[1, :].set(zero_vec)
b0 = l_mat + m_mat
c0 = a0 * (b0 - 2.0) * (b0 - 1.0)
c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices])
c0_masked = c0_masked.at[1, :].set(zero_vec)
# p_l^{m-1}.
p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) +
jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1))
d0 = -0.5 / (m_mat + 1.0)
d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices])
e0 = d0 * b0 * (b0 + 1.0)
e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices])
# p_l^{m+1}.
p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) +
jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1))
f0 = b0 * (l_mat - m_mat + 1.0) / 2.0
f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices])
p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l
# Special treatment of the singularity at m = 1.
if num_m > 1:
l_vec = jnp.arange(num_l)
g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :])
if num_l > 2:
g0 = g0 - p[2, :, :]
p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0)
p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0)
p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x,)))
return p_derivative
@partial(jit, static_argnums=(0, 2))
def _gen_associated_legendre(l_max: int,
x: jnp.ndarray,
is_normalized: bool) -> jnp.ndarray:
r"""Computes associated Legendre functions (ALFs) of the first kind.
The ALFs of the first kind are used in spherical harmonics. The spherical
harmonic of degree `l` and order `m` can be written as
`Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
normalization factor and θ and φ are the colatitude and longitude,
repectively. `N_l^m` is chosen in the way that the spherical harmonics form
a set of orthonormal basis function of L^2(S^2). For the computational
efficiency of spherical harmonics transform, the normalization factor is
used in the computation of the ALFs. In addition, normalizing `P_l^m`
avoids overflow/underflow and achieves better numerical stability. Three
recurrence relations are used in the computation.
Args:
l_max: The maximum degree of the associated Legendre function. Both the
degrees and orders are `[0, 1, 2, ..., l_max]`.
x: A vector of type `float32`, `float64` containing the sampled points in
spherical coordinates, at which the ALFs are computed; `x` is essentially
`cos(θ)`. For the numerical integration used by the spherical harmonics
transforms, `x` contains the quadrature points in the interval of
`[-1, 1]`. There are several approaches to provide the quadrature points:
Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
method (`scipy.special.roots_chebyu`), and Driscoll & Healy
method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
transforms and convolutions on the 2-sphere." Advances in applied
mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
points are nearly equal-spaced along θ and provide exact discrete
orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
operation, `W` is a diagonal matrix containing the quadrature weights,
and `I` is the identity matrix. The Gauss-Chebyshev points are equally
spaced, which only provide approximate discrete orthogonality. The
Driscoll & Healy qudarture points are equally spaced and provide the
exact discrete orthogonality. The number of sampling points is required to
be twice as the number of frequency points (modes) in the Driscoll & Healy
approach, which enables FFT and achieves a fast spherical harmonics
transform.
is_normalized: True if the associated Legendre functions are normalized.
With normalization, `N_l^m` is applied such that the spherical harmonics
form a set of orthonormal basis functions of L^2(S^2).
Returns:
The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
of the ALFs at `x`; the dimensions in the sequence of order, degree, and
evalution points.
"""
p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]))
a_idx = jnp.arange(1, l_max + 1)
b_idx = jnp.arange(l_max)
if is_normalized:
initial_value = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0).
f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx))
f_b = jnp.sqrt(2.0 * b_idx + 3.0)
else:
initial_value = 1.0 # The initial value p(0,0).
f_a = jnp.cumprod(1.0 - 2.0 * a_idx)
f_b = 2.0 * b_idx + 1.0
p = p.at[(0, 0)].set(initial_value)
# Compute the diagonal entries p(l,l) with recurrence.
y = jnp.cumprod(
jnp.broadcast_to(jnp.sqrt(1.0 - x * x), (l_max, x.shape[0])),
axis=0)
p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y)
diag_indices = jnp.diag_indices(l_max + 1)
p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag)
# Compute the off-diagonal entries with recurrence.
p_offdiag = jnp.einsum('ij,ij->ij',
jnp.einsum('i,j->ij', f_b, x),
p[jnp.diag_indices(l_max)])
offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1)
p = p.at[offdiag_indices].set(p_offdiag)
# Compute the remaining entries with recurrence.
d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(
l_max, is_normalized=is_normalized)
def body_fun(i, p_val):
coeff_0 = d0_mask_3d[i]
coeff_1 = d1_mask_3d[i]
h = (jnp.einsum('ij,ijk->ijk',
coeff_0,
jnp.einsum(
'ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) -
jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(p_val, shift=2, axis=1)))
p_val = p_val + h
return p_val
p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p)
return p
def lpmn(m, n, z):
"""The associated Legendre functions (ALFs) of the first kind.
Args:
m: The maximum order of the associated Legendre functions.
n: The maximum degree of the associated Legendre function, often called
`l` in describing ALFs. Both the degrees and orders are
`[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree.
z: A vector of type `float32` or `float64` containing the sampling
points at which the ALFs are computed.
Returns:
A 2-tuple of 3D arrays of shape `(l_max + 1, l_max + 1, len(z))` containing
the values and derivatives of the associated Legendre functions of the
first kind. The return type matches the type of `z`.
Raises:
TypeError if elements of array `z` are not in (float32, float64).
ValueError if array `z` is not 1D.
NotImplementedError if `m!=n`.
"""
dtype = lax.dtype(z)
if dtype not in (jnp.float32, jnp.float64):
raise TypeError(
'z.dtype={} is not supported, see docstring for supported types.'
.format(dtype))
if z.ndim != 1:
raise ValueError('z must be a 1D array.')
m = core.concrete_or_error(int, m, 'Argument m of lpmn.')
n = core.concrete_or_error(int, n, 'Argument n of lpmn.')
if m != n:
raise NotImplementedError('Computations for m!=n are not yet supported.')
l_max = n
is_normalized = False
p_vals = _gen_associated_legendre(l_max, z, is_normalized)
p_derivatives = _gen_derivatives(p_vals, z, is_normalized)
return (p_vals, p_derivatives)

View File

@ -32,6 +32,7 @@ from jax._src.scipy.special import (
i1e,
logit,
logsumexp,
lpmn,
multigammaln,
log_ndtr,
ndtr,

View File

@ -239,6 +239,33 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_maxdegree={}_inputsize={}".format(l_max, num_z),
"l_max": l_max,
"num_z": num_z}
for l_max, num_z in zip([1, 2, 3], [6, 7, 8])))
def testLpmn(self, l_max, num_z):
# Points on which the associated Legendre functions areevaluated.
z = np.linspace(-0.2, 0.9, num_z)
actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z)
# The expected results are obtained from scipy.
expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z))
for i in range(num_z):
val, derivative = osp_special.lpmn(l_max, l_max, z[i])
expected_p_vals[:, :, i] = val
expected_p_derivatives[:, :, i] = derivative
with self.subTest('Test values.'):
self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6)
with self.subTest('Test derivatives.'):
self.assertAllClose(actual_p_derivatives,expected_p_derivatives,
rtol=1e-6, atol=8.4e-4)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())