Support value computation of associated Legendre functions.

Co-authored-by: Jake VanderPlas <jakevdp@google.com>
This commit is contained in:
tlu7 2021-06-14 14:51:37 -07:00
parent 1e4d28a2d9
commit 095e6507b9
4 changed files with 103 additions and 1 deletions

View File

@ -100,6 +100,7 @@ jax.scipy.special
logit
logsumexp
lpmn
lpmn_values
multigammaln
ndtr
ndtri

View File

@ -914,7 +914,7 @@ def _gen_associated_legendre(l_max: int,
return p
def lpmn(m, n, z):
def lpmn(m: int, n: int, z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""The associated Legendre functions (ALFs) of the first kind.
Args:
@ -956,3 +956,57 @@ def lpmn(m, n, z):
p_derivatives = _gen_derivatives(p_vals, z, is_normalized)
return (p_vals, p_derivatives)
def lpmn_values(m: int, n: int, z: jnp.ndarray, is_normalized: bool) -> jnp.ndarray:
r"""The associated Legendre functions (ALFs) of the first kind.
Unlike `lpmn`, this function only computes the values of ALFs.
The ALFs of the first kind can be used in spherical harmonics. The
spherical harmonic of degree `l` and order `m` can be written as
:math:`Y_l^m(\theta, \phi) = N_l^m * P_l^m(\cos \theta) * \exp(i m \phi)`,
where :math:`N_l^m` is the normalization factor and θ and φ are the
colatitude and longitude, repectively. :math:`N_l^m` is chosen in the
way that the spherical harmonics form a set of orthonormal basis function
of :math:`L^2(S^2)`. Normalizing :math:`P_l^m` avoids overflow/underflow
and achieves better numerical stability.
Args:
m: The maximum order of the associated Legendre functions.
n: The maximum degree of the associated Legendre function, often called
`l` in describing ALFs. Both the degrees and orders are
`[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree.
z: A vector of type `float32` or `float64` containing the sampling
points at which the ALFs are computed.
is_normalized: True if the associated Legendre functions are normalized.
With normalization, :math:`N_l^m` is applied such that the spherical
harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`.
Returns:
A 3D array of shape `(l_max + 1, l_max + 1, len(z))` containing
the values of the associated Legendre functions of the first kind. The
return type matches the type of `z`.
Raises:
TypeError if elements of array `z` are not in (float32, float64).
ValueError if array `z` is not 1D.
NotImplementedError if `m!=n`.
"""
dtype = lax.dtype(z)
if dtype not in (jnp.float32, jnp.float64):
raise TypeError(
'z.dtype={} is not supported, see docstring for supported types.'
.format(dtype))
if z.ndim != 1:
raise ValueError('z must be a 1D array.')
m = core.concrete_or_error(int, m, 'Argument m of lpmn.')
n = core.concrete_or_error(int, n, 'Argument n of lpmn.')
if m != n:
raise NotImplementedError('Computations for m!=n are not yet supported.')
l_max = n
return _gen_associated_legendre(l_max, z, is_normalized)

View File

@ -33,6 +33,7 @@ from jax._src.scipy.special import (
logit,
logsumexp,
lpmn,
lpmn_values,
multigammaln,
log_ndtr,
ndtr,

View File

@ -266,6 +266,52 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual_p_derivatives,expected_p_derivatives,
rtol=1e-6, atol=8.4e-4)
with self.subTest('Test JIT compatibility'):
args_maker = lambda: [z]
lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z)
self._CompileAndCheck(lsp_special_fn, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_maxdegree={}_inputsize={}".format(l_max, num_z),
"l_max": l_max,
"num_z": num_z}
for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64])))
def testNormalizedLpmnValues(self, l_max, num_z):
# Points on which the associated Legendre functions areevaluated.
z = np.linspace(-0.2, 0.9, num_z)
is_normalized = True
actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
# The expected results are obtained from scipy.
expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
for i in range(num_z):
expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0]
def apply_normalization(a):
"""Applies normalization to the associated Legendre functions."""
num_m, num_l, _ = a.shape
a_normalized = np.zeros_like(a)
for m in range(num_m):
for l in range(num_l):
c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m)
c1 = (4.0 * np.pi) * osp_special.factorial(l + m)
c2 = np.sqrt(c0 / c1)
a_normalized[m, l] = c2 * a[m, l]
return a_normalized
# The results from scipy are not normalized and the comparison requires
# normalizing the results.
expected_p_vals_normalized = apply_normalization(expected_p_vals)
with self.subTest('Test accuracy.'):
self.assertAllClose(actual_p_vals, expected_p_vals_normalized, rtol=1e-6, atol=3.2e-6)
with self.subTest('Test JIT compatibility'):
args_maker = lambda: [z]
lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
self._CompileAndCheck(lsp_special_fn, args_maker)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())