From 58cdc1b3d6ca338f8e229fd8cdf37082eb0458e5 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 19 Oct 2021 13:46:32 -0700 Subject: [PATCH] special.lpmn: use more canonical testing approach --- tests/lax_scipy_test.py | 89 ++++++++++++++++------------------------- 1 file changed, 35 insertions(+), 54 deletions(-) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index b34422dc1..c4357a7e9 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -314,56 +314,46 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_maxdegree={}_inputsize={}".format(l_max, num_z), - "l_max": l_max, - "num_z": num_z} - for l_max, num_z in zip([1, 2, 3], [6, 7, 8]))) - def testLpmn(self, l_max, num_z): - # Points on which the associated Legendre functions areevaluated. - z = np.linspace(-0.2, 0.9, num_z) - actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z) + {"testcase_name": "_{}_lmax={}".format( + jtu.format_shape_dtype_string(shape, dtype), l_max), + "l_max": l_max, "shape": shape, "dtype": dtype} + for l_max in [1, 2, 3] + for shape in [(5,), (10,)] + for dtype in float_dtypes)) + def testLpmn(self, l_max, shape, dtype): + rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) + args_maker = lambda: [rng(shape, dtype)] - # The expected results are obtained from scipy. - expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z)) - expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z)) + lax_fun = partial(lsp_special.lpmn, l_max, l_max) - for i in range(num_z): - val, derivative = osp_special.lpmn(l_max, l_max, z[i]) - expected_p_vals[:, :, i] = val - expected_p_derivatives[:, :, i] = derivative + def scipy_fun(z, m=l_max, n=l_max): + # scipy only supports scalar inputs for z, so we must loop here. + vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z)) + return np.dstack(vals), np.dstack(derivs) - with self.subTest('Test values.'): - self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6) - - with self.subTest('Test derivatives.'): - self.assertAllClose(actual_p_derivatives,expected_p_derivatives, - rtol=1e-6, atol=8.4e-4) - - with self.subTest('Test JIT compatibility'): - args_maker = lambda: [z] - lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z) - self._CompileAndCheck(lsp_special_fn, args_maker) + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-6, atol=1e-6) + self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_maxdegree={}_inputsize={}".format(l_max, num_z), - "l_max": l_max, - "num_z": num_z} - for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64]))) - def testNormalizedLpmnValues(self, l_max, num_z): - # Points on which the associated Legendre functions areevaluated. - z = np.linspace(-0.2, 0.9, num_z) - is_normalized = True - actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized) + {"testcase_name": "_{}_lmax={}".format( + jtu.format_shape_dtype_string(shape, dtype), l_max), + "l_max": l_max, "shape": shape, "dtype": dtype} + for l_max in [3, 4, 6, 32] + for shape in [(2,), (3,), (4,), (64,)] + for dtype in float_dtypes)) + def testNormalizedLpmnValues(self, l_max, shape, dtype): + rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) + args_maker = lambda: [rng(shape, dtype)] - # The expected results are obtained from scipy. - expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z)) - for i in range(num_z): - expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0] + # Note: we test only the normalized values, not the derivatives. + lax_fun = partial(lsp_special.lpmn_values, l_max, l_max, is_normalized=True) - def apply_normalization(a): - """Applies normalization to the associated Legendre functions.""" + def scipy_fun(z, m=l_max, n=l_max): + # scipy only supports scalar inputs for z, so we must loop here. + vals, _ = zip(*(osp_special.lpmn(m, n, zi) for zi in z)) + a = np.dstack(vals) + + # apply the normalization num_m, num_l, _ = a.shape a_normalized = np.zeros_like(a) for m in range(num_m): @@ -374,17 +364,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase): a_normalized[m, l] = c2 * a[m, l] return a_normalized - # The results from scipy are not normalized and the comparison requires - # normalizing the results. - expected_p_vals_normalized = apply_normalization(expected_p_vals) - - with self.subTest('Test accuracy.'): - self.assertAllClose(actual_p_vals, expected_p_vals_normalized, rtol=1e-6, atol=3.2e-6) - - with self.subTest('Test JIT compatibility'): - args_maker = lambda: [z] - lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized) - self._CompileAndCheck(lsp_special_fn, args_maker) + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5, atol=1e-5) + self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6) def testSphHarmAccuracy(self): m = jnp.arange(-3, 3)[:, None]