mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Adds spherical harmonics.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
This commit is contained in:
parent
c97d63dec3
commit
d97b393694
@ -10,6 +10,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
## jax 0.2.17 (unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main).
|
||||
* New features:
|
||||
* New SciPy function {py:func}`jax.scipy.special.sph_harm`.
|
||||
|
||||
## jaxlib 0.1.69 (unreleased)
|
||||
|
||||
@ -23,6 +25,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
with significant dispatch performance improvements on CPU.
|
||||
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
|
||||
({jax-issue}`#6956`).
|
||||
* New SciPy function {py:func}`jax.scipy.special.lpmn_values`.
|
||||
|
||||
* Breaking changes:
|
||||
* Support for NumPy 1.16 has been dropped, per the
|
||||
@ -51,6 +54,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* The {func}`jax2tf.convert` generates custom attributes with location information
|
||||
in TF ops. The code that XLA generates after jax2tf
|
||||
has the same location information as JAX/XLA.
|
||||
* New SciPy function {py:func}`jax.scipy.special.lpmn`.
|
||||
|
||||
* Bug fixes:
|
||||
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
|
||||
|
@ -105,6 +105,7 @@ jax.scipy.special
|
||||
ndtr
|
||||
ndtri
|
||||
polygamma
|
||||
sph_harm
|
||||
xlog1py
|
||||
xlogy
|
||||
zeta
|
||||
|
@ -27,7 +27,7 @@ from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
|
||||
_promote_args_inexact)
|
||||
from jax._src.numpy.util import _wraps
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@_wraps(osp_special.gammaln)
|
||||
@ -909,6 +909,7 @@ def _gen_associated_legendre(l_max: int,
|
||||
p_val = p_val + h
|
||||
return p_val
|
||||
|
||||
if l_max > 1:
|
||||
p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p)
|
||||
|
||||
return p
|
||||
@ -1010,3 +1011,77 @@ def lpmn_values(m: int, n: int, z: jnp.ndarray, is_normalized: bool) -> jnp.ndar
|
||||
l_max = n
|
||||
|
||||
return _gen_associated_legendre(l_max, z, is_normalized)
|
||||
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(4,))
|
||||
def _sph_harm(m: jnp.ndarray,
|
||||
n: jnp.ndarray,
|
||||
theta: jnp.ndarray,
|
||||
phi: jnp.ndarray,
|
||||
n_max: int) -> jnp.ndarray:
|
||||
"""Computes the spherical harmonics."""
|
||||
|
||||
cos_colatitude = jnp.cos(phi)
|
||||
|
||||
legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
|
||||
legendre_val = legendre[abs(m), n, jnp.arange(len(n))]
|
||||
|
||||
angle = abs(m) * theta
|
||||
vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
|
||||
harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
|
||||
legendre_val * jnp.imag(vandermonde))
|
||||
|
||||
# Negative order.
|
||||
harmonics = jnp.where(m < 0,
|
||||
(-1.0)**abs(m) * jnp.conjugate(harmonics),
|
||||
harmonics)
|
||||
|
||||
return harmonics
|
||||
|
||||
|
||||
def sph_harm(m: jnp.ndarray,
|
||||
n: jnp.ndarray,
|
||||
theta: jnp.ndarray,
|
||||
phi: jnp.ndarray,
|
||||
n_max: Optional[int] = None) -> jnp.ndarray:
|
||||
r"""Computes the spherical harmonics.
|
||||
|
||||
The JAX version has one extra argument `n_max`, the maximum value in `n`.
|
||||
|
||||
The spherical harmonic of degree `n` and order `m` can be written as
|
||||
:math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`,
|
||||
where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!}
|
||||
{4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and
|
||||
:math:\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is
|
||||
chosen in the way that the spherical harmonics form a set of orthonormal basis
|
||||
functions of :math:`L^2(S^2)`.
|
||||
|
||||
Args:
|
||||
m: The order of the harmonic; must have `|m| <= n`. Return values for
|
||||
`|m| > n` ara undefined.
|
||||
n: The degree of the harmonic; must have `n >= 0`. The standard notation for
|
||||
degree in descriptions of spherical harmonics is `l (lower case L)`. We
|
||||
use `n` here to be consistent with `scipy.special.sph_harm`. Return
|
||||
values for `n < 0` are undefined.
|
||||
theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].
|
||||
phi: The polar (colatitudinal) coordinate; must be in [0, pi].
|
||||
n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true
|
||||
maximum value of `n`, the results are clipped to `n_max`. For example,
|
||||
`sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)`
|
||||
acutually returns
|
||||
`sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)`
|
||||
Returns:
|
||||
A 1D array containing the spherical harmonics at (m, n, theta, phi).
|
||||
"""
|
||||
|
||||
if jnp.isscalar(phi):
|
||||
phi = jnp.array([phi])
|
||||
|
||||
if n_max is None:
|
||||
n_max = jnp.max(n)
|
||||
n_max = core.concrete_or_error(
|
||||
int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
|
||||
'be statically specified to use `sph_harm` within JAX transformations.')
|
||||
|
||||
return _sph_harm(m, n, theta, phi, n_max)
|
||||
|
@ -39,6 +39,7 @@ from jax._src.scipy.special import (
|
||||
ndtr,
|
||||
ndtri,
|
||||
polygamma,
|
||||
sph_harm,
|
||||
xlogy,
|
||||
xlog1py,
|
||||
zeta,
|
||||
|
@ -26,6 +26,7 @@ import numpy as np
|
||||
import scipy.special as osp_special
|
||||
|
||||
from jax._src import api
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.scipy import special as lsp_special
|
||||
|
||||
@ -312,6 +313,98 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
|
||||
self._CompileAndCheck(lsp_special_fn, args_maker)
|
||||
|
||||
def testSphHarmAccuracy(self):
|
||||
m = jnp.arange(-3, 3)[:, None]
|
||||
n = jnp.arange(3, 6)
|
||||
n_max = 5
|
||||
theta = 0.0
|
||||
phi = jnp.pi
|
||||
|
||||
expected = lsp_special.sph_harm(m, n, theta, phi, n_max)
|
||||
|
||||
actual = osp_special.sph_harm(m, n, theta, phi)
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
|
||||
|
||||
def testSphHarmOrderZeroDegreeZero(self):
|
||||
"""Tests the spherical harmonics of order zero and degree zero."""
|
||||
theta = jnp.array([0.3])
|
||||
phi = jnp.array([2.3])
|
||||
n_max = 0
|
||||
|
||||
expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)])
|
||||
actual = jnp.real(
|
||||
lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi, n_max))
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)
|
||||
|
||||
def testSphHarmOrderZeroDegreeOne(self):
|
||||
"""Tests the spherical harmonics of order one and degree zero."""
|
||||
theta = jnp.array([2.0])
|
||||
phi = jnp.array([3.1])
|
||||
n_max = 1
|
||||
|
||||
expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi)
|
||||
actual = jnp.real(
|
||||
lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi, n_max))
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=7e-8, atol=1.5e-8)
|
||||
|
||||
def testSphHarmOrderOneDegreeOne(self):
|
||||
"""Tests the spherical harmonics of order one and degree one."""
|
||||
theta = jnp.array([2.0])
|
||||
phi = jnp.array([2.5])
|
||||
n_max = 1
|
||||
|
||||
expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) *
|
||||
jnp.sin(phi) * jnp.exp(1j * theta))
|
||||
actual = lsp_special.sph_harm(
|
||||
jnp.array([1]), jnp.array([1]), theta, phi, n_max)
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name': '_maxdegree={}_inputsize={}_dtype={}'.format(
|
||||
l_max, num_z, dtype),
|
||||
'l_max': l_max, 'num_z': num_z, 'dtype': dtype}
|
||||
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
|
||||
for dtype in jtu.dtypes.all_integer))
|
||||
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
|
||||
"""Tests against JIT compatibility and Numpy."""
|
||||
n_max = l_max
|
||||
shape = (num_z,)
|
||||
rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)
|
||||
|
||||
lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max)
|
||||
|
||||
def args_maker():
|
||||
m = rng(shape, dtype)
|
||||
n = abs(m)
|
||||
theta = jnp.linspace(-4.0, 5.0, num_z)
|
||||
phi = jnp.linspace(-2.0, 1.0, num_z)
|
||||
return m, n, theta, phi
|
||||
|
||||
with self.subTest('Test JIT compatibility'):
|
||||
self._CompileAndCheck(lsp_special_fn, args_maker)
|
||||
|
||||
with self.subTest('Test against numpy.'):
|
||||
self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)
|
||||
|
||||
def testSphHarmCornerCaseWithWrongNmax(self):
|
||||
"""Tests the corner case where `n_max` is not the maximum value of `n`."""
|
||||
m = jnp.array([2])
|
||||
n = jnp.array([10])
|
||||
n_clipped = jnp.array([6])
|
||||
n_max = 6
|
||||
theta = jnp.array([0.9])
|
||||
phi = jnp.array([0.2])
|
||||
|
||||
expected = lsp_special.sph_harm(m, n, theta, phi, n_max)
|
||||
|
||||
actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max)
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user