mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
special.lpmn: use more canonical testing approach
This commit is contained in:
parent
8f0bfcb4ef
commit
58cdc1b3d6
@ -314,56 +314,46 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(jax.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_maxdegree={}_inputsize={}".format(l_max, num_z),
|
||||
"l_max": l_max,
|
||||
"num_z": num_z}
|
||||
for l_max, num_z in zip([1, 2, 3], [6, 7, 8])))
|
||||
def testLpmn(self, l_max, num_z):
|
||||
# Points on which the associated Legendre functions areevaluated.
|
||||
z = np.linspace(-0.2, 0.9, num_z)
|
||||
actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z)
|
||||
{"testcase_name": "_{}_lmax={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), l_max),
|
||||
"l_max": l_max, "shape": shape, "dtype": dtype}
|
||||
for l_max in [1, 2, 3]
|
||||
for shape in [(5,), (10,)]
|
||||
for dtype in float_dtypes))
|
||||
def testLpmn(self, l_max, shape, dtype):
|
||||
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
# The expected results are obtained from scipy.
|
||||
expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
|
||||
expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z))
|
||||
lax_fun = partial(lsp_special.lpmn, l_max, l_max)
|
||||
|
||||
for i in range(num_z):
|
||||
val, derivative = osp_special.lpmn(l_max, l_max, z[i])
|
||||
expected_p_vals[:, :, i] = val
|
||||
expected_p_derivatives[:, :, i] = derivative
|
||||
def scipy_fun(z, m=l_max, n=l_max):
|
||||
# scipy only supports scalar inputs for z, so we must loop here.
|
||||
vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
|
||||
return np.dstack(vals), np.dstack(derivs)
|
||||
|
||||
with self.subTest('Test values.'):
|
||||
self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6)
|
||||
|
||||
with self.subTest('Test derivatives.'):
|
||||
self.assertAllClose(actual_p_derivatives,expected_p_derivatives,
|
||||
rtol=1e-6, atol=8.4e-4)
|
||||
|
||||
with self.subTest('Test JIT compatibility'):
|
||||
args_maker = lambda: [z]
|
||||
lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z)
|
||||
self._CompileAndCheck(lsp_special_fn, args_maker)
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-6, atol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_maxdegree={}_inputsize={}".format(l_max, num_z),
|
||||
"l_max": l_max,
|
||||
"num_z": num_z}
|
||||
for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64])))
|
||||
def testNormalizedLpmnValues(self, l_max, num_z):
|
||||
# Points on which the associated Legendre functions areevaluated.
|
||||
z = np.linspace(-0.2, 0.9, num_z)
|
||||
is_normalized = True
|
||||
actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
|
||||
{"testcase_name": "_{}_lmax={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), l_max),
|
||||
"l_max": l_max, "shape": shape, "dtype": dtype}
|
||||
for l_max in [3, 4, 6, 32]
|
||||
for shape in [(2,), (3,), (4,), (64,)]
|
||||
for dtype in float_dtypes))
|
||||
def testNormalizedLpmnValues(self, l_max, shape, dtype):
|
||||
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
# The expected results are obtained from scipy.
|
||||
expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
|
||||
for i in range(num_z):
|
||||
expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0]
|
||||
# Note: we test only the normalized values, not the derivatives.
|
||||
lax_fun = partial(lsp_special.lpmn_values, l_max, l_max, is_normalized=True)
|
||||
|
||||
def apply_normalization(a):
|
||||
"""Applies normalization to the associated Legendre functions."""
|
||||
def scipy_fun(z, m=l_max, n=l_max):
|
||||
# scipy only supports scalar inputs for z, so we must loop here.
|
||||
vals, _ = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
|
||||
a = np.dstack(vals)
|
||||
|
||||
# apply the normalization
|
||||
num_m, num_l, _ = a.shape
|
||||
a_normalized = np.zeros_like(a)
|
||||
for m in range(num_m):
|
||||
@ -374,17 +364,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
a_normalized[m, l] = c2 * a[m, l]
|
||||
return a_normalized
|
||||
|
||||
# The results from scipy are not normalized and the comparison requires
|
||||
# normalizing the results.
|
||||
expected_p_vals_normalized = apply_normalization(expected_p_vals)
|
||||
|
||||
with self.subTest('Test accuracy.'):
|
||||
self.assertAllClose(actual_p_vals, expected_p_vals_normalized, rtol=1e-6, atol=3.2e-6)
|
||||
|
||||
with self.subTest('Test JIT compatibility'):
|
||||
args_maker = lambda: [z]
|
||||
lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
|
||||
self._CompileAndCheck(lsp_special_fn, args_maker)
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5, atol=1e-5)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
|
||||
|
||||
def testSphHarmAccuracy(self):
|
||||
m = jnp.arange(-3, 3)[:, None]
|
||||
|
Loading…
x
Reference in New Issue
Block a user