mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Adds associated Legendre functions of the first kind.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
This commit is contained in:
parent
46cc654537
commit
a02bf59233
@ -99,6 +99,7 @@ jax.scipy.special
|
|||||||
log_ndtr
|
log_ndtr
|
||||||
logit
|
logit
|
||||||
logsumexp
|
logsumexp
|
||||||
|
lpmn
|
||||||
multigammaln
|
multigammaln
|
||||||
ndtr
|
ndtr
|
||||||
ndtri
|
ndtri
|
||||||
|
@ -18,13 +18,17 @@ import numpy as np
|
|||||||
import scipy.special as osp_special
|
import scipy.special as osp_special
|
||||||
|
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
|
from jax import jit
|
||||||
from jax import lax, core
|
from jax import lax, core
|
||||||
|
from jax import ops
|
||||||
from jax.interpreters import ad
|
from jax.interpreters import ad
|
||||||
from jax._src.numpy import lax_numpy as jnp
|
from jax._src.numpy import lax_numpy as jnp
|
||||||
from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
|
from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
|
||||||
_promote_args_inexact)
|
_promote_args_inexact)
|
||||||
from jax._src.numpy.util import _wraps
|
from jax._src.numpy.util import _wraps
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
@_wraps(osp_special.gammaln)
|
@_wraps(osp_special.gammaln)
|
||||||
def gammaln(x):
|
def gammaln(x):
|
||||||
@ -667,3 +671,287 @@ def i1e(x):
|
|||||||
def i1(x):
|
def i1(x):
|
||||||
x, = _promote_args_inexact("i1", x)
|
x, = _promote_args_inexact("i1", x)
|
||||||
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
|
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_recurrence_mask(
|
||||||
|
l_max: int, is_normalized: bool = True
|
||||||
|
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||||
|
"""Generates mask for recurrence relation on the remaining entries.
|
||||||
|
|
||||||
|
The remaining entries are with respect to the diagonal and offdiagonal
|
||||||
|
entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
l_max: see `gen_normalized_legendre`.
|
||||||
|
is_normalized: True if the recurrence mask is used by normalized associated
|
||||||
|
Legendre functions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Arrays representing the mask used by the recurrence relations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Computes all coefficients.
|
||||||
|
m_mat, l_mat = jnp.mgrid[:l_max + 1, :l_max + 1]
|
||||||
|
if is_normalized:
|
||||||
|
c0 = l_mat * l_mat
|
||||||
|
c1 = m_mat * m_mat
|
||||||
|
c2 = 2.0 * l_mat
|
||||||
|
c3 = (l_mat - 1.0) * (l_mat - 1.0)
|
||||||
|
d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
|
||||||
|
d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
|
||||||
|
else:
|
||||||
|
d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
|
||||||
|
d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
|
||||||
|
|
||||||
|
d0_mask_indices = jnp.triu_indices(l_max + 1, 1)
|
||||||
|
d1_mask_indices = jnp.triu_indices(l_max + 1, 2)
|
||||||
|
d_zeros = jnp.zeros((l_max + 1, l_max + 1))
|
||||||
|
d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices])
|
||||||
|
d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices])
|
||||||
|
|
||||||
|
# Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
|
||||||
|
# i = jnp.arange(l_max + 1)[:, None, None]
|
||||||
|
# j = jnp.arange(l_max + 1)[None, :, None]
|
||||||
|
# k = jnp.arange(l_max + 1)[None, None, :]
|
||||||
|
i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1]
|
||||||
|
mask = 1.0 * (i + j - k == 0)
|
||||||
|
|
||||||
|
d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask)
|
||||||
|
d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask)
|
||||||
|
|
||||||
|
return (d0_mask_3d, d1_mask_3d)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnums=(2))
|
||||||
|
def _gen_derivatives(p: jnp.ndarray,
|
||||||
|
x: jnp.ndarray,
|
||||||
|
is_normalized: bool) -> jnp.ndarray:
|
||||||
|
"""Generates derivatives of associated Legendre functions of the first kind.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p: The 3D array containing the values of associated Legendre functions; the
|
||||||
|
dimensions are in the sequence of order (m), degree (l), and evalution
|
||||||
|
points.
|
||||||
|
x: A vector of type `float32` or `float64` containing the sampled points.
|
||||||
|
is_normalized: True if the associated Legendre functions are normalized.
|
||||||
|
Returns:
|
||||||
|
The 3D array representing the derivatives of associated Legendre functions
|
||||||
|
of the first kind.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_m, num_l, num_x = p.shape
|
||||||
|
|
||||||
|
# p_{l-1}^m.
|
||||||
|
p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :]
|
||||||
|
|
||||||
|
# p_{l-1}^{m+2}.
|
||||||
|
p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :]
|
||||||
|
|
||||||
|
# p_{l-1}^{m-2}.
|
||||||
|
p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :]
|
||||||
|
|
||||||
|
# Derivative computation requires negative orders.
|
||||||
|
if is_normalized:
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Negative orders for normalization is not implemented yet.')
|
||||||
|
else:
|
||||||
|
if num_l > 1:
|
||||||
|
l_vec = jnp.arange(1, num_l - 1)
|
||||||
|
p_p1 = p[1, 1:num_l - 1, :]
|
||||||
|
coeff = -1.0 / ((l_vec + 1) * l_vec)
|
||||||
|
update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
|
||||||
|
p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1)
|
||||||
|
|
||||||
|
if num_l > 2:
|
||||||
|
l_vec = jnp.arange(2, num_l - 1)
|
||||||
|
p_p2 = p[2, 2:num_l - 1, :]
|
||||||
|
coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec)
|
||||||
|
update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
|
||||||
|
p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2)
|
||||||
|
|
||||||
|
m_mat, l_mat = jnp.mgrid[:num_m, :num_l]
|
||||||
|
|
||||||
|
coeff_zeros = jnp.zeros((num_m, num_l))
|
||||||
|
upper_0_indices = jnp.triu_indices(num_m, 0, num_l)
|
||||||
|
zero_vec = jnp.zeros((num_l,))
|
||||||
|
|
||||||
|
a0 = -0.5 / (m_mat - 1.0)
|
||||||
|
a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices])
|
||||||
|
a0_masked = a0_masked.at[1, :].set(zero_vec)
|
||||||
|
|
||||||
|
b0 = l_mat + m_mat
|
||||||
|
c0 = a0 * (b0 - 2.0) * (b0 - 1.0)
|
||||||
|
c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices])
|
||||||
|
c0_masked = c0_masked.at[1, :].set(zero_vec)
|
||||||
|
|
||||||
|
# p_l^{m-1}.
|
||||||
|
p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) +
|
||||||
|
jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1))
|
||||||
|
|
||||||
|
d0 = -0.5 / (m_mat + 1.0)
|
||||||
|
d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices])
|
||||||
|
e0 = d0 * b0 * (b0 + 1.0)
|
||||||
|
e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices])
|
||||||
|
|
||||||
|
# p_l^{m+1}.
|
||||||
|
p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) +
|
||||||
|
jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1))
|
||||||
|
|
||||||
|
f0 = b0 * (l_mat - m_mat + 1.0) / 2.0
|
||||||
|
f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices])
|
||||||
|
p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l
|
||||||
|
|
||||||
|
# Special treatment of the singularity at m = 1.
|
||||||
|
if num_m > 1:
|
||||||
|
l_vec = jnp.arange(num_l)
|
||||||
|
g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :])
|
||||||
|
if num_l > 2:
|
||||||
|
g0 = g0 - p[2, :, :]
|
||||||
|
p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0)
|
||||||
|
p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0)
|
||||||
|
p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x,)))
|
||||||
|
|
||||||
|
return p_derivative
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnums=(0, 2))
|
||||||
|
def _gen_associated_legendre(l_max: int,
|
||||||
|
x: jnp.ndarray,
|
||||||
|
is_normalized: bool) -> jnp.ndarray:
|
||||||
|
r"""Computes associated Legendre functions (ALFs) of the first kind.
|
||||||
|
|
||||||
|
The ALFs of the first kind are used in spherical harmonics. The spherical
|
||||||
|
harmonic of degree `l` and order `m` can be written as
|
||||||
|
`Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
|
||||||
|
normalization factor and θ and φ are the colatitude and longitude,
|
||||||
|
repectively. `N_l^m` is chosen in the way that the spherical harmonics form
|
||||||
|
a set of orthonormal basis function of L^2(S^2). For the computational
|
||||||
|
efficiency of spherical harmonics transform, the normalization factor is
|
||||||
|
used in the computation of the ALFs. In addition, normalizing `P_l^m`
|
||||||
|
avoids overflow/underflow and achieves better numerical stability. Three
|
||||||
|
recurrence relations are used in the computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
l_max: The maximum degree of the associated Legendre function. Both the
|
||||||
|
degrees and orders are `[0, 1, 2, ..., l_max]`.
|
||||||
|
x: A vector of type `float32`, `float64` containing the sampled points in
|
||||||
|
spherical coordinates, at which the ALFs are computed; `x` is essentially
|
||||||
|
`cos(θ)`. For the numerical integration used by the spherical harmonics
|
||||||
|
transforms, `x` contains the quadrature points in the interval of
|
||||||
|
`[-1, 1]`. There are several approaches to provide the quadrature points:
|
||||||
|
Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
|
||||||
|
method (`scipy.special.roots_chebyu`), and Driscoll & Healy
|
||||||
|
method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
|
||||||
|
transforms and convolutions on the 2-sphere." Advances in applied
|
||||||
|
mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
|
||||||
|
points are nearly equal-spaced along θ and provide exact discrete
|
||||||
|
orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
|
||||||
|
operation, `W` is a diagonal matrix containing the quadrature weights,
|
||||||
|
and `I` is the identity matrix. The Gauss-Chebyshev points are equally
|
||||||
|
spaced, which only provide approximate discrete orthogonality. The
|
||||||
|
Driscoll & Healy qudarture points are equally spaced and provide the
|
||||||
|
exact discrete orthogonality. The number of sampling points is required to
|
||||||
|
be twice as the number of frequency points (modes) in the Driscoll & Healy
|
||||||
|
approach, which enables FFT and achieves a fast spherical harmonics
|
||||||
|
transform.
|
||||||
|
is_normalized: True if the associated Legendre functions are normalized.
|
||||||
|
With normalization, `N_l^m` is applied such that the spherical harmonics
|
||||||
|
form a set of orthonormal basis functions of L^2(S^2).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
|
||||||
|
of the ALFs at `x`; the dimensions in the sequence of order, degree, and
|
||||||
|
evalution points.
|
||||||
|
"""
|
||||||
|
p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]))
|
||||||
|
|
||||||
|
a_idx = jnp.arange(1, l_max + 1)
|
||||||
|
b_idx = jnp.arange(l_max)
|
||||||
|
if is_normalized:
|
||||||
|
initial_value = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0).
|
||||||
|
f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx))
|
||||||
|
f_b = jnp.sqrt(2.0 * b_idx + 3.0)
|
||||||
|
else:
|
||||||
|
initial_value = 1.0 # The initial value p(0,0).
|
||||||
|
f_a = jnp.cumprod(1.0 - 2.0 * a_idx)
|
||||||
|
f_b = 2.0 * b_idx + 1.0
|
||||||
|
|
||||||
|
p = p.at[(0, 0)].set(initial_value)
|
||||||
|
|
||||||
|
# Compute the diagonal entries p(l,l) with recurrence.
|
||||||
|
y = jnp.cumprod(
|
||||||
|
jnp.broadcast_to(jnp.sqrt(1.0 - x * x), (l_max, x.shape[0])),
|
||||||
|
axis=0)
|
||||||
|
p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y)
|
||||||
|
diag_indices = jnp.diag_indices(l_max + 1)
|
||||||
|
p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag)
|
||||||
|
|
||||||
|
# Compute the off-diagonal entries with recurrence.
|
||||||
|
p_offdiag = jnp.einsum('ij,ij->ij',
|
||||||
|
jnp.einsum('i,j->ij', f_b, x),
|
||||||
|
p[jnp.diag_indices(l_max)])
|
||||||
|
offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1)
|
||||||
|
p = p.at[offdiag_indices].set(p_offdiag)
|
||||||
|
|
||||||
|
# Compute the remaining entries with recurrence.
|
||||||
|
d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(
|
||||||
|
l_max, is_normalized=is_normalized)
|
||||||
|
|
||||||
|
def body_fun(i, p_val):
|
||||||
|
coeff_0 = d0_mask_3d[i]
|
||||||
|
coeff_1 = d1_mask_3d[i]
|
||||||
|
h = (jnp.einsum('ij,ijk->ijk',
|
||||||
|
coeff_0,
|
||||||
|
jnp.einsum(
|
||||||
|
'ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) -
|
||||||
|
jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(p_val, shift=2, axis=1)))
|
||||||
|
p_val = p_val + h
|
||||||
|
return p_val
|
||||||
|
|
||||||
|
p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p)
|
||||||
|
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def lpmn(m, n, z):
|
||||||
|
"""The associated Legendre functions (ALFs) of the first kind.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m: The maximum order of the associated Legendre functions.
|
||||||
|
n: The maximum degree of the associated Legendre function, often called
|
||||||
|
`l` in describing ALFs. Both the degrees and orders are
|
||||||
|
`[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree.
|
||||||
|
z: A vector of type `float32` or `float64` containing the sampling
|
||||||
|
points at which the ALFs are computed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 2-tuple of 3D arrays of shape `(l_max + 1, l_max + 1, len(z))` containing
|
||||||
|
the values and derivatives of the associated Legendre functions of the
|
||||||
|
first kind. The return type matches the type of `z`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError if elements of array `z` are not in (float32, float64).
|
||||||
|
ValueError if array `z` is not 1D.
|
||||||
|
NotImplementedError if `m!=n`.
|
||||||
|
"""
|
||||||
|
dtype = lax.dtype(z)
|
||||||
|
if dtype not in (jnp.float32, jnp.float64):
|
||||||
|
raise TypeError(
|
||||||
|
'z.dtype={} is not supported, see docstring for supported types.'
|
||||||
|
.format(dtype))
|
||||||
|
|
||||||
|
if z.ndim != 1:
|
||||||
|
raise ValueError('z must be a 1D array.')
|
||||||
|
|
||||||
|
m = core.concrete_or_error(int, m, 'Argument m of lpmn.')
|
||||||
|
n = core.concrete_or_error(int, n, 'Argument n of lpmn.')
|
||||||
|
|
||||||
|
if m != n:
|
||||||
|
raise NotImplementedError('Computations for m!=n are not yet supported.')
|
||||||
|
|
||||||
|
l_max = n
|
||||||
|
is_normalized = False
|
||||||
|
p_vals = _gen_associated_legendre(l_max, z, is_normalized)
|
||||||
|
p_derivatives = _gen_derivatives(p_vals, z, is_normalized)
|
||||||
|
|
||||||
|
return (p_vals, p_derivatives)
|
||||||
|
@ -32,6 +32,7 @@ from jax._src.scipy.special import (
|
|||||||
i1e,
|
i1e,
|
||||||
logit,
|
logit,
|
||||||
logsumexp,
|
logsumexp,
|
||||||
|
lpmn,
|
||||||
multigammaln,
|
multigammaln,
|
||||||
log_ndtr,
|
log_ndtr,
|
||||||
ndtr,
|
ndtr,
|
||||||
|
@ -239,6 +239,33 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
|||||||
partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
|
partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
|
||||||
self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
|
self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
{"testcase_name":
|
||||||
|
"_maxdegree={}_inputsize={}".format(l_max, num_z),
|
||||||
|
"l_max": l_max,
|
||||||
|
"num_z": num_z}
|
||||||
|
for l_max, num_z in zip([1, 2, 3], [6, 7, 8])))
|
||||||
|
def testLpmn(self, l_max, num_z):
|
||||||
|
# Points on which the associated Legendre functions areevaluated.
|
||||||
|
z = np.linspace(-0.2, 0.9, num_z)
|
||||||
|
actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z)
|
||||||
|
|
||||||
|
# The expected results are obtained from scipy.
|
||||||
|
expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
|
||||||
|
expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z))
|
||||||
|
|
||||||
|
for i in range(num_z):
|
||||||
|
val, derivative = osp_special.lpmn(l_max, l_max, z[i])
|
||||||
|
expected_p_vals[:, :, i] = val
|
||||||
|
expected_p_derivatives[:, :, i] = derivative
|
||||||
|
|
||||||
|
with self.subTest('Test values.'):
|
||||||
|
self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6)
|
||||||
|
|
||||||
|
with self.subTest('Test derivatives.'):
|
||||||
|
self.assertAllClose(actual_p_derivatives,expected_p_derivatives,
|
||||||
|
rtol=1e-6, atol=8.4e-4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user