mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[linalg] Fix a bug in computing derivatives of scipy.special.lpmn.
PiperOrigin-RevId: 447807140
This commit is contained in:
parent
a62ca21b15
commit
48f47c36c4
@ -769,7 +769,7 @@ def _gen_derivatives(p: jnp.ndarray,
|
||||
if num_l > 2:
|
||||
l_vec = jnp.arange(2, num_l - 1)
|
||||
p_p2 = p[2, 2:num_l - 1, :]
|
||||
coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec)
|
||||
coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec * (l_vec - 1))
|
||||
update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
|
||||
p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2)
|
||||
|
||||
|
@ -322,7 +322,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
{"testcase_name": "_{}_lmax={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), l_max),
|
||||
"l_max": l_max, "shape": shape, "dtype": dtype}
|
||||
for l_max in [1, 2, 3]
|
||||
for l_max in [1, 2, 3, 6]
|
||||
for shape in [(5,), (10,)]
|
||||
for dtype in float_dtypes))
|
||||
def testLpmn(self, l_max, shape, dtype):
|
||||
@ -336,8 +336,9 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
|
||||
return np.dstack(vals), np.dstack(derivs)
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-6, atol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5,
|
||||
atol=3e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_lmax={}".format(
|
||||
|
Loading…
x
Reference in New Issue
Block a user