Adds spherical harmonics.

Co-authored-by: Jake VanderPlas <jakevdp@google.com>
This commit is contained in:
tlu7 2021-07-02 10:42:29 -07:00
parent c97d63dec3
commit d97b393694
5 changed files with 176 additions and 2 deletions

View File

@ -10,6 +10,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.17 (unreleased) ## jax 0.2.17 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main). * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main).
* New features:
* New SciPy function {py:func}`jax.scipy.special.sph_harm`.
## jaxlib 0.1.69 (unreleased) ## jaxlib 0.1.69 (unreleased)
@ -23,6 +25,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
with significant dispatch performance improvements on CPU. with significant dispatch performance improvements on CPU.
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans * The {func}`jax2tf.convert` supports inequalities and min/max for booleans
({jax-issue}`#6956`). ({jax-issue}`#6956`).
* New SciPy function {py:func}`jax.scipy.special.lpmn_values`.
* Breaking changes: * Breaking changes:
* Support for NumPy 1.16 has been dropped, per the * Support for NumPy 1.16 has been dropped, per the
@ -51,6 +54,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* The {func}`jax2tf.convert` generates custom attributes with location information * The {func}`jax2tf.convert` generates custom attributes with location information
in TF ops. The code that XLA generates after jax2tf in TF ops. The code that XLA generates after jax2tf
has the same location information as JAX/XLA. has the same location information as JAX/XLA.
* New SciPy function {py:func}`jax.scipy.special.lpmn`.
* Bug fixes: * Bug fixes:
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules * The {func}`jax2tf.convert` now ensures that it uses the same typing rules

View File

@ -105,6 +105,7 @@ jax.scipy.special
ndtr ndtr
ndtri ndtri
polygamma polygamma
sph_harm
xlog1py xlog1py
xlogy xlogy
zeta zeta

View File

@ -27,7 +27,7 @@ from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
_promote_args_inexact) _promote_args_inexact)
from jax._src.numpy.util import _wraps from jax._src.numpy.util import _wraps
from typing import Tuple from typing import Optional, Tuple
@_wraps(osp_special.gammaln) @_wraps(osp_special.gammaln)
@ -909,7 +909,8 @@ def _gen_associated_legendre(l_max: int,
p_val = p_val + h p_val = p_val + h
return p_val return p_val
p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p) if l_max > 1:
p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p)
return p return p
@ -1010,3 +1011,77 @@ def lpmn_values(m: int, n: int, z: jnp.ndarray, is_normalized: bool) -> jnp.ndar
l_max = n l_max = n
return _gen_associated_legendre(l_max, z, is_normalized) return _gen_associated_legendre(l_max, z, is_normalized)
@partial(jit, static_argnums=(4,))
def _sph_harm(m: jnp.ndarray,
n: jnp.ndarray,
theta: jnp.ndarray,
phi: jnp.ndarray,
n_max: int) -> jnp.ndarray:
"""Computes the spherical harmonics."""
cos_colatitude = jnp.cos(phi)
legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
legendre_val = legendre[abs(m), n, jnp.arange(len(n))]
angle = abs(m) * theta
vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
legendre_val * jnp.imag(vandermonde))
# Negative order.
harmonics = jnp.where(m < 0,
(-1.0)**abs(m) * jnp.conjugate(harmonics),
harmonics)
return harmonics
def sph_harm(m: jnp.ndarray,
n: jnp.ndarray,
theta: jnp.ndarray,
phi: jnp.ndarray,
n_max: Optional[int] = None) -> jnp.ndarray:
r"""Computes the spherical harmonics.
The JAX version has one extra argument `n_max`, the maximum value in `n`.
The spherical harmonic of degree `n` and order `m` can be written as
:math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`,
where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!}
{4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and
:math:\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is
chosen in the way that the spherical harmonics form a set of orthonormal basis
functions of :math:`L^2(S^2)`.
Args:
m: The order of the harmonic; must have `|m| <= n`. Return values for
`|m| > n` ara undefined.
n: The degree of the harmonic; must have `n >= 0`. The standard notation for
degree in descriptions of spherical harmonics is `l (lower case L)`. We
use `n` here to be consistent with `scipy.special.sph_harm`. Return
values for `n < 0` are undefined.
theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].
phi: The polar (colatitudinal) coordinate; must be in [0, pi].
n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true
maximum value of `n`, the results are clipped to `n_max`. For example,
`sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)`
acutually returns
`sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)`
Returns:
A 1D array containing the spherical harmonics at (m, n, theta, phi).
"""
if jnp.isscalar(phi):
phi = jnp.array([phi])
if n_max is None:
n_max = jnp.max(n)
n_max = core.concrete_or_error(
int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
'be statically specified to use `sph_harm` within JAX transformations.')
return _sph_harm(m, n, theta, phi, n_max)

View File

@ -39,6 +39,7 @@ from jax._src.scipy.special import (
ndtr, ndtr,
ndtri, ndtri,
polygamma, polygamma,
sph_harm,
xlogy, xlogy,
xlog1py, xlog1py,
zeta, zeta,

View File

@ -26,6 +26,7 @@ import numpy as np
import scipy.special as osp_special import scipy.special as osp_special
from jax._src import api from jax._src import api
from jax import numpy as jnp
from jax import test_util as jtu from jax import test_util as jtu
from jax.scipy import special as lsp_special from jax.scipy import special as lsp_special
@ -312,6 +313,98 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized) lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
self._CompileAndCheck(lsp_special_fn, args_maker) self._CompileAndCheck(lsp_special_fn, args_maker)
def testSphHarmAccuracy(self):
m = jnp.arange(-3, 3)[:, None]
n = jnp.arange(3, 6)
n_max = 5
theta = 0.0
phi = jnp.pi
expected = lsp_special.sph_harm(m, n, theta, phi, n_max)
actual = osp_special.sph_harm(m, n, theta, phi)
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
def testSphHarmOrderZeroDegreeZero(self):
"""Tests the spherical harmonics of order zero and degree zero."""
theta = jnp.array([0.3])
phi = jnp.array([2.3])
n_max = 0
expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)])
actual = jnp.real(
lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi, n_max))
self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)
def testSphHarmOrderZeroDegreeOne(self):
"""Tests the spherical harmonics of order one and degree zero."""
theta = jnp.array([2.0])
phi = jnp.array([3.1])
n_max = 1
expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi)
actual = jnp.real(
lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi, n_max))
self.assertAllClose(actual, expected, rtol=7e-8, atol=1.5e-8)
def testSphHarmOrderOneDegreeOne(self):
"""Tests the spherical harmonics of order one and degree one."""
theta = jnp.array([2.0])
phi = jnp.array([2.5])
n_max = 1
expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) *
jnp.sin(phi) * jnp.exp(1j * theta))
actual = lsp_special.sph_harm(
jnp.array([1]), jnp.array([1]), theta, phi, n_max)
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name': '_maxdegree={}_inputsize={}_dtype={}'.format(
l_max, num_z, dtype),
'l_max': l_max, 'num_z': num_z, 'dtype': dtype}
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
for dtype in jtu.dtypes.all_integer))
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
"""Tests against JIT compatibility and Numpy."""
n_max = l_max
shape = (num_z,)
rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)
lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max)
def args_maker():
m = rng(shape, dtype)
n = abs(m)
theta = jnp.linspace(-4.0, 5.0, num_z)
phi = jnp.linspace(-2.0, 1.0, num_z)
return m, n, theta, phi
with self.subTest('Test JIT compatibility'):
self._CompileAndCheck(lsp_special_fn, args_maker)
with self.subTest('Test against numpy.'):
self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)
def testSphHarmCornerCaseWithWrongNmax(self):
"""Tests the corner case where `n_max` is not the maximum value of `n`."""
m = jnp.array([2])
n = jnp.array([10])
n_clipped = jnp.array([6])
n_max = 6
theta = jnp.array([0.9])
phi = jnp.array([0.2])
expected = lsp_special.sph_harm(m, n, theta, phi, n_max)
actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max)
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
if __name__ == "__main__": if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader()) absltest.main(testLoader=jtu.JaxTestLoader())