mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Support value computation of associated Legendre functions.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
This commit is contained in:
parent
1e4d28a2d9
commit
095e6507b9
@ -100,6 +100,7 @@ jax.scipy.special
|
||||
logit
|
||||
logsumexp
|
||||
lpmn
|
||||
lpmn_values
|
||||
multigammaln
|
||||
ndtr
|
||||
ndtri
|
||||
|
@ -914,7 +914,7 @@ def _gen_associated_legendre(l_max: int,
|
||||
return p
|
||||
|
||||
|
||||
def lpmn(m, n, z):
|
||||
def lpmn(m: int, n: int, z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
"""The associated Legendre functions (ALFs) of the first kind.
|
||||
|
||||
Args:
|
||||
@ -956,3 +956,57 @@ def lpmn(m, n, z):
|
||||
p_derivatives = _gen_derivatives(p_vals, z, is_normalized)
|
||||
|
||||
return (p_vals, p_derivatives)
|
||||
|
||||
|
||||
def lpmn_values(m: int, n: int, z: jnp.ndarray, is_normalized: bool) -> jnp.ndarray:
|
||||
r"""The associated Legendre functions (ALFs) of the first kind.
|
||||
|
||||
Unlike `lpmn`, this function only computes the values of ALFs.
|
||||
The ALFs of the first kind can be used in spherical harmonics. The
|
||||
spherical harmonic of degree `l` and order `m` can be written as
|
||||
:math:`Y_l^m(\theta, \phi) = N_l^m * P_l^m(\cos \theta) * \exp(i m \phi)`,
|
||||
where :math:`N_l^m` is the normalization factor and θ and φ are the
|
||||
colatitude and longitude, repectively. :math:`N_l^m` is chosen in the
|
||||
way that the spherical harmonics form a set of orthonormal basis function
|
||||
of :math:`L^2(S^2)`. Normalizing :math:`P_l^m` avoids overflow/underflow
|
||||
and achieves better numerical stability.
|
||||
|
||||
Args:
|
||||
m: The maximum order of the associated Legendre functions.
|
||||
n: The maximum degree of the associated Legendre function, often called
|
||||
`l` in describing ALFs. Both the degrees and orders are
|
||||
`[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree.
|
||||
z: A vector of type `float32` or `float64` containing the sampling
|
||||
points at which the ALFs are computed.
|
||||
is_normalized: True if the associated Legendre functions are normalized.
|
||||
With normalization, :math:`N_l^m` is applied such that the spherical
|
||||
harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`.
|
||||
|
||||
Returns:
|
||||
A 3D array of shape `(l_max + 1, l_max + 1, len(z))` containing
|
||||
the values of the associated Legendre functions of the first kind. The
|
||||
return type matches the type of `z`.
|
||||
|
||||
Raises:
|
||||
TypeError if elements of array `z` are not in (float32, float64).
|
||||
ValueError if array `z` is not 1D.
|
||||
NotImplementedError if `m!=n`.
|
||||
"""
|
||||
dtype = lax.dtype(z)
|
||||
if dtype not in (jnp.float32, jnp.float64):
|
||||
raise TypeError(
|
||||
'z.dtype={} is not supported, see docstring for supported types.'
|
||||
.format(dtype))
|
||||
|
||||
if z.ndim != 1:
|
||||
raise ValueError('z must be a 1D array.')
|
||||
|
||||
m = core.concrete_or_error(int, m, 'Argument m of lpmn.')
|
||||
n = core.concrete_or_error(int, n, 'Argument n of lpmn.')
|
||||
|
||||
if m != n:
|
||||
raise NotImplementedError('Computations for m!=n are not yet supported.')
|
||||
|
||||
l_max = n
|
||||
|
||||
return _gen_associated_legendre(l_max, z, is_normalized)
|
||||
|
@ -33,6 +33,7 @@ from jax._src.scipy.special import (
|
||||
logit,
|
||||
logsumexp,
|
||||
lpmn,
|
||||
lpmn_values,
|
||||
multigammaln,
|
||||
log_ndtr,
|
||||
ndtr,
|
||||
|
@ -266,6 +266,52 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(actual_p_derivatives,expected_p_derivatives,
|
||||
rtol=1e-6, atol=8.4e-4)
|
||||
|
||||
with self.subTest('Test JIT compatibility'):
|
||||
args_maker = lambda: [z]
|
||||
lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z)
|
||||
self._CompileAndCheck(lsp_special_fn, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_maxdegree={}_inputsize={}".format(l_max, num_z),
|
||||
"l_max": l_max,
|
||||
"num_z": num_z}
|
||||
for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64])))
|
||||
def testNormalizedLpmnValues(self, l_max, num_z):
|
||||
# Points on which the associated Legendre functions areevaluated.
|
||||
z = np.linspace(-0.2, 0.9, num_z)
|
||||
is_normalized = True
|
||||
actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
|
||||
|
||||
# The expected results are obtained from scipy.
|
||||
expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
|
||||
for i in range(num_z):
|
||||
expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0]
|
||||
|
||||
def apply_normalization(a):
|
||||
"""Applies normalization to the associated Legendre functions."""
|
||||
num_m, num_l, _ = a.shape
|
||||
a_normalized = np.zeros_like(a)
|
||||
for m in range(num_m):
|
||||
for l in range(num_l):
|
||||
c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m)
|
||||
c1 = (4.0 * np.pi) * osp_special.factorial(l + m)
|
||||
c2 = np.sqrt(c0 / c1)
|
||||
a_normalized[m, l] = c2 * a[m, l]
|
||||
return a_normalized
|
||||
|
||||
# The results from scipy are not normalized and the comparison requires
|
||||
# normalizing the results.
|
||||
expected_p_vals_normalized = apply_normalization(expected_p_vals)
|
||||
|
||||
with self.subTest('Test accuracy.'):
|
||||
self.assertAllClose(actual_p_vals, expected_p_vals_normalized, rtol=1e-6, atol=3.2e-6)
|
||||
|
||||
with self.subTest('Test JIT compatibility'):
|
||||
args_maker = lambda: [z]
|
||||
lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
|
||||
self._CompileAndCheck(lsp_special_fn, args_maker)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user