[linalg] Fix a bug in computing derivatives of scipy.special.lpmn.

PiperOrigin-RevId: 447807140
This commit is contained in:
Tianjian Lu 2022-05-10 13:02:18 -07:00 committed by jax authors
parent a62ca21b15
commit 48f47c36c4
2 changed files with 5 additions and 4 deletions

View File

@ -769,7 +769,7 @@ def _gen_derivatives(p: jnp.ndarray,
if num_l > 2:
l_vec = jnp.arange(2, num_l - 1)
p_p2 = p[2, 2:num_l - 1, :]
coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec)
coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec * (l_vec - 1))
update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2)

View File

@ -322,7 +322,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
{"testcase_name": "_{}_lmax={}".format(
jtu.format_shape_dtype_string(shape, dtype), l_max),
"l_max": l_max, "shape": shape, "dtype": dtype}
for l_max in [1, 2, 3]
for l_max in [1, 2, 3, 6]
for shape in [(5,), (10,)]
for dtype in float_dtypes))
def testLpmn(self, l_max, shape, dtype):
@ -336,8 +336,9 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
return np.dstack(vals), np.dstack(derivs)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-6, atol=1e-6)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5,
atol=3e-3)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_lmax={}".format(