From 095e6507b9f8c69c3a5f4c0fcd68033dae428bd6 Mon Sep 17 00:00:00 2001 From: tlu7 Date: Mon, 14 Jun 2021 14:51:37 -0700 Subject: [PATCH] Support value computation of associated Legendre functions. Co-authored-by: Jake VanderPlas --- docs/jax.scipy.rst | 1 + jax/_src/scipy/special.py | 56 ++++++++++++++++++++++++++++++++++++++- jax/scipy/special.py | 1 + tests/lax_scipy_test.py | 46 ++++++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 1 deletion(-) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 2d3cc3827..a282fbbb3 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -100,6 +100,7 @@ jax.scipy.special logit logsumexp lpmn + lpmn_values multigammaln ndtr ndtri diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 2a559edb0..3edde9524 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -914,7 +914,7 @@ def _gen_associated_legendre(l_max: int, return p -def lpmn(m, n, z): +def lpmn(m: int, n: int, z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """The associated Legendre functions (ALFs) of the first kind. Args: @@ -956,3 +956,57 @@ def lpmn(m, n, z): p_derivatives = _gen_derivatives(p_vals, z, is_normalized) return (p_vals, p_derivatives) + + +def lpmn_values(m: int, n: int, z: jnp.ndarray, is_normalized: bool) -> jnp.ndarray: + r"""The associated Legendre functions (ALFs) of the first kind. + + Unlike `lpmn`, this function only computes the values of ALFs. + The ALFs of the first kind can be used in spherical harmonics. The + spherical harmonic of degree `l` and order `m` can be written as + :math:`Y_l^m(\theta, \phi) = N_l^m * P_l^m(\cos \theta) * \exp(i m \phi)`, + where :math:`N_l^m` is the normalization factor and θ and φ are the + colatitude and longitude, repectively. :math:`N_l^m` is chosen in the + way that the spherical harmonics form a set of orthonormal basis function + of :math:`L^2(S^2)`. Normalizing :math:`P_l^m` avoids overflow/underflow + and achieves better numerical stability. + + Args: + m: The maximum order of the associated Legendre functions. + n: The maximum degree of the associated Legendre function, often called + `l` in describing ALFs. Both the degrees and orders are + `[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree. + z: A vector of type `float32` or `float64` containing the sampling + points at which the ALFs are computed. + is_normalized: True if the associated Legendre functions are normalized. + With normalization, :math:`N_l^m` is applied such that the spherical + harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`. + + Returns: + A 3D array of shape `(l_max + 1, l_max + 1, len(z))` containing + the values of the associated Legendre functions of the first kind. The + return type matches the type of `z`. + + Raises: + TypeError if elements of array `z` are not in (float32, float64). + ValueError if array `z` is not 1D. + NotImplementedError if `m!=n`. + """ + dtype = lax.dtype(z) + if dtype not in (jnp.float32, jnp.float64): + raise TypeError( + 'z.dtype={} is not supported, see docstring for supported types.' + .format(dtype)) + + if z.ndim != 1: + raise ValueError('z must be a 1D array.') + + m = core.concrete_or_error(int, m, 'Argument m of lpmn.') + n = core.concrete_or_error(int, n, 'Argument n of lpmn.') + + if m != n: + raise NotImplementedError('Computations for m!=n are not yet supported.') + + l_max = n + + return _gen_associated_legendre(l_max, z, is_normalized) diff --git a/jax/scipy/special.py b/jax/scipy/special.py index be6801159..d9fb19b5b 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -33,6 +33,7 @@ from jax._src.scipy.special import ( logit, logsumexp, lpmn, + lpmn_values, multigammaln, log_ndtr, ndtr, diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 3094abbf1..985b9a00f 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -266,6 +266,52 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertAllClose(actual_p_derivatives,expected_p_derivatives, rtol=1e-6, atol=8.4e-4) + with self.subTest('Test JIT compatibility'): + args_maker = lambda: [z] + lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z) + self._CompileAndCheck(lsp_special_fn, args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_maxdegree={}_inputsize={}".format(l_max, num_z), + "l_max": l_max, + "num_z": num_z} + for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64]))) + def testNormalizedLpmnValues(self, l_max, num_z): + # Points on which the associated Legendre functions areevaluated. + z = np.linspace(-0.2, 0.9, num_z) + is_normalized = True + actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized) + + # The expected results are obtained from scipy. + expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z)) + for i in range(num_z): + expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0] + + def apply_normalization(a): + """Applies normalization to the associated Legendre functions.""" + num_m, num_l, _ = a.shape + a_normalized = np.zeros_like(a) + for m in range(num_m): + for l in range(num_l): + c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m) + c1 = (4.0 * np.pi) * osp_special.factorial(l + m) + c2 = np.sqrt(c0 / c1) + a_normalized[m, l] = c2 * a[m, l] + return a_normalized + + # The results from scipy are not normalized and the comparison requires + # normalizing the results. + expected_p_vals_normalized = apply_normalization(expected_p_vals) + + with self.subTest('Test accuracy.'): + self.assertAllClose(actual_p_vals, expected_p_vals_normalized, rtol=1e-6, atol=3.2e-6) + + with self.subTest('Test JIT compatibility'): + args_maker = lambda: [z] + lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized) + self._CompileAndCheck(lsp_special_fn, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())