mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #25753 from dfm:shp-harm-y
PiperOrigin-RevId: 713717495
This commit is contained in:
commit
6c8b02df01
@ -135,3 +135,4 @@ register('jax-numpy-quantile-interpolation')
|
||||
register('jax-numpy-reduction-non-boolean-where')
|
||||
register('jax-numpy-trimzeros-not-1d-array')
|
||||
register('pallas-gpu-triton')
|
||||
register('jax-scipy-special-sph-harm')
|
||||
|
@ -28,6 +28,7 @@ from jax import lax
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact
|
||||
@ -1731,19 +1732,19 @@ def lpmn_values(m: int, n: int, z: Array, is_normalized: bool) -> Array:
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(4,))
|
||||
def _sph_harm(m: Array,
|
||||
n: Array,
|
||||
def _sph_harm(n: Array,
|
||||
m: Array,
|
||||
theta: Array,
|
||||
phi: Array,
|
||||
n_max: int) -> Array:
|
||||
"""Computes the spherical harmonics."""
|
||||
|
||||
cos_colatitude = jnp.cos(phi)
|
||||
cos_colatitude = jnp.cos(theta)
|
||||
|
||||
legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
|
||||
legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")
|
||||
|
||||
angle = abs(m) * theta
|
||||
angle = abs(m) * phi
|
||||
vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
|
||||
harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
|
||||
legendre_val * jnp.imag(vandermonde))
|
||||
@ -1756,6 +1757,58 @@ def _sph_harm(m: Array,
|
||||
return harmonics
|
||||
|
||||
|
||||
def sph_harm_y(n: Array,
|
||||
m: Array,
|
||||
theta: Array,
|
||||
phi: Array,
|
||||
diff_n: int | None = None,
|
||||
n_max: int | None = None) -> Array:
|
||||
r"""Computes the spherical harmonics.
|
||||
|
||||
The JAX version has one extra argument `n_max`, the maximum value in `n`.
|
||||
|
||||
The spherical harmonic of degree `n` and order `m` can be written as
|
||||
:math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \theta) * \exp(i m \phi)`,
|
||||
where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!}
|
||||
{4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\theta` and
|
||||
:math:`\phi` are the colatitude and longitude, respectively. :math:`N_n^m` is
|
||||
chosen in the way that the spherical harmonics form a set of orthonormal basis
|
||||
functions of :math:`L^2(S^2)`.
|
||||
|
||||
Args:
|
||||
n: The degree of the harmonic; must have `n >= 0`. The standard notation for
|
||||
degree in descriptions of spherical harmonics is `l (lower case L)`. We
|
||||
use `n` here to be consistent with `scipy.special.sph_harm_y`. Return
|
||||
values for `n < 0` are undefined.
|
||||
m: The order of the harmonic; must have `|m| <= n`. Return values for
|
||||
`|m| > n` are undefined.
|
||||
theta: The polar (colatitudinal) coordinate; must be in [0, pi].
|
||||
phi: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].
|
||||
diff_n: Unsupported by JAX.
|
||||
n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true
|
||||
maximum value of `n`, the results are clipped to `n_max`. For example,
|
||||
`sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)`
|
||||
actually returns
|
||||
`sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)`
|
||||
Returns:
|
||||
A 1D array containing the spherical harmonics at (m, n, theta, phi).
|
||||
"""
|
||||
if diff_n is not None:
|
||||
raise NotImplementedError(
|
||||
"The 'diff_n' argument to jax.scipy.special.sph_harm_y is not supported.")
|
||||
|
||||
if jnp.isscalar(theta):
|
||||
theta = jnp.array([theta])
|
||||
|
||||
if n_max is None:
|
||||
n_max = np.max(n)
|
||||
n_max = core.concrete_or_error(
|
||||
int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
|
||||
'be statically specified to use `sph_harm` within JAX transformations.')
|
||||
|
||||
return _sph_harm(n, m, theta, phi, n_max)
|
||||
|
||||
|
||||
def sph_harm(m: Array,
|
||||
n: Array,
|
||||
theta: Array,
|
||||
@ -1763,6 +1816,11 @@ def sph_harm(m: Array,
|
||||
n_max: int | None = None) -> Array:
|
||||
r"""Computes the spherical harmonics.
|
||||
|
||||
Note:
|
||||
This function is deprecated, and :func:`~jax.scipy.special.sph_harm_y`
|
||||
should be used instead, noting that the order of ``m`` and ``n`` are
|
||||
reversed, and definitions of ``theta`` and ``phi`` are swapped.
|
||||
|
||||
The JAX version has one extra argument `n_max`, the maximum value in `n`.
|
||||
|
||||
The spherical harmonic of degree `n` and order `m` can be written as
|
||||
@ -1790,17 +1848,16 @@ def sph_harm(m: Array,
|
||||
Returns:
|
||||
A 1D array containing the spherical harmonics at (m, n, theta, phi).
|
||||
"""
|
||||
|
||||
if jnp.isscalar(phi):
|
||||
phi = jnp.array([phi])
|
||||
|
||||
if n_max is None:
|
||||
n_max = np.max(n)
|
||||
n_max = core.concrete_or_error(
|
||||
int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
|
||||
'be statically specified to use `sph_harm` within JAX transformations.')
|
||||
|
||||
return _sph_harm(m, n, theta, phi, n_max)
|
||||
# Added 2025-01-06.
|
||||
# TODO(dfm): Remove after deprecation period.
|
||||
deprecations.warn(
|
||||
"jax-scipy-special-sph-harm",
|
||||
("jax.scipy.special.sph_harm is deprecated. Please use "
|
||||
"jax.scipy.special.sph_harm_y instead, noting that the order of `m` and "
|
||||
"`n` are reversed, and definitions of `theta` and `phi` are swapped."),
|
||||
stacklevel=2,
|
||||
)
|
||||
return sph_harm_y(n, m, phi, theta, n_max=n_max)
|
||||
|
||||
|
||||
# exponential integrals
|
||||
|
@ -57,6 +57,7 @@ from jax._src.scipy.special import (
|
||||
softmax as softmax,
|
||||
spence as spence,
|
||||
sph_harm as sph_harm,
|
||||
sph_harm_y as sph_harm_y,
|
||||
xlog1py as xlog1py,
|
||||
xlogy as xlogy,
|
||||
zeta as zeta,
|
||||
|
@ -384,7 +384,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
message=".*scipy.special.sph_harm.*")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmAccuracy(self):
|
||||
m = jnp.arange(-3, 3)[:, None]
|
||||
@ -400,7 +400,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
message=".*scipy.special.sph_harm.*")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmOrderZeroDegreeZero(self):
|
||||
"""Tests the spherical harmonics of order zero and degree zero."""
|
||||
@ -415,7 +415,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
message=".*scipy.special.sph_harm.*")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmOrderZeroDegreeOne(self):
|
||||
"""Tests the spherical harmonics of order one and degree zero."""
|
||||
@ -430,7 +430,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(actual, expected, rtol=2e-7, atol=6e-8)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
message=".*scipy.special.sph_harm.*")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmOrderOneDegreeOne(self):
|
||||
"""Tests the spherical harmonics of order one and degree one."""
|
||||
@ -452,7 +452,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.all_integer,
|
||||
)
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
message=".*scipy.special.sph_harm.*")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
|
||||
"""Tests against JIT compatibility and Numpy."""
|
||||
@ -478,7 +478,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="`scipy.special.sph_harm` is deprecated")
|
||||
message=".*scipy.special.sph_harm.*")
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmCornerCaseWithWrongNmax(self):
|
||||
"""Tests the corner case where `n_max` is not the maximum value of `n`."""
|
||||
@ -495,6 +495,35 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(l_max=l_max, num_z=num_z)
|
||||
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
|
||||
],
|
||||
dtype=jtu.dtypes.all_integer,
|
||||
)
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testSphHarmY(self, l_max, num_z, dtype):
|
||||
if jtu.is_device_tpu(6, "e"):
|
||||
self.skipTest("TODO(b/364258243): fails on TPU v6e")
|
||||
n_max = l_max
|
||||
shape = (num_z,)
|
||||
rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)
|
||||
|
||||
def args_maker():
|
||||
m = rng(shape, dtype)
|
||||
n = abs(m)
|
||||
theta = np.linspace(-2.0, 1.0, num_z)
|
||||
phi = np.linspace(-4.0, 5.0, num_z)
|
||||
return n, m, theta, phi
|
||||
|
||||
lsp_special_fn = partial(lsp_special.sph_harm_y, n_max=n_max)
|
||||
self._CompileAndCheck(lsp_special_fn, args_maker)
|
||||
if scipy_version < (1, 15, 0):
|
||||
osp_special_fn = lambda n, m, theta, phi: osp_special.sph_harm(m, n, phi, theta)
|
||||
else:
|
||||
osp_special_fn = osp_special.sph_harm_y
|
||||
self._CheckAgainstNumpy(osp_special_fn, lsp_special_fn, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
n_zero_sv=n_zero_svs,
|
||||
degeneracy=degeneracies,
|
||||
|
Loading…
x
Reference in New Issue
Block a user