special.lpmn: use more canonical testing approach

This commit is contained in:
Jake VanderPlas 2021-10-19 13:46:32 -07:00
parent 8f0bfcb4ef
commit 58cdc1b3d6

View File

@ -314,56 +314,46 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_maxdegree={}_inputsize={}".format(l_max, num_z),
"l_max": l_max,
"num_z": num_z}
for l_max, num_z in zip([1, 2, 3], [6, 7, 8])))
def testLpmn(self, l_max, num_z):
# Points on which the associated Legendre functions areevaluated.
z = np.linspace(-0.2, 0.9, num_z)
actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z)
{"testcase_name": "_{}_lmax={}".format(
jtu.format_shape_dtype_string(shape, dtype), l_max),
"l_max": l_max, "shape": shape, "dtype": dtype}
for l_max in [1, 2, 3]
for shape in [(5,), (10,)]
for dtype in float_dtypes))
def testLpmn(self, l_max, shape, dtype):
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
args_maker = lambda: [rng(shape, dtype)]
# The expected results are obtained from scipy.
expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z))
lax_fun = partial(lsp_special.lpmn, l_max, l_max)
for i in range(num_z):
val, derivative = osp_special.lpmn(l_max, l_max, z[i])
expected_p_vals[:, :, i] = val
expected_p_derivatives[:, :, i] = derivative
def scipy_fun(z, m=l_max, n=l_max):
# scipy only supports scalar inputs for z, so we must loop here.
vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
return np.dstack(vals), np.dstack(derivs)
with self.subTest('Test values.'):
self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6)
with self.subTest('Test derivatives.'):
self.assertAllClose(actual_p_derivatives,expected_p_derivatives,
rtol=1e-6, atol=8.4e-4)
with self.subTest('Test JIT compatibility'):
args_maker = lambda: [z]
lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z)
self._CompileAndCheck(lsp_special_fn, args_maker)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-6, atol=1e-6)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_maxdegree={}_inputsize={}".format(l_max, num_z),
"l_max": l_max,
"num_z": num_z}
for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64])))
def testNormalizedLpmnValues(self, l_max, num_z):
# Points on which the associated Legendre functions areevaluated.
z = np.linspace(-0.2, 0.9, num_z)
is_normalized = True
actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
{"testcase_name": "_{}_lmax={}".format(
jtu.format_shape_dtype_string(shape, dtype), l_max),
"l_max": l_max, "shape": shape, "dtype": dtype}
for l_max in [3, 4, 6, 32]
for shape in [(2,), (3,), (4,), (64,)]
for dtype in float_dtypes))
def testNormalizedLpmnValues(self, l_max, shape, dtype):
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
args_maker = lambda: [rng(shape, dtype)]
# The expected results are obtained from scipy.
expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
for i in range(num_z):
expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0]
# Note: we test only the normalized values, not the derivatives.
lax_fun = partial(lsp_special.lpmn_values, l_max, l_max, is_normalized=True)
def apply_normalization(a):
"""Applies normalization to the associated Legendre functions."""
def scipy_fun(z, m=l_max, n=l_max):
# scipy only supports scalar inputs for z, so we must loop here.
vals, _ = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
a = np.dstack(vals)
# apply the normalization
num_m, num_l, _ = a.shape
a_normalized = np.zeros_like(a)
for m in range(num_m):
@ -374,17 +364,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
a_normalized[m, l] = c2 * a[m, l]
return a_normalized
# The results from scipy are not normalized and the comparison requires
# normalizing the results.
expected_p_vals_normalized = apply_normalization(expected_p_vals)
with self.subTest('Test accuracy.'):
self.assertAllClose(actual_p_vals, expected_p_vals_normalized, rtol=1e-6, atol=3.2e-6)
with self.subTest('Test JIT compatibility'):
args_maker = lambda: [z]
lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
self._CompileAndCheck(lsp_special_fn, args_maker)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5, atol=1e-5)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
def testSphHarmAccuracy(self):
m = jnp.arange(-3, 3)[:, None]