mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Adds spherical harmonics.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
This commit is contained in:
parent
c97d63dec3
commit
d97b393694
@ -10,6 +10,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
|
|
||||||
## jax 0.2.17 (unreleased)
|
## jax 0.2.17 (unreleased)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main).
|
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main).
|
||||||
|
* New features:
|
||||||
|
* New SciPy function {py:func}`jax.scipy.special.sph_harm`.
|
||||||
|
|
||||||
## jaxlib 0.1.69 (unreleased)
|
## jaxlib 0.1.69 (unreleased)
|
||||||
|
|
||||||
@ -23,6 +25,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
with significant dispatch performance improvements on CPU.
|
with significant dispatch performance improvements on CPU.
|
||||||
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
|
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
|
||||||
({jax-issue}`#6956`).
|
({jax-issue}`#6956`).
|
||||||
|
* New SciPy function {py:func}`jax.scipy.special.lpmn_values`.
|
||||||
|
|
||||||
* Breaking changes:
|
* Breaking changes:
|
||||||
* Support for NumPy 1.16 has been dropped, per the
|
* Support for NumPy 1.16 has been dropped, per the
|
||||||
@ -51,6 +54,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* The {func}`jax2tf.convert` generates custom attributes with location information
|
* The {func}`jax2tf.convert` generates custom attributes with location information
|
||||||
in TF ops. The code that XLA generates after jax2tf
|
in TF ops. The code that XLA generates after jax2tf
|
||||||
has the same location information as JAX/XLA.
|
has the same location information as JAX/XLA.
|
||||||
|
* New SciPy function {py:func}`jax.scipy.special.lpmn`.
|
||||||
|
|
||||||
* Bug fixes:
|
* Bug fixes:
|
||||||
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
|
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
|
||||||
|
@ -105,6 +105,7 @@ jax.scipy.special
|
|||||||
ndtr
|
ndtr
|
||||||
ndtri
|
ndtri
|
||||||
polygamma
|
polygamma
|
||||||
|
sph_harm
|
||||||
xlog1py
|
xlog1py
|
||||||
xlogy
|
xlogy
|
||||||
zeta
|
zeta
|
||||||
|
@ -27,7 +27,7 @@ from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
|
|||||||
_promote_args_inexact)
|
_promote_args_inexact)
|
||||||
from jax._src.numpy.util import _wraps
|
from jax._src.numpy.util import _wraps
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
@_wraps(osp_special.gammaln)
|
@_wraps(osp_special.gammaln)
|
||||||
@ -909,7 +909,8 @@ def _gen_associated_legendre(l_max: int,
|
|||||||
p_val = p_val + h
|
p_val = p_val + h
|
||||||
return p_val
|
return p_val
|
||||||
|
|
||||||
p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p)
|
if l_max > 1:
|
||||||
|
p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p)
|
||||||
|
|
||||||
return p
|
return p
|
||||||
|
|
||||||
@ -1010,3 +1011,77 @@ def lpmn_values(m: int, n: int, z: jnp.ndarray, is_normalized: bool) -> jnp.ndar
|
|||||||
l_max = n
|
l_max = n
|
||||||
|
|
||||||
return _gen_associated_legendre(l_max, z, is_normalized)
|
return _gen_associated_legendre(l_max, z, is_normalized)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnums=(4,))
|
||||||
|
def _sph_harm(m: jnp.ndarray,
|
||||||
|
n: jnp.ndarray,
|
||||||
|
theta: jnp.ndarray,
|
||||||
|
phi: jnp.ndarray,
|
||||||
|
n_max: int) -> jnp.ndarray:
|
||||||
|
"""Computes the spherical harmonics."""
|
||||||
|
|
||||||
|
cos_colatitude = jnp.cos(phi)
|
||||||
|
|
||||||
|
legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
|
||||||
|
legendre_val = legendre[abs(m), n, jnp.arange(len(n))]
|
||||||
|
|
||||||
|
angle = abs(m) * theta
|
||||||
|
vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
|
||||||
|
harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
|
||||||
|
legendre_val * jnp.imag(vandermonde))
|
||||||
|
|
||||||
|
# Negative order.
|
||||||
|
harmonics = jnp.where(m < 0,
|
||||||
|
(-1.0)**abs(m) * jnp.conjugate(harmonics),
|
||||||
|
harmonics)
|
||||||
|
|
||||||
|
return harmonics
|
||||||
|
|
||||||
|
|
||||||
|
def sph_harm(m: jnp.ndarray,
|
||||||
|
n: jnp.ndarray,
|
||||||
|
theta: jnp.ndarray,
|
||||||
|
phi: jnp.ndarray,
|
||||||
|
n_max: Optional[int] = None) -> jnp.ndarray:
|
||||||
|
r"""Computes the spherical harmonics.
|
||||||
|
|
||||||
|
The JAX version has one extra argument `n_max`, the maximum value in `n`.
|
||||||
|
|
||||||
|
The spherical harmonic of degree `n` and order `m` can be written as
|
||||||
|
:math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`,
|
||||||
|
where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!}
|
||||||
|
{4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and
|
||||||
|
:math:\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is
|
||||||
|
chosen in the way that the spherical harmonics form a set of orthonormal basis
|
||||||
|
functions of :math:`L^2(S^2)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m: The order of the harmonic; must have `|m| <= n`. Return values for
|
||||||
|
`|m| > n` ara undefined.
|
||||||
|
n: The degree of the harmonic; must have `n >= 0`. The standard notation for
|
||||||
|
degree in descriptions of spherical harmonics is `l (lower case L)`. We
|
||||||
|
use `n` here to be consistent with `scipy.special.sph_harm`. Return
|
||||||
|
values for `n < 0` are undefined.
|
||||||
|
theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].
|
||||||
|
phi: The polar (colatitudinal) coordinate; must be in [0, pi].
|
||||||
|
n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true
|
||||||
|
maximum value of `n`, the results are clipped to `n_max`. For example,
|
||||||
|
`sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)`
|
||||||
|
acutually returns
|
||||||
|
`sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)`
|
||||||
|
Returns:
|
||||||
|
A 1D array containing the spherical harmonics at (m, n, theta, phi).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if jnp.isscalar(phi):
|
||||||
|
phi = jnp.array([phi])
|
||||||
|
|
||||||
|
if n_max is None:
|
||||||
|
n_max = jnp.max(n)
|
||||||
|
n_max = core.concrete_or_error(
|
||||||
|
int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
|
||||||
|
'be statically specified to use `sph_harm` within JAX transformations.')
|
||||||
|
|
||||||
|
return _sph_harm(m, n, theta, phi, n_max)
|
||||||
|
@ -39,6 +39,7 @@ from jax._src.scipy.special import (
|
|||||||
ndtr,
|
ndtr,
|
||||||
ndtri,
|
ndtri,
|
||||||
polygamma,
|
polygamma,
|
||||||
|
sph_harm,
|
||||||
xlogy,
|
xlogy,
|
||||||
xlog1py,
|
xlog1py,
|
||||||
zeta,
|
zeta,
|
||||||
|
@ -26,6 +26,7 @@ import numpy as np
|
|||||||
import scipy.special as osp_special
|
import scipy.special as osp_special
|
||||||
|
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
|
from jax import numpy as jnp
|
||||||
from jax import test_util as jtu
|
from jax import test_util as jtu
|
||||||
from jax.scipy import special as lsp_special
|
from jax.scipy import special as lsp_special
|
||||||
|
|
||||||
@ -312,6 +313,98 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
|||||||
lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
|
lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
|
||||||
self._CompileAndCheck(lsp_special_fn, args_maker)
|
self._CompileAndCheck(lsp_special_fn, args_maker)
|
||||||
|
|
||||||
|
def testSphHarmAccuracy(self):
|
||||||
|
m = jnp.arange(-3, 3)[:, None]
|
||||||
|
n = jnp.arange(3, 6)
|
||||||
|
n_max = 5
|
||||||
|
theta = 0.0
|
||||||
|
phi = jnp.pi
|
||||||
|
|
||||||
|
expected = lsp_special.sph_harm(m, n, theta, phi, n_max)
|
||||||
|
|
||||||
|
actual = osp_special.sph_harm(m, n, theta, phi)
|
||||||
|
|
||||||
|
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
|
||||||
|
|
||||||
|
def testSphHarmOrderZeroDegreeZero(self):
|
||||||
|
"""Tests the spherical harmonics of order zero and degree zero."""
|
||||||
|
theta = jnp.array([0.3])
|
||||||
|
phi = jnp.array([2.3])
|
||||||
|
n_max = 0
|
||||||
|
|
||||||
|
expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)])
|
||||||
|
actual = jnp.real(
|
||||||
|
lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi, n_max))
|
||||||
|
|
||||||
|
self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)
|
||||||
|
|
||||||
|
def testSphHarmOrderZeroDegreeOne(self):
|
||||||
|
"""Tests the spherical harmonics of order one and degree zero."""
|
||||||
|
theta = jnp.array([2.0])
|
||||||
|
phi = jnp.array([3.1])
|
||||||
|
n_max = 1
|
||||||
|
|
||||||
|
expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi)
|
||||||
|
actual = jnp.real(
|
||||||
|
lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi, n_max))
|
||||||
|
|
||||||
|
self.assertAllClose(actual, expected, rtol=7e-8, atol=1.5e-8)
|
||||||
|
|
||||||
|
def testSphHarmOrderOneDegreeOne(self):
|
||||||
|
"""Tests the spherical harmonics of order one and degree one."""
|
||||||
|
theta = jnp.array([2.0])
|
||||||
|
phi = jnp.array([2.5])
|
||||||
|
n_max = 1
|
||||||
|
|
||||||
|
expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) *
|
||||||
|
jnp.sin(phi) * jnp.exp(1j * theta))
|
||||||
|
actual = lsp_special.sph_harm(
|
||||||
|
jnp.array([1]), jnp.array([1]), theta, phi, n_max)
|
||||||
|
|
||||||
|
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
{'testcase_name': '_maxdegree={}_inputsize={}_dtype={}'.format(
|
||||||
|
l_max, num_z, dtype),
|
||||||
|
'l_max': l_max, 'num_z': num_z, 'dtype': dtype}
|
||||||
|
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
|
||||||
|
for dtype in jtu.dtypes.all_integer))
|
||||||
|
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
|
||||||
|
"""Tests against JIT compatibility and Numpy."""
|
||||||
|
n_max = l_max
|
||||||
|
shape = (num_z,)
|
||||||
|
rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)
|
||||||
|
|
||||||
|
lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max)
|
||||||
|
|
||||||
|
def args_maker():
|
||||||
|
m = rng(shape, dtype)
|
||||||
|
n = abs(m)
|
||||||
|
theta = jnp.linspace(-4.0, 5.0, num_z)
|
||||||
|
phi = jnp.linspace(-2.0, 1.0, num_z)
|
||||||
|
return m, n, theta, phi
|
||||||
|
|
||||||
|
with self.subTest('Test JIT compatibility'):
|
||||||
|
self._CompileAndCheck(lsp_special_fn, args_maker)
|
||||||
|
|
||||||
|
with self.subTest('Test against numpy.'):
|
||||||
|
self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)
|
||||||
|
|
||||||
|
def testSphHarmCornerCaseWithWrongNmax(self):
|
||||||
|
"""Tests the corner case where `n_max` is not the maximum value of `n`."""
|
||||||
|
m = jnp.array([2])
|
||||||
|
n = jnp.array([10])
|
||||||
|
n_clipped = jnp.array([6])
|
||||||
|
n_max = 6
|
||||||
|
theta = jnp.array([0.9])
|
||||||
|
phi = jnp.array([0.2])
|
||||||
|
|
||||||
|
expected = lsp_special.sph_harm(m, n, theta, phi, n_max)
|
||||||
|
|
||||||
|
actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max)
|
||||||
|
|
||||||
|
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user