diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 627a1bd09..2d3cc3827 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -99,6 +99,7 @@ jax.scipy.special log_ndtr logit logsumexp + lpmn multigammaln ndtr ndtri diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 4b6e23a1b..8503708b1 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -18,13 +18,17 @@ import numpy as np import scipy.special as osp_special from jax._src import api +from jax import jit from jax import lax, core +from jax import ops from jax.interpreters import ad from jax._src.numpy import lax_numpy as jnp from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like, _promote_args_inexact) from jax._src.numpy.util import _wraps +from typing import Tuple + @_wraps(osp_special.gammaln) def gammaln(x): @@ -667,3 +671,287 @@ def i1e(x): def i1(x): x, = _promote_args_inexact("i1", x) return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x)) + + +def _gen_recurrence_mask( + l_max: int, is_normalized: bool = True +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Generates mask for recurrence relation on the remaining entries. + + The remaining entries are with respect to the diagonal and offdiagonal + entries. + + Args: + l_max: see `gen_normalized_legendre`. + is_normalized: True if the recurrence mask is used by normalized associated + Legendre functions. + + Returns: + Arrays representing the mask used by the recurrence relations. + """ + + # Computes all coefficients. + m_mat, l_mat = jnp.mgrid[:l_max + 1, :l_max + 1] + if is_normalized: + c0 = l_mat * l_mat + c1 = m_mat * m_mat + c2 = 2.0 * l_mat + c3 = (l_mat - 1.0) * (l_mat - 1.0) + d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1)) + d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1))) + else: + d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat) + d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat) + + d0_mask_indices = jnp.triu_indices(l_max + 1, 1) + d1_mask_indices = jnp.triu_indices(l_max + 1, 2) + d_zeros = jnp.zeros((l_max + 1, l_max + 1)) + d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices]) + d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices]) + + # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere. + # i = jnp.arange(l_max + 1)[:, None, None] + # j = jnp.arange(l_max + 1)[None, :, None] + # k = jnp.arange(l_max + 1)[None, None, :] + i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1] + mask = 1.0 * (i + j - k == 0) + + d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask) + d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask) + + return (d0_mask_3d, d1_mask_3d) + + +@partial(jit, static_argnums=(2)) +def _gen_derivatives(p: jnp.ndarray, + x: jnp.ndarray, + is_normalized: bool) -> jnp.ndarray: + """Generates derivatives of associated Legendre functions of the first kind. + + Args: + p: The 3D array containing the values of associated Legendre functions; the + dimensions are in the sequence of order (m), degree (l), and evalution + points. + x: A vector of type `float32` or `float64` containing the sampled points. + is_normalized: True if the associated Legendre functions are normalized. + Returns: + The 3D array representing the derivatives of associated Legendre functions + of the first kind. + """ + + num_m, num_l, num_x = p.shape + + # p_{l-1}^m. + p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :] + + # p_{l-1}^{m+2}. + p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :] + + # p_{l-1}^{m-2}. + p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :] + + # Derivative computation requires negative orders. + if is_normalized: + raise NotImplementedError( + 'Negative orders for normalization is not implemented yet.') + else: + if num_l > 1: + l_vec = jnp.arange(1, num_l - 1) + p_p1 = p[1, 1:num_l - 1, :] + coeff = -1.0 / ((l_vec + 1) * l_vec) + update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1) + p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1) + + if num_l > 2: + l_vec = jnp.arange(2, num_l - 1) + p_p2 = p[2, 2:num_l - 1, :] + coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec) + update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2) + p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2) + + m_mat, l_mat = jnp.mgrid[:num_m, :num_l] + + coeff_zeros = jnp.zeros((num_m, num_l)) + upper_0_indices = jnp.triu_indices(num_m, 0, num_l) + zero_vec = jnp.zeros((num_l,)) + + a0 = -0.5 / (m_mat - 1.0) + a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices]) + a0_masked = a0_masked.at[1, :].set(zero_vec) + + b0 = l_mat + m_mat + c0 = a0 * (b0 - 2.0) * (b0 - 1.0) + c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices]) + c0_masked = c0_masked.at[1, :].set(zero_vec) + + # p_l^{m-1}. + p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) + + jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1)) + + d0 = -0.5 / (m_mat + 1.0) + d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices]) + e0 = d0 * b0 * (b0 + 1.0) + e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices]) + + # p_l^{m+1}. + p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) + + jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1)) + + f0 = b0 * (l_mat - m_mat + 1.0) / 2.0 + f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices]) + p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l + + # Special treatment of the singularity at m = 1. + if num_m > 1: + l_vec = jnp.arange(num_l) + g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :]) + if num_l > 2: + g0 = g0 - p[2, :, :] + p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0) + p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0) + p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x,))) + + return p_derivative + + +@partial(jit, static_argnums=(0, 2)) +def _gen_associated_legendre(l_max: int, + x: jnp.ndarray, + is_normalized: bool) -> jnp.ndarray: + r"""Computes associated Legendre functions (ALFs) of the first kind. + + The ALFs of the first kind are used in spherical harmonics. The spherical + harmonic of degree `l` and order `m` can be written as + `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the + normalization factor and θ and φ are the colatitude and longitude, + repectively. `N_l^m` is chosen in the way that the spherical harmonics form + a set of orthonormal basis function of L^2(S^2). For the computational + efficiency of spherical harmonics transform, the normalization factor is + used in the computation of the ALFs. In addition, normalizing `P_l^m` + avoids overflow/underflow and achieves better numerical stability. Three + recurrence relations are used in the computation. + + Args: + l_max: The maximum degree of the associated Legendre function. Both the + degrees and orders are `[0, 1, 2, ..., l_max]`. + x: A vector of type `float32`, `float64` containing the sampled points in + spherical coordinates, at which the ALFs are computed; `x` is essentially + `cos(θ)`. For the numerical integration used by the spherical harmonics + transforms, `x` contains the quadrature points in the interval of + `[-1, 1]`. There are several approaches to provide the quadrature points: + Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev + method (`scipy.special.roots_chebyu`), and Driscoll & Healy + method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier + transforms and convolutions on the 2-sphere." Advances in applied + mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature + points are nearly equal-spaced along θ and provide exact discrete + orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose + operation, `W` is a diagonal matrix containing the quadrature weights, + and `I` is the identity matrix. The Gauss-Chebyshev points are equally + spaced, which only provide approximate discrete orthogonality. The + Driscoll & Healy qudarture points are equally spaced and provide the + exact discrete orthogonality. The number of sampling points is required to + be twice as the number of frequency points (modes) in the Driscoll & Healy + approach, which enables FFT and achieves a fast spherical harmonics + transform. + is_normalized: True if the associated Legendre functions are normalized. + With normalization, `N_l^m` is applied such that the spherical harmonics + form a set of orthonormal basis functions of L^2(S^2). + + Returns: + The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values + of the ALFs at `x`; the dimensions in the sequence of order, degree, and + evalution points. + """ + p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0])) + + a_idx = jnp.arange(1, l_max + 1) + b_idx = jnp.arange(l_max) + if is_normalized: + initial_value = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0). + f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx)) + f_b = jnp.sqrt(2.0 * b_idx + 3.0) + else: + initial_value = 1.0 # The initial value p(0,0). + f_a = jnp.cumprod(1.0 - 2.0 * a_idx) + f_b = 2.0 * b_idx + 1.0 + + p = p.at[(0, 0)].set(initial_value) + + # Compute the diagonal entries p(l,l) with recurrence. + y = jnp.cumprod( + jnp.broadcast_to(jnp.sqrt(1.0 - x * x), (l_max, x.shape[0])), + axis=0) + p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y) + diag_indices = jnp.diag_indices(l_max + 1) + p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag) + + # Compute the off-diagonal entries with recurrence. + p_offdiag = jnp.einsum('ij,ij->ij', + jnp.einsum('i,j->ij', f_b, x), + p[jnp.diag_indices(l_max)]) + offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1) + p = p.at[offdiag_indices].set(p_offdiag) + + # Compute the remaining entries with recurrence. + d0_mask_3d, d1_mask_3d = _gen_recurrence_mask( + l_max, is_normalized=is_normalized) + + def body_fun(i, p_val): + coeff_0 = d0_mask_3d[i] + coeff_1 = d1_mask_3d[i] + h = (jnp.einsum('ij,ijk->ijk', + coeff_0, + jnp.einsum( + 'ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) - + jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(p_val, shift=2, axis=1))) + p_val = p_val + h + return p_val + + p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p) + + return p + + +def lpmn(m, n, z): + """The associated Legendre functions (ALFs) of the first kind. + + Args: + m: The maximum order of the associated Legendre functions. + n: The maximum degree of the associated Legendre function, often called + `l` in describing ALFs. Both the degrees and orders are + `[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree. + z: A vector of type `float32` or `float64` containing the sampling + points at which the ALFs are computed. + + Returns: + A 2-tuple of 3D arrays of shape `(l_max + 1, l_max + 1, len(z))` containing + the values and derivatives of the associated Legendre functions of the + first kind. The return type matches the type of `z`. + + Raises: + TypeError if elements of array `z` are not in (float32, float64). + ValueError if array `z` is not 1D. + NotImplementedError if `m!=n`. + """ + dtype = lax.dtype(z) + if dtype not in (jnp.float32, jnp.float64): + raise TypeError( + 'z.dtype={} is not supported, see docstring for supported types.' + .format(dtype)) + + if z.ndim != 1: + raise ValueError('z must be a 1D array.') + + m = core.concrete_or_error(int, m, 'Argument m of lpmn.') + n = core.concrete_or_error(int, n, 'Argument n of lpmn.') + + if m != n: + raise NotImplementedError('Computations for m!=n are not yet supported.') + + l_max = n + is_normalized = False + p_vals = _gen_associated_legendre(l_max, z, is_normalized) + p_derivatives = _gen_derivatives(p_vals, z, is_normalized) + + return (p_vals, p_derivatives) diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 75a4b3e14..be6801159 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -32,6 +32,7 @@ from jax._src.scipy.special import ( i1e, logit, logsumexp, + lpmn, multigammaln, log_ndtr, ndtr, diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index ad2753f70..3094abbf1 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -239,6 +239,33 @@ class LaxBackedScipyTests(jtu.JaxTestCase): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_maxdegree={}_inputsize={}".format(l_max, num_z), + "l_max": l_max, + "num_z": num_z} + for l_max, num_z in zip([1, 2, 3], [6, 7, 8]))) + def testLpmn(self, l_max, num_z): + # Points on which the associated Legendre functions areevaluated. + z = np.linspace(-0.2, 0.9, num_z) + actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z) + + # The expected results are obtained from scipy. + expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z)) + expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z)) + + for i in range(num_z): + val, derivative = osp_special.lpmn(l_max, l_max, z[i]) + expected_p_vals[:, :, i] = val + expected_p_derivatives[:, :, i] = derivative + + with self.subTest('Test values.'): + self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6) + + with self.subTest('Test derivatives.'): + self.assertAllClose(actual_p_derivatives,expected_p_derivatives, + rtol=1e-6, atol=8.4e-4) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())