diff --git a/CHANGELOG.md b/CHANGELOG.md index 29c1d26f8..93a063a87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.2.17 (unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main). +* New features: + * New SciPy function {py:func}`jax.scipy.special.sph_harm`. ## jaxlib 0.1.69 (unreleased) @@ -23,6 +25,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. with significant dispatch performance improvements on CPU. * The {func}`jax2tf.convert` supports inequalities and min/max for booleans ({jax-issue}`#6956`). + * New SciPy function {py:func}`jax.scipy.special.lpmn_values`. * Breaking changes: * Support for NumPy 1.16 has been dropped, per the @@ -51,6 +54,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * The {func}`jax2tf.convert` generates custom attributes with location information in TF ops. The code that XLA generates after jax2tf has the same location information as JAX/XLA. + * New SciPy function {py:func}`jax.scipy.special.lpmn`. * Bug fixes: * The {func}`jax2tf.convert` now ensures that it uses the same typing rules diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index a282fbbb3..1a0bbd39d 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -105,6 +105,7 @@ jax.scipy.special ndtr ndtri polygamma + sph_harm xlog1py xlogy zeta diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 3edde9524..7c39db27f 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -27,7 +27,7 @@ from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like, _promote_args_inexact) from jax._src.numpy.util import _wraps -from typing import Tuple +from typing import Optional, Tuple @_wraps(osp_special.gammaln) @@ -909,7 +909,8 @@ def _gen_associated_legendre(l_max: int, p_val = p_val + h return p_val - p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p) + if l_max > 1: + p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p) return p @@ -1010,3 +1011,77 @@ def lpmn_values(m: int, n: int, z: jnp.ndarray, is_normalized: bool) -> jnp.ndar l_max = n return _gen_associated_legendre(l_max, z, is_normalized) + + + +@partial(jit, static_argnums=(4,)) +def _sph_harm(m: jnp.ndarray, + n: jnp.ndarray, + theta: jnp.ndarray, + phi: jnp.ndarray, + n_max: int) -> jnp.ndarray: + """Computes the spherical harmonics.""" + + cos_colatitude = jnp.cos(phi) + + legendre = _gen_associated_legendre(n_max, cos_colatitude, True) + legendre_val = legendre[abs(m), n, jnp.arange(len(n))] + + angle = abs(m) * theta + vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle)) + harmonics = lax.complex(legendre_val * jnp.real(vandermonde), + legendre_val * jnp.imag(vandermonde)) + + # Negative order. + harmonics = jnp.where(m < 0, + (-1.0)**abs(m) * jnp.conjugate(harmonics), + harmonics) + + return harmonics + + +def sph_harm(m: jnp.ndarray, + n: jnp.ndarray, + theta: jnp.ndarray, + phi: jnp.ndarray, + n_max: Optional[int] = None) -> jnp.ndarray: + r"""Computes the spherical harmonics. + + The JAX version has one extra argument `n_max`, the maximum value in `n`. + + The spherical harmonic of degree `n` and order `m` can be written as + :math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`, + where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!} + {4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and + :math:\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is + chosen in the way that the spherical harmonics form a set of orthonormal basis + functions of :math:`L^2(S^2)`. + + Args: + m: The order of the harmonic; must have `|m| <= n`. Return values for + `|m| > n` ara undefined. + n: The degree of the harmonic; must have `n >= 0`. The standard notation for + degree in descriptions of spherical harmonics is `l (lower case L)`. We + use `n` here to be consistent with `scipy.special.sph_harm`. Return + values for `n < 0` are undefined. + theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi]. + phi: The polar (colatitudinal) coordinate; must be in [0, pi]. + n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true + maximum value of `n`, the results are clipped to `n_max`. For example, + `sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)` + acutually returns + `sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)` + Returns: + A 1D array containing the spherical harmonics at (m, n, theta, phi). + """ + + if jnp.isscalar(phi): + phi = jnp.array([phi]) + + if n_max is None: + n_max = jnp.max(n) + n_max = core.concrete_or_error( + int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must ' + 'be statically specified to use `sph_harm` within JAX transformations.') + + return _sph_harm(m, n, theta, phi, n_max) diff --git a/jax/scipy/special.py b/jax/scipy/special.py index d9fb19b5b..18a9f37dc 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -39,6 +39,7 @@ from jax._src.scipy.special import ( ndtr, ndtri, polygamma, + sph_harm, xlogy, xlog1py, zeta, diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 985b9a00f..5d0e35a69 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -26,6 +26,7 @@ import numpy as np import scipy.special as osp_special from jax._src import api +from jax import numpy as jnp from jax import test_util as jtu from jax.scipy import special as lsp_special @@ -312,6 +313,98 @@ class LaxBackedScipyTests(jtu.JaxTestCase): lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized) self._CompileAndCheck(lsp_special_fn, args_maker) + def testSphHarmAccuracy(self): + m = jnp.arange(-3, 3)[:, None] + n = jnp.arange(3, 6) + n_max = 5 + theta = 0.0 + phi = jnp.pi + + expected = lsp_special.sph_harm(m, n, theta, phi, n_max) + + actual = osp_special.sph_harm(m, n, theta, phi) + + self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5) + + def testSphHarmOrderZeroDegreeZero(self): + """Tests the spherical harmonics of order zero and degree zero.""" + theta = jnp.array([0.3]) + phi = jnp.array([2.3]) + n_max = 0 + + expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)]) + actual = jnp.real( + lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi, n_max)) + + self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8) + + def testSphHarmOrderZeroDegreeOne(self): + """Tests the spherical harmonics of order one and degree zero.""" + theta = jnp.array([2.0]) + phi = jnp.array([3.1]) + n_max = 1 + + expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi) + actual = jnp.real( + lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi, n_max)) + + self.assertAllClose(actual, expected, rtol=7e-8, atol=1.5e-8) + + def testSphHarmOrderOneDegreeOne(self): + """Tests the spherical harmonics of order one and degree one.""" + theta = jnp.array([2.0]) + phi = jnp.array([2.5]) + n_max = 1 + + expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) * + jnp.sin(phi) * jnp.exp(1j * theta)) + actual = lsp_special.sph_harm( + jnp.array([1]), jnp.array([1]), theta, phi, n_max) + + self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8) + + @parameterized.named_parameters(jtu.cases_from_list( + {'testcase_name': '_maxdegree={}_inputsize={}_dtype={}'.format( + l_max, num_z, dtype), + 'l_max': l_max, 'num_z': num_z, 'dtype': dtype} + for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8]) + for dtype in jtu.dtypes.all_integer)) + def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): + """Tests against JIT compatibility and Numpy.""" + n_max = l_max + shape = (num_z,) + rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) + + lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max) + + def args_maker(): + m = rng(shape, dtype) + n = abs(m) + theta = jnp.linspace(-4.0, 5.0, num_z) + phi = jnp.linspace(-2.0, 1.0, num_z) + return m, n, theta, phi + + with self.subTest('Test JIT compatibility'): + self._CompileAndCheck(lsp_special_fn, args_maker) + + with self.subTest('Test against numpy.'): + self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker) + + def testSphHarmCornerCaseWithWrongNmax(self): + """Tests the corner case where `n_max` is not the maximum value of `n`.""" + m = jnp.array([2]) + n = jnp.array([10]) + n_clipped = jnp.array([6]) + n_max = 6 + theta = jnp.array([0.9]) + phi = jnp.array([0.2]) + + expected = lsp_special.sph_harm(m, n, theta, phi, n_max) + + actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max) + + self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())