Merge pull request #25753 from dfm:shp-harm-y

PiperOrigin-RevId: 713717495
This commit is contained in:
jax authors 2025-01-09 10:28:32 -08:00
commit 6c8b02df01
4 changed files with 109 additions and 21 deletions

View File

@ -135,3 +135,4 @@ register('jax-numpy-quantile-interpolation')
register('jax-numpy-reduction-non-boolean-where')
register('jax-numpy-trimzeros-not-1d-array')
register('pallas-gpu-triton')
register('jax-scipy-special-sph-harm')

View File

@ -28,6 +28,7 @@ from jax import lax
from jax._src import core
from jax._src import custom_derivatives
from jax._src import deprecations
from jax._src import dtypes
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact
@ -1731,19 +1732,19 @@ def lpmn_values(m: int, n: int, z: Array, is_normalized: bool) -> Array:
@partial(jit, static_argnums=(4,))
def _sph_harm(m: Array,
n: Array,
def _sph_harm(n: Array,
m: Array,
theta: Array,
phi: Array,
n_max: int) -> Array:
"""Computes the spherical harmonics."""
cos_colatitude = jnp.cos(phi)
cos_colatitude = jnp.cos(theta)
legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")
angle = abs(m) * theta
angle = abs(m) * phi
vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
legendre_val * jnp.imag(vandermonde))
@ -1756,6 +1757,58 @@ def _sph_harm(m: Array,
return harmonics
def sph_harm_y(n: Array,
m: Array,
theta: Array,
phi: Array,
diff_n: int | None = None,
n_max: int | None = None) -> Array:
r"""Computes the spherical harmonics.
The JAX version has one extra argument `n_max`, the maximum value in `n`.
The spherical harmonic of degree `n` and order `m` can be written as
:math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \theta) * \exp(i m \phi)`,
where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!}
{4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\theta` and
:math:`\phi` are the colatitude and longitude, respectively. :math:`N_n^m` is
chosen in the way that the spherical harmonics form a set of orthonormal basis
functions of :math:`L^2(S^2)`.
Args:
n: The degree of the harmonic; must have `n >= 0`. The standard notation for
degree in descriptions of spherical harmonics is `l (lower case L)`. We
use `n` here to be consistent with `scipy.special.sph_harm_y`. Return
values for `n < 0` are undefined.
m: The order of the harmonic; must have `|m| <= n`. Return values for
`|m| > n` are undefined.
theta: The polar (colatitudinal) coordinate; must be in [0, pi].
phi: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].
diff_n: Unsupported by JAX.
n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true
maximum value of `n`, the results are clipped to `n_max`. For example,
`sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)`
actually returns
`sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)`
Returns:
A 1D array containing the spherical harmonics at (m, n, theta, phi).
"""
if diff_n is not None:
raise NotImplementedError(
"The 'diff_n' argument to jax.scipy.special.sph_harm_y is not supported.")
if jnp.isscalar(theta):
theta = jnp.array([theta])
if n_max is None:
n_max = np.max(n)
n_max = core.concrete_or_error(
int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
'be statically specified to use `sph_harm` within JAX transformations.')
return _sph_harm(n, m, theta, phi, n_max)
def sph_harm(m: Array,
n: Array,
theta: Array,
@ -1763,6 +1816,11 @@ def sph_harm(m: Array,
n_max: int | None = None) -> Array:
r"""Computes the spherical harmonics.
Note:
This function is deprecated, and :func:`~jax.scipy.special.sph_harm_y`
should be used instead, noting that the order of ``m`` and ``n`` are
reversed, and definitions of ``theta`` and ``phi`` are swapped.
The JAX version has one extra argument `n_max`, the maximum value in `n`.
The spherical harmonic of degree `n` and order `m` can be written as
@ -1790,17 +1848,16 @@ def sph_harm(m: Array,
Returns:
A 1D array containing the spherical harmonics at (m, n, theta, phi).
"""
if jnp.isscalar(phi):
phi = jnp.array([phi])
if n_max is None:
n_max = np.max(n)
n_max = core.concrete_or_error(
int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
'be statically specified to use `sph_harm` within JAX transformations.')
return _sph_harm(m, n, theta, phi, n_max)
# Added 2025-01-06.
# TODO(dfm): Remove after deprecation period.
deprecations.warn(
"jax-scipy-special-sph-harm",
("jax.scipy.special.sph_harm is deprecated. Please use "
"jax.scipy.special.sph_harm_y instead, noting that the order of `m` and "
"`n` are reversed, and definitions of `theta` and `phi` are swapped."),
stacklevel=2,
)
return sph_harm_y(n, m, phi, theta, n_max=n_max)
# exponential integrals

View File

@ -57,6 +57,7 @@ from jax._src.scipy.special import (
softmax as softmax,
spence as spence,
sph_harm as sph_harm,
sph_harm_y as sph_harm_y,
xlog1py as xlog1py,
xlogy as xlogy,
zeta as zeta,

View File

@ -384,7 +384,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
message=".*scipy.special.sph_harm.*")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmAccuracy(self):
m = jnp.arange(-3, 3)[:, None]
@ -400,7 +400,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
message=".*scipy.special.sph_harm.*")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmOrderZeroDegreeZero(self):
"""Tests the spherical harmonics of order zero and degree zero."""
@ -415,7 +415,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
message=".*scipy.special.sph_harm.*")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmOrderZeroDegreeOne(self):
"""Tests the spherical harmonics of order one and degree zero."""
@ -430,7 +430,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=2e-7, atol=6e-8)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
message=".*scipy.special.sph_harm.*")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmOrderOneDegreeOne(self):
"""Tests the spherical harmonics of order one and degree one."""
@ -452,7 +452,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
dtype=jtu.dtypes.all_integer,
)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
message=".*scipy.special.sph_harm.*")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
"""Tests against JIT compatibility and Numpy."""
@ -478,7 +478,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)
@jtu.ignore_warning(category=DeprecationWarning,
message="`scipy.special.sph_harm` is deprecated")
message=".*scipy.special.sph_harm.*")
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmCornerCaseWithWrongNmax(self):
"""Tests the corner case where `n_max` is not the maximum value of `n`."""
@ -495,6 +495,35 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
@jtu.sample_product(
[dict(l_max=l_max, num_z=num_z)
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
],
dtype=jtu.dtypes.all_integer,
)
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmY(self, l_max, num_z, dtype):
if jtu.is_device_tpu(6, "e"):
self.skipTest("TODO(b/364258243): fails on TPU v6e")
n_max = l_max
shape = (num_z,)
rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)
def args_maker():
m = rng(shape, dtype)
n = abs(m)
theta = np.linspace(-2.0, 1.0, num_z)
phi = np.linspace(-4.0, 5.0, num_z)
return n, m, theta, phi
lsp_special_fn = partial(lsp_special.sph_harm_y, n_max=n_max)
self._CompileAndCheck(lsp_special_fn, args_maker)
if scipy_version < (1, 15, 0):
osp_special_fn = lambda n, m, theta, phi: osp_special.sph_harm(m, n, phi, theta)
else:
osp_special_fn = osp_special.sph_harm_y
self._CheckAgainstNumpy(osp_special_fn, lsp_special_fn, args_maker)
@jtu.sample_product(
n_zero_sv=n_zero_svs,
degeneracy=degeneracies,